diff --git a/src/main/java/azgracompress/benchmark/BenchmarkBase.java b/src/main/java/azgracompress/benchmark/BenchmarkBase.java index 18aad38e617da1e8fa0cd6474cf350977437d300..1d159e2360e85c366b1f6a07dd43f1376a1cb4ba 100644 --- a/src/main/java/azgracompress/benchmark/BenchmarkBase.java +++ b/src/main/java/azgracompress/benchmark/BenchmarkBase.java @@ -17,6 +17,8 @@ abstract class BenchmarkBase { protected final static String QUANTIZED_FILE_TEMPLATE = "%d_cb%d.raw"; protected final static String DIFFERENCE_FILE_TEMPLATE = "%d_cb%d.data"; protected final static String ABSOLUTE_DIFFERENCE_FILE_TEMPLATE = "%d_cb%d_abs.data"; + protected final static String TRAIN_FILE_TEMPLATE = "%d_cb%d_trainLog.csv"; + protected final String inputFile; protected final String outputDirectory; @@ -186,7 +188,7 @@ abstract class BenchmarkBase { * @param trainingLog QTrainingLog */ protected void saveQTrainLog(final String filename, final QTrainIteration[] trainingLog) { - final String CSV_HEADER = "It;AvgMSE;BestMSE;AvgPSNR;BestPSNR\n"; + final String CSV_HEADER = "It;MSE;PSNR\n"; try { FileOutputStream fileStream = new FileOutputStream(getFileNamePathIntoOutDir(filename)); OutputStreamWriter writer = new OutputStreamWriter(fileStream); @@ -194,12 +196,10 @@ abstract class BenchmarkBase { writer.write(CSV_HEADER); for (final QTrainIteration it : trainingLog) { - writer.write(String.format("%d;%.5f;%.5f;%.5f;%.5f\n", + writer.write(String.format("%d;%.5f;%.5f\n", it.getIteration(), - it.getAverageMSE(), - it.getBestMSE(), - it.getAveragePSNR(), - it.getBestPSNR())); + it.getMse(), + it.getPSNR())); } writer.flush(); fileStream.flush(); diff --git a/src/main/java/azgracompress/benchmark/SQBenchmark.java b/src/main/java/azgracompress/benchmark/SQBenchmark.java index 4ec378a0310c67cbc6a6b48d389a3560c64bd433..50b9e0f373bf289626341187da40701abd8ba901 100644 --- a/src/main/java/azgracompress/benchmark/SQBenchmark.java +++ b/src/main/java/azgracompress/benchmark/SQBenchmark.java @@ -7,6 +7,7 @@ import azgracompress.quantization.QTrainIteration; import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization; import azgracompress.quantization.scalar.SQCodebook; import azgracompress.quantization.scalar.ScalarQuantizer; +import azgracompress.utilities.Utils; import java.io.File; import java.io.FileOutputStream; @@ -46,7 +47,7 @@ public class SQBenchmark extends BenchmarkBase { System.err.println("Failed to load middle plane data."); return; } - quantizer = trainLloydMaxQuantizer(middlePlaneData, codebookSize); + quantizer = trainLloydMaxQuantizer(middlePlaneData, codebookSize, null); System.out.println("Created quantizer from middle plane."); } @@ -59,8 +60,15 @@ public class SQBenchmark extends BenchmarkBase { return; } + final String quantizedFile = String.format(QUANTIZED_FILE_TEMPLATE, planeIndex, codebookSize); + final String diffFile = String.format(DIFFERENCE_FILE_TEMPLATE, planeIndex, codebookSize); + final String absoluteDiffFile = String.format(ABSOLUTE_DIFFERENCE_FILE_TEMPLATE, + planeIndex, + codebookSize); + final String trainLogFile = String.format(TRAIN_FILE_TEMPLATE, planeIndex, codebookSize); + if (!hasGeneralQuantizer) { - quantizer = trainLloydMaxQuantizer(planeData, codebookSize); + quantizer = trainLloydMaxQuantizer(planeData, codebookSize, trainLogFile); System.out.println("Created plane quantizer"); } @@ -69,15 +77,15 @@ public class SQBenchmark extends BenchmarkBase { return; } - // TODO(Moravec): Add huffman coding. + final int[] quantizedData = quantizer.quantize(planeData); - final String quantizedFile = String.format(QUANTIZED_FILE_TEMPLATE, planeIndex, codebookSize); - final String diffFile = String.format(DIFFERENCE_FILE_TEMPLATE, planeIndex, codebookSize); - final String absoluteDiffFile = String.format(ABSOLUTE_DIFFERENCE_FILE_TEMPLATE, - planeIndex, - codebookSize); + { + final int[] diffArray = Utils.getDifference(planeData, quantizedData); + final double mse = Utils.calculateMse(diffArray); + final double PSNR = Utils.calculatePsnr(mse, U16.Max); + System.out.println(String.format("MSE: %.4f\tPNSR: %.4f(dB)", mse, PSNR)); + } - final int[] quantizedData = quantizer.quantize(planeData); if (!saveQuantizedPlaneData(quantizedData, quantizedFile)) { System.err.println("Failed to save quantized plane."); @@ -113,10 +121,15 @@ public class SQBenchmark extends BenchmarkBase { } } - private ScalarQuantizer trainLloydMaxQuantizer(final int[] data, final int codebookSize) { + private ScalarQuantizer trainLloydMaxQuantizer(final int[] data, + final int codebookSize, + final String trainLogFile) { LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize, workerCount); QTrainIteration[] trainingReport = lloydMax.train(false); -// saveQTrainLog(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize), trainingReport); + if (trainLogFile != null) { + saveQTrainLog(trainLogFile, trainingReport); + System.out.println("Saved the train log file to: " + trainLogFile); + } return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook()); } } diff --git a/src/main/java/azgracompress/quantization/QTrainIteration.java b/src/main/java/azgracompress/quantization/QTrainIteration.java index c2b00aca4bff3dda377479cb58b393abe7c188ff..0f6d44498f5017955cacdbcc20fefc812518dc5d 100644 --- a/src/main/java/azgracompress/quantization/QTrainIteration.java +++ b/src/main/java/azgracompress/quantization/QTrainIteration.java @@ -2,37 +2,25 @@ package azgracompress.quantization; public class QTrainIteration { private final int iteration; - private final double averageMSE; - private final double bestMSE; - private final double averagePSNR; - private final double bestPSNR; + private final double mse; + private final double PSNR; - public QTrainIteration(int iteration, double averageMSE, double bestMSE, double averagePSNR, double bestPSNR) { + public QTrainIteration(int iteration, double mse, double psnr) { this.iteration = iteration; - this.averageMSE = averageMSE; - this.bestMSE = bestMSE; - this.averagePSNR = averagePSNR; - this.bestPSNR = bestPSNR; + this.mse = mse; + this.PSNR = psnr; } public int getIteration() { return iteration; } - public double getAverageMSE() { - return averageMSE; + public double getMse() { + return mse; } - public double getBestMSE() { - return bestMSE; - } - - public double getAveragePSNR() { - return averagePSNR; - } - - public double getBestPSNR() { - return bestPSNR; + public double getPSNR() { + return PSNR; } } diff --git a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java index 14bf0a18b165bf334f44643d5fa263d878af9870..4e3f41dc51ed2b84c618c905bee893a82201f380 100644 --- a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java +++ b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java @@ -225,7 +225,7 @@ public class LloydMaxU16ScalarQuantization { System.out.println(String.format("Initial MSE: %f", currentMse)); } - solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr)); + solutionHistory.add(new QTrainIteration(0, currentMse, psnr)); double mseImprovement = 1; int iteration = 0; @@ -240,7 +240,7 @@ public class LloydMaxU16ScalarQuantization { mseImprovement = prevMse - currentMse; psnr = Utils.calculatePsnr(currentMse, U16.Max); - solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr)); + solutionHistory.add(new QTrainIteration(++iteration, currentMse, psnr)); if (verbose) { System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB",