Skip to content
Snippets Groups Projects
Commit b88e2059 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Report MSE,PSNR in benchmark and save train file.

parent b5d7ab5a
Branches
No related tags found
No related merge requests found
...@@ -17,6 +17,8 @@ abstract class BenchmarkBase { ...@@ -17,6 +17,8 @@ abstract class BenchmarkBase {
protected final static String QUANTIZED_FILE_TEMPLATE = "%d_cb%d.raw"; protected final static String QUANTIZED_FILE_TEMPLATE = "%d_cb%d.raw";
protected final static String DIFFERENCE_FILE_TEMPLATE = "%d_cb%d.data"; protected final static String DIFFERENCE_FILE_TEMPLATE = "%d_cb%d.data";
protected final static String ABSOLUTE_DIFFERENCE_FILE_TEMPLATE = "%d_cb%d_abs.data"; protected final static String ABSOLUTE_DIFFERENCE_FILE_TEMPLATE = "%d_cb%d_abs.data";
protected final static String TRAIN_FILE_TEMPLATE = "%d_cb%d_trainLog.csv";
protected final String inputFile; protected final String inputFile;
protected final String outputDirectory; protected final String outputDirectory;
...@@ -186,7 +188,7 @@ abstract class BenchmarkBase { ...@@ -186,7 +188,7 @@ abstract class BenchmarkBase {
* @param trainingLog QTrainingLog * @param trainingLog QTrainingLog
*/ */
protected void saveQTrainLog(final String filename, final QTrainIteration[] trainingLog) { protected void saveQTrainLog(final String filename, final QTrainIteration[] trainingLog) {
final String CSV_HEADER = "It;AvgMSE;BestMSE;AvgPSNR;BestPSNR\n"; final String CSV_HEADER = "It;MSE;PSNR\n";
try { try {
FileOutputStream fileStream = new FileOutputStream(getFileNamePathIntoOutDir(filename)); FileOutputStream fileStream = new FileOutputStream(getFileNamePathIntoOutDir(filename));
OutputStreamWriter writer = new OutputStreamWriter(fileStream); OutputStreamWriter writer = new OutputStreamWriter(fileStream);
...@@ -194,12 +196,10 @@ abstract class BenchmarkBase { ...@@ -194,12 +196,10 @@ abstract class BenchmarkBase {
writer.write(CSV_HEADER); writer.write(CSV_HEADER);
for (final QTrainIteration it : trainingLog) { for (final QTrainIteration it : trainingLog) {
writer.write(String.format("%d;%.5f;%.5f;%.5f;%.5f\n", writer.write(String.format("%d;%.5f;%.5f\n",
it.getIteration(), it.getIteration(),
it.getAverageMSE(), it.getMse(),
it.getBestMSE(), it.getPSNR()));
it.getAveragePSNR(),
it.getBestPSNR()));
} }
writer.flush(); writer.flush();
fileStream.flush(); fileStream.flush();
......
...@@ -7,6 +7,7 @@ import azgracompress.quantization.QTrainIteration; ...@@ -7,6 +7,7 @@ import azgracompress.quantization.QTrainIteration;
import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization; import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization;
import azgracompress.quantization.scalar.SQCodebook; import azgracompress.quantization.scalar.SQCodebook;
import azgracompress.quantization.scalar.ScalarQuantizer; import azgracompress.quantization.scalar.ScalarQuantizer;
import azgracompress.utilities.Utils;
import java.io.File; import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
...@@ -46,7 +47,7 @@ public class SQBenchmark extends BenchmarkBase { ...@@ -46,7 +47,7 @@ public class SQBenchmark extends BenchmarkBase {
System.err.println("Failed to load middle plane data."); System.err.println("Failed to load middle plane data.");
return; return;
} }
quantizer = trainLloydMaxQuantizer(middlePlaneData, codebookSize); quantizer = trainLloydMaxQuantizer(middlePlaneData, codebookSize, null);
System.out.println("Created quantizer from middle plane."); System.out.println("Created quantizer from middle plane.");
} }
...@@ -59,8 +60,15 @@ public class SQBenchmark extends BenchmarkBase { ...@@ -59,8 +60,15 @@ public class SQBenchmark extends BenchmarkBase {
return; return;
} }
final String quantizedFile = String.format(QUANTIZED_FILE_TEMPLATE, planeIndex, codebookSize);
final String diffFile = String.format(DIFFERENCE_FILE_TEMPLATE, planeIndex, codebookSize);
final String absoluteDiffFile = String.format(ABSOLUTE_DIFFERENCE_FILE_TEMPLATE,
planeIndex,
codebookSize);
final String trainLogFile = String.format(TRAIN_FILE_TEMPLATE, planeIndex, codebookSize);
if (!hasGeneralQuantizer) { if (!hasGeneralQuantizer) {
quantizer = trainLloydMaxQuantizer(planeData, codebookSize); quantizer = trainLloydMaxQuantizer(planeData, codebookSize, trainLogFile);
System.out.println("Created plane quantizer"); System.out.println("Created plane quantizer");
} }
...@@ -69,15 +77,15 @@ public class SQBenchmark extends BenchmarkBase { ...@@ -69,15 +77,15 @@ public class SQBenchmark extends BenchmarkBase {
return; return;
} }
// TODO(Moravec): Add huffman coding. final int[] quantizedData = quantizer.quantize(planeData);
final String quantizedFile = String.format(QUANTIZED_FILE_TEMPLATE, planeIndex, codebookSize); {
final String diffFile = String.format(DIFFERENCE_FILE_TEMPLATE, planeIndex, codebookSize); final int[] diffArray = Utils.getDifference(planeData, quantizedData);
final String absoluteDiffFile = String.format(ABSOLUTE_DIFFERENCE_FILE_TEMPLATE, final double mse = Utils.calculateMse(diffArray);
planeIndex, final double PSNR = Utils.calculatePsnr(mse, U16.Max);
codebookSize); System.out.println(String.format("MSE: %.4f\tPNSR: %.4f(dB)", mse, PSNR));
}
final int[] quantizedData = quantizer.quantize(planeData);
if (!saveQuantizedPlaneData(quantizedData, quantizedFile)) { if (!saveQuantizedPlaneData(quantizedData, quantizedFile)) {
System.err.println("Failed to save quantized plane."); System.err.println("Failed to save quantized plane.");
...@@ -113,10 +121,15 @@ public class SQBenchmark extends BenchmarkBase { ...@@ -113,10 +121,15 @@ public class SQBenchmark extends BenchmarkBase {
} }
} }
private ScalarQuantizer trainLloydMaxQuantizer(final int[] data, final int codebookSize) { private ScalarQuantizer trainLloydMaxQuantizer(final int[] data,
final int codebookSize,
final String trainLogFile) {
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize, workerCount); LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize, workerCount);
QTrainIteration[] trainingReport = lloydMax.train(false); QTrainIteration[] trainingReport = lloydMax.train(false);
// saveQTrainLog(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize), trainingReport); if (trainLogFile != null) {
saveQTrainLog(trainLogFile, trainingReport);
System.out.println("Saved the train log file to: " + trainLogFile);
}
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook()); return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook());
} }
} }
...@@ -2,37 +2,25 @@ package azgracompress.quantization; ...@@ -2,37 +2,25 @@ package azgracompress.quantization;
public class QTrainIteration { public class QTrainIteration {
private final int iteration; private final int iteration;
private final double averageMSE; private final double mse;
private final double bestMSE; private final double PSNR;
private final double averagePSNR;
private final double bestPSNR;
public QTrainIteration(int iteration, double averageMSE, double bestMSE, double averagePSNR, double bestPSNR) { public QTrainIteration(int iteration, double mse, double psnr) {
this.iteration = iteration; this.iteration = iteration;
this.averageMSE = averageMSE; this.mse = mse;
this.bestMSE = bestMSE; this.PSNR = psnr;
this.averagePSNR = averagePSNR;
this.bestPSNR = bestPSNR;
} }
public int getIteration() { public int getIteration() {
return iteration; return iteration;
} }
public double getAverageMSE() { public double getMse() {
return averageMSE; return mse;
} }
public double getBestMSE() { public double getPSNR() {
return bestMSE; return PSNR;
}
public double getAveragePSNR() {
return averagePSNR;
}
public double getBestPSNR() {
return bestPSNR;
} }
} }
...@@ -225,7 +225,7 @@ public class LloydMaxU16ScalarQuantization { ...@@ -225,7 +225,7 @@ public class LloydMaxU16ScalarQuantization {
System.out.println(String.format("Initial MSE: %f", currentMse)); System.out.println(String.format("Initial MSE: %f", currentMse));
} }
solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr)); solutionHistory.add(new QTrainIteration(0, currentMse, psnr));
double mseImprovement = 1; double mseImprovement = 1;
int iteration = 0; int iteration = 0;
...@@ -240,7 +240,7 @@ public class LloydMaxU16ScalarQuantization { ...@@ -240,7 +240,7 @@ public class LloydMaxU16ScalarQuantization {
mseImprovement = prevMse - currentMse; mseImprovement = prevMse - currentMse;
psnr = Utils.calculatePsnr(currentMse, U16.Max); psnr = Utils.calculatePsnr(currentMse, U16.Max);
solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr)); solutionHistory.add(new QTrainIteration(++iteration, currentMse, psnr));
if (verbose) { if (verbose) {
System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB", System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment