Skip to content
Snippets Groups Projects
Commit b88e2059 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Report MSE,PSNR in benchmark and save train file.

parent b5d7ab5a
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,8 @@ abstract class BenchmarkBase {
protected final static String QUANTIZED_FILE_TEMPLATE = "%d_cb%d.raw";
protected final static String DIFFERENCE_FILE_TEMPLATE = "%d_cb%d.data";
protected final static String ABSOLUTE_DIFFERENCE_FILE_TEMPLATE = "%d_cb%d_abs.data";
protected final static String TRAIN_FILE_TEMPLATE = "%d_cb%d_trainLog.csv";
protected final String inputFile;
protected final String outputDirectory;
......@@ -186,7 +188,7 @@ abstract class BenchmarkBase {
* @param trainingLog QTrainingLog
*/
protected void saveQTrainLog(final String filename, final QTrainIteration[] trainingLog) {
final String CSV_HEADER = "It;AvgMSE;BestMSE;AvgPSNR;BestPSNR\n";
final String CSV_HEADER = "It;MSE;PSNR\n";
try {
FileOutputStream fileStream = new FileOutputStream(getFileNamePathIntoOutDir(filename));
OutputStreamWriter writer = new OutputStreamWriter(fileStream);
......@@ -194,12 +196,10 @@ abstract class BenchmarkBase {
writer.write(CSV_HEADER);
for (final QTrainIteration it : trainingLog) {
writer.write(String.format("%d;%.5f;%.5f;%.5f;%.5f\n",
writer.write(String.format("%d;%.5f;%.5f\n",
it.getIteration(),
it.getAverageMSE(),
it.getBestMSE(),
it.getAveragePSNR(),
it.getBestPSNR()));
it.getMse(),
it.getPSNR()));
}
writer.flush();
fileStream.flush();
......
......@@ -7,6 +7,7 @@ import azgracompress.quantization.QTrainIteration;
import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization;
import azgracompress.quantization.scalar.SQCodebook;
import azgracompress.quantization.scalar.ScalarQuantizer;
import azgracompress.utilities.Utils;
import java.io.File;
import java.io.FileOutputStream;
......@@ -46,7 +47,7 @@ public class SQBenchmark extends BenchmarkBase {
System.err.println("Failed to load middle plane data.");
return;
}
quantizer = trainLloydMaxQuantizer(middlePlaneData, codebookSize);
quantizer = trainLloydMaxQuantizer(middlePlaneData, codebookSize, null);
System.out.println("Created quantizer from middle plane.");
}
......@@ -59,8 +60,15 @@ public class SQBenchmark extends BenchmarkBase {
return;
}
final String quantizedFile = String.format(QUANTIZED_FILE_TEMPLATE, planeIndex, codebookSize);
final String diffFile = String.format(DIFFERENCE_FILE_TEMPLATE, planeIndex, codebookSize);
final String absoluteDiffFile = String.format(ABSOLUTE_DIFFERENCE_FILE_TEMPLATE,
planeIndex,
codebookSize);
final String trainLogFile = String.format(TRAIN_FILE_TEMPLATE, planeIndex, codebookSize);
if (!hasGeneralQuantizer) {
quantizer = trainLloydMaxQuantizer(planeData, codebookSize);
quantizer = trainLloydMaxQuantizer(planeData, codebookSize, trainLogFile);
System.out.println("Created plane quantizer");
}
......@@ -69,15 +77,15 @@ public class SQBenchmark extends BenchmarkBase {
return;
}
// TODO(Moravec): Add huffman coding.
final int[] quantizedData = quantizer.quantize(planeData);
final String quantizedFile = String.format(QUANTIZED_FILE_TEMPLATE, planeIndex, codebookSize);
final String diffFile = String.format(DIFFERENCE_FILE_TEMPLATE, planeIndex, codebookSize);
final String absoluteDiffFile = String.format(ABSOLUTE_DIFFERENCE_FILE_TEMPLATE,
planeIndex,
codebookSize);
{
final int[] diffArray = Utils.getDifference(planeData, quantizedData);
final double mse = Utils.calculateMse(diffArray);
final double PSNR = Utils.calculatePsnr(mse, U16.Max);
System.out.println(String.format("MSE: %.4f\tPNSR: %.4f(dB)", mse, PSNR));
}
final int[] quantizedData = quantizer.quantize(planeData);
if (!saveQuantizedPlaneData(quantizedData, quantizedFile)) {
System.err.println("Failed to save quantized plane.");
......@@ -113,10 +121,15 @@ public class SQBenchmark extends BenchmarkBase {
}
}
private ScalarQuantizer trainLloydMaxQuantizer(final int[] data, final int codebookSize) {
private ScalarQuantizer trainLloydMaxQuantizer(final int[] data,
final int codebookSize,
final String trainLogFile) {
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize, workerCount);
QTrainIteration[] trainingReport = lloydMax.train(false);
// saveQTrainLog(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize), trainingReport);
if (trainLogFile != null) {
saveQTrainLog(trainLogFile, trainingReport);
System.out.println("Saved the train log file to: " + trainLogFile);
}
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook());
}
}
......@@ -2,37 +2,25 @@ package azgracompress.quantization;
public class QTrainIteration {
private final int iteration;
private final double averageMSE;
private final double bestMSE;
private final double averagePSNR;
private final double bestPSNR;
private final double mse;
private final double PSNR;
public QTrainIteration(int iteration, double averageMSE, double bestMSE, double averagePSNR, double bestPSNR) {
public QTrainIteration(int iteration, double mse, double psnr) {
this.iteration = iteration;
this.averageMSE = averageMSE;
this.bestMSE = bestMSE;
this.averagePSNR = averagePSNR;
this.bestPSNR = bestPSNR;
this.mse = mse;
this.PSNR = psnr;
}
public int getIteration() {
return iteration;
}
public double getAverageMSE() {
return averageMSE;
public double getMse() {
return mse;
}
public double getBestMSE() {
return bestMSE;
}
public double getAveragePSNR() {
return averagePSNR;
}
public double getBestPSNR() {
return bestPSNR;
public double getPSNR() {
return PSNR;
}
}
......@@ -225,7 +225,7 @@ public class LloydMaxU16ScalarQuantization {
System.out.println(String.format("Initial MSE: %f", currentMse));
}
solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr));
solutionHistory.add(new QTrainIteration(0, currentMse, psnr));
double mseImprovement = 1;
int iteration = 0;
......@@ -240,7 +240,7 @@ public class LloydMaxU16ScalarQuantization {
mseImprovement = prevMse - currentMse;
psnr = Utils.calculatePsnr(currentMse, U16.Max);
solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr));
solutionHistory.add(new QTrainIteration(++iteration, currentMse, psnr));
if (verbose) {
System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment