diff --git a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java index dc463c38a3f4d999cf272f8b680635ff9a4dd6cd..681823e09d7334190814598b9bb2e485434d6287 100644 --- a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java +++ b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java @@ -2,6 +2,7 @@ package azgracompress.quantization.scalar; import azgracompress.U16; import azgracompress.quantization.QTrainIteration; +import azgracompress.utilities.MinMaxResult; import azgracompress.utilities.Stopwatch; import azgracompress.utilities.Utils; @@ -12,6 +13,9 @@ public class LloydMaxU16ScalarQuantization { private final int[] trainingData; private int codebookSize; + private int dataMin; + private int dataMax; + private int dataSpan; private int[] centroids; private int[] boundaryPoints; private double[] pdf; @@ -32,13 +36,18 @@ public class LloydMaxU16ScalarQuantization { private void initialize() { centroids = new int[codebookSize]; - centroids[0] = 0; + boundaryPoints = new int[codebookSize + 1]; - boundaryPoints[0] = U16.Min; - boundaryPoints[codebookSize] = U16.Max; + MinMaxResult<Integer> minMax = Utils.getMinAndMax(trainingData); + dataMin = minMax.getMin(); + dataMax = minMax.getMax(); + dataSpan = dataMax - dataMin; + centroids[0] = dataMin; - double intervalSize = (double) (U16.Max - U16.Min) / (double) codebookSize; + boundaryPoints[0] = dataMin; + boundaryPoints[codebookSize] = dataMax; + double intervalSize = (double) (dataSpan) / (double) codebookSize; for (int i = 0; i < codebookSize; i++) { centroids[i] = (int) Math.floor(((double) i + 0.5) * intervalSize); } @@ -49,9 +58,11 @@ public class LloydMaxU16ScalarQuantization { // Speedup - for now it is fast enough Stopwatch s = new Stopwatch(); s.start(); + for (int i = 0; i < trainingData.length; i++) { - pdf[trainingData[i]] += 1; + pdf[trainingData[i]] += 1.0; } + s.stop(); if (verbose) { System.out.println("Init_PDF: " + s.getElapsedTimeString()); @@ -60,7 +71,24 @@ public class LloydMaxU16ScalarQuantization { private void recalculateBoundaryPoints() { for (int j = 1; j < codebookSize; j++) { - boundaryPoints[j] = (centroids[j] + centroids[j - 1]) / 2; + boundaryPoints[j] = Math.min(dataMax, + (int) Math.floor(((double) centroids[j] + (double) centroids[j - 1]) / 2.0)); + } + } + + private void initializeCentroids() { + int lowerBound, upperBound; + double[] centroidPdf = new double[codebookSize]; + for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) { + lowerBound = boundaryPoints[centroidIndex]; + upperBound = boundaryPoints[centroidIndex + 1]; + + for (int rangeValue = lowerBound; rangeValue <= upperBound; rangeValue++) { + if (pdf[rangeValue] > centroidPdf[centroidIndex]) { + centroidPdf[centroidIndex] = pdf[rangeValue]; + centroids[centroidIndex] = rangeValue; + } + } } } @@ -70,13 +98,13 @@ public class LloydMaxU16ScalarQuantization { int lowerBound, upperBound; - for (int j = 0; j < codebookSize; j++) { + for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) { numerator = 0.0; denominator = 0.0; - lowerBound = boundaryPoints[j]; - upperBound = boundaryPoints[j + 1]; + lowerBound = boundaryPoints[centroidIndex]; + upperBound = boundaryPoints[centroidIndex + 1]; for (int n = lowerBound; n <= upperBound; n++) { numerator += (double) n * pdf[n]; @@ -84,7 +112,7 @@ public class LloydMaxU16ScalarQuantization { } if (denominator > 0) { - centroids[j] = (int) Math.floor(numerator / denominator); + centroids[centroidIndex] = (int) Math.floor(numerator / denominator); } } } @@ -98,6 +126,16 @@ public class LloydMaxU16ScalarQuantization { throw new RuntimeException("Value couldn't be quantized!"); } + private double calculateMAE() { + double mae = 0.0; + for (final int trainingDatum : trainingData) { + int quantizedValue = quantize(trainingDatum); + mae += Math.abs((double) trainingDatum - (double) quantizedValue); + } + return (mae / (double) trainingData.length); + } + + private double getCurrentMse() { double mse = 0.0; @@ -149,6 +187,8 @@ public class LloydMaxU16ScalarQuantization { public QTrainIteration[] train(final boolean shouldBeVerbose) { this.verbose = shouldBeVerbose; final int RECALCULATE_N_TIMES = 10; + final int PATIENCE = 1; + int noImprovementCounter = 0; if (verbose) { System.out.println("Training data count: " + trainingData.length); } @@ -156,6 +196,8 @@ public class LloydMaxU16ScalarQuantization { initialize(); initializeProbabilityDensityFunction(); + + double currMAE = 1.0; double prevMse = 1.0; double currentMse = 1.0; double psnr; @@ -163,7 +205,8 @@ public class LloydMaxU16ScalarQuantization { ArrayList<QTrainIteration> solutionHistory = new ArrayList<>(); recalculateBoundaryPoints(); - recalculateCentroids(); + // recalculateCentroids(); + initializeCentroids(); currentMse = getCurrentMse(); psnr = Utils.calculatePsnr(currentMse, U16.Max); @@ -174,27 +217,59 @@ public class LloydMaxU16ScalarQuantization { solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr)); - double dist = 1; + double mseImprovement = 1; int iteration = 0; do { for (int i = 0; i < RECALCULATE_N_TIMES; i++) { recalculateBoundaryPoints(); recalculateCentroids(); } + + // TODO(Moravec): Check if we are improving MSE. + // Save the best centroids, the lowest MSE. + + currMAE = calculateMAE(); + prevMse = currentMse; currentMse = getCurrentMse(); + mseImprovement = prevMse - currentMse; + + // System.out.println(String.format("Improvement: %.4f", mseImprovement)); + + // if ((prevMAE < currMAE) && (iteration != 0)) { + // System.err.println(String.format( + // "MAE = +%.5f", + // currMAE - prevMAE)); + // } + psnr = Utils.calculatePsnr(currentMse, U16.Max); solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr)); - dist = (prevMse - currentMse) / currentMse; + // dist = (prevMse - currentMse) / currentMse; if (verbose) { - System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB", currentMse, psnr)); + System.out.println(String.format("Current MAE: %.4f MSE: %.4f PSNR: %.4f dB", + currMAE, + currentMse, + psnr)); } - } while (dist > 0.001); //0.0005 + // if (mseImprovement < 1.0 && mseImprovement > 0.0005) { + // System.out.println("----- low improvement " + mseImprovement); + // } + + if (mseImprovement < 1.0) { + if ((++noImprovementCounter) >= PATIENCE) { + break; + } + + } + + + } while (true); //0.001 //0.0005// || currMAE > 1500//mseImprovement > 0.0005 if (verbose) { System.out.println("\nFinished training."); } + System.out.println(String.format("Final MAE: %.4f after %d iterations", currMAE, iteration)); return solutionHistory.toArray(new QTrainIteration[0]); }