diff --git a/src/main/java/azgracompress/benchmark/BenchmarkBase.java b/src/main/java/azgracompress/benchmark/BenchmarkBase.java
index 18aad38e617da1e8fa0cd6474cf350977437d300..1d159e2360e85c366b1f6a07dd43f1376a1cb4ba 100644
--- a/src/main/java/azgracompress/benchmark/BenchmarkBase.java
+++ b/src/main/java/azgracompress/benchmark/BenchmarkBase.java
@@ -17,6 +17,8 @@ abstract class BenchmarkBase {
protected final static String QUANTIZED_FILE_TEMPLATE = "%d_cb%d.raw";
protected final static String DIFFERENCE_FILE_TEMPLATE = "%d_cb%d.data";
protected final static String ABSOLUTE_DIFFERENCE_FILE_TEMPLATE = "%d_cb%d_abs.data";
+ protected final static String TRAIN_FILE_TEMPLATE = "%d_cb%d_trainLog.csv";
+
protected final String inputFile;
protected final String outputDirectory;
@@ -186,7 +188,7 @@ abstract class BenchmarkBase {
* @param trainingLog QTrainingLog
*/
protected void saveQTrainLog(final String filename, final QTrainIteration[] trainingLog) {
- final String CSV_HEADER = "It;AvgMSE;BestMSE;AvgPSNR;BestPSNR\n";
+ final String CSV_HEADER = "It;MSE;PSNR\n";
try {
FileOutputStream fileStream = new FileOutputStream(getFileNamePathIntoOutDir(filename));
OutputStreamWriter writer = new OutputStreamWriter(fileStream);
@@ -194,12 +196,10 @@ abstract class BenchmarkBase {
writer.write(CSV_HEADER);
for (final QTrainIteration it : trainingLog) {
- writer.write(String.format("%d;%.5f;%.5f;%.5f;%.5f\n",
+ writer.write(String.format("%d;%.5f;%.5f\n",
it.getIteration(),
- it.getAverageMSE(),
- it.getBestMSE(),
- it.getAveragePSNR(),
- it.getBestPSNR()));
+ it.getMse(),
+ it.getPSNR()));
}
writer.flush();
fileStream.flush();
diff --git a/src/main/java/azgracompress/benchmark/SQBenchmark.java b/src/main/java/azgracompress/benchmark/SQBenchmark.java
index 4ec378a0310c67cbc6a6b48d389a3560c64bd433..50b9e0f373bf289626341187da40701abd8ba901 100644
--- a/src/main/java/azgracompress/benchmark/SQBenchmark.java
+++ b/src/main/java/azgracompress/benchmark/SQBenchmark.java
@@ -7,6 +7,7 @@ import azgracompress.quantization.QTrainIteration;
import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization;
import azgracompress.quantization.scalar.SQCodebook;
import azgracompress.quantization.scalar.ScalarQuantizer;
+import azgracompress.utilities.Utils;
import java.io.File;
import java.io.FileOutputStream;
@@ -46,7 +47,7 @@ public class SQBenchmark extends BenchmarkBase {
System.err.println("Failed to load middle plane data.");
return;
}
- quantizer = trainLloydMaxQuantizer(middlePlaneData, codebookSize);
+ quantizer = trainLloydMaxQuantizer(middlePlaneData, codebookSize, null);
System.out.println("Created quantizer from middle plane.");
}
@@ -59,8 +60,15 @@ public class SQBenchmark extends BenchmarkBase {
return;
}
+ final String quantizedFile = String.format(QUANTIZED_FILE_TEMPLATE, planeIndex, codebookSize);
+ final String diffFile = String.format(DIFFERENCE_FILE_TEMPLATE, planeIndex, codebookSize);
+ final String absoluteDiffFile = String.format(ABSOLUTE_DIFFERENCE_FILE_TEMPLATE,
+ planeIndex,
+ codebookSize);
+ final String trainLogFile = String.format(TRAIN_FILE_TEMPLATE, planeIndex, codebookSize);
+
if (!hasGeneralQuantizer) {
- quantizer = trainLloydMaxQuantizer(planeData, codebookSize);
+ quantizer = trainLloydMaxQuantizer(planeData, codebookSize, trainLogFile);
System.out.println("Created plane quantizer");
}
@@ -69,15 +77,15 @@ public class SQBenchmark extends BenchmarkBase {
return;
}
- // TODO(Moravec): Add huffman coding.
+ final int[] quantizedData = quantizer.quantize(planeData);
- final String quantizedFile = String.format(QUANTIZED_FILE_TEMPLATE, planeIndex, codebookSize);
- final String diffFile = String.format(DIFFERENCE_FILE_TEMPLATE, planeIndex, codebookSize);
- final String absoluteDiffFile = String.format(ABSOLUTE_DIFFERENCE_FILE_TEMPLATE,
- planeIndex,
- codebookSize);
+ {
+ final int[] diffArray = Utils.getDifference(planeData, quantizedData);
+ final double mse = Utils.calculateMse(diffArray);
+ final double PSNR = Utils.calculatePsnr(mse, U16.Max);
+ System.out.println(String.format("MSE: %.4f\tPNSR: %.4f(dB)", mse, PSNR));
+ }
- final int[] quantizedData = quantizer.quantize(planeData);
if (!saveQuantizedPlaneData(quantizedData, quantizedFile)) {
System.err.println("Failed to save quantized plane.");
@@ -113,10 +121,15 @@ public class SQBenchmark extends BenchmarkBase {
}
}
- private ScalarQuantizer trainLloydMaxQuantizer(final int[] data, final int codebookSize) {
+ private ScalarQuantizer trainLloydMaxQuantizer(final int[] data,
+ final int codebookSize,
+ final String trainLogFile) {
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize, workerCount);
QTrainIteration[] trainingReport = lloydMax.train(false);
-// saveQTrainLog(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize), trainingReport);
+ if (trainLogFile != null) {
+ saveQTrainLog(trainLogFile, trainingReport);
+ System.out.println("Saved the train log file to: " + trainLogFile);
+ }
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook());
}
}
diff --git a/src/main/java/azgracompress/quantization/QTrainIteration.java b/src/main/java/azgracompress/quantization/QTrainIteration.java
index c2b00aca4bff3dda377479cb58b393abe7c188ff..0f6d44498f5017955cacdbcc20fefc812518dc5d 100644
--- a/src/main/java/azgracompress/quantization/QTrainIteration.java
+++ b/src/main/java/azgracompress/quantization/QTrainIteration.java
@@ -2,37 +2,25 @@ package azgracompress.quantization;
public class QTrainIteration {
private final int iteration;
- private final double averageMSE;
- private final double bestMSE;
- private final double averagePSNR;
- private final double bestPSNR;
+ private final double mse;
+ private final double PSNR;
- public QTrainIteration(int iteration, double averageMSE, double bestMSE, double averagePSNR, double bestPSNR) {
+ public QTrainIteration(int iteration, double mse, double psnr) {
this.iteration = iteration;
- this.averageMSE = averageMSE;
- this.bestMSE = bestMSE;
- this.averagePSNR = averagePSNR;
- this.bestPSNR = bestPSNR;
+ this.mse = mse;
+ this.PSNR = psnr;
}
public int getIteration() {
return iteration;
}
- public double getAverageMSE() {
- return averageMSE;
+ public double getMse() {
+ return mse;
}
- public double getBestMSE() {
- return bestMSE;
- }
-
- public double getAveragePSNR() {
- return averagePSNR;
- }
-
- public double getBestPSNR() {
- return bestPSNR;
+ public double getPSNR() {
+ return PSNR;
}
}
diff --git a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java
index 14bf0a18b165bf334f44643d5fa263d878af9870..4e3f41dc51ed2b84c618c905bee893a82201f380 100644
--- a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java
+++ b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java
@@ -225,7 +225,7 @@ public class LloydMaxU16ScalarQuantization {
System.out.println(String.format("Initial MSE: %f", currentMse));
}
- solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr));
+ solutionHistory.add(new QTrainIteration(0, currentMse, psnr));
double mseImprovement = 1;
int iteration = 0;
@@ -240,7 +240,7 @@ public class LloydMaxU16ScalarQuantization {
mseImprovement = prevMse - currentMse;
psnr = Utils.calculatePsnr(currentMse, U16.Max);
- solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr));
+ solutionHistory.add(new QTrainIteration(++iteration, currentMse, psnr));
if (verbose) {
System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB",