Skip to content
Snippets Groups Projects
Commit 825d7a03 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Experimental rewrite of LloydMax algorithm.

- Calculate quantizaton values in range [min, max], where min and max correspond to the minimal and maximal values found in the data.

- Initialize centroids to the value with the highest pdf.
- Changed the terminal condition of the main loop, from distortion to the minimal MSE improvement.
parent 181b721d
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ package azgracompress.quantization.scalar;
import azgracompress.U16;
import azgracompress.quantization.QTrainIteration;
import azgracompress.utilities.MinMaxResult;
import azgracompress.utilities.Stopwatch;
import azgracompress.utilities.Utils;
......@@ -12,6 +13,9 @@ public class LloydMaxU16ScalarQuantization {
private final int[] trainingData;
private int codebookSize;
private int dataMin;
private int dataMax;
private int dataSpan;
private int[] centroids;
private int[] boundaryPoints;
private double[] pdf;
......@@ -32,13 +36,18 @@ public class LloydMaxU16ScalarQuantization {
private void initialize() {
centroids = new int[codebookSize];
centroids[0] = 0;
boundaryPoints = new int[codebookSize + 1];
boundaryPoints[0] = U16.Min;
boundaryPoints[codebookSize] = U16.Max;
MinMaxResult<Integer> minMax = Utils.getMinAndMax(trainingData);
dataMin = minMax.getMin();
dataMax = minMax.getMax();
dataSpan = dataMax - dataMin;
centroids[0] = dataMin;
double intervalSize = (double) (U16.Max - U16.Min) / (double) codebookSize;
boundaryPoints[0] = dataMin;
boundaryPoints[codebookSize] = dataMax;
double intervalSize = (double) (dataSpan) / (double) codebookSize;
for (int i = 0; i < codebookSize; i++) {
centroids[i] = (int) Math.floor(((double) i + 0.5) * intervalSize);
}
......@@ -49,9 +58,11 @@ public class LloydMaxU16ScalarQuantization {
// Speedup - for now it is fast enough
Stopwatch s = new Stopwatch();
s.start();
for (int i = 0; i < trainingData.length; i++) {
pdf[trainingData[i]] += 1;
pdf[trainingData[i]] += 1.0;
}
s.stop();
if (verbose) {
System.out.println("Init_PDF: " + s.getElapsedTimeString());
......@@ -60,7 +71,24 @@ public class LloydMaxU16ScalarQuantization {
private void recalculateBoundaryPoints() {
for (int j = 1; j < codebookSize; j++) {
boundaryPoints[j] = (centroids[j] + centroids[j - 1]) / 2;
boundaryPoints[j] = Math.min(dataMax,
(int) Math.floor(((double) centroids[j] + (double) centroids[j - 1]) / 2.0));
}
}
private void initializeCentroids() {
int lowerBound, upperBound;
double[] centroidPdf = new double[codebookSize];
for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) {
lowerBound = boundaryPoints[centroidIndex];
upperBound = boundaryPoints[centroidIndex + 1];
for (int rangeValue = lowerBound; rangeValue <= upperBound; rangeValue++) {
if (pdf[rangeValue] > centroidPdf[centroidIndex]) {
centroidPdf[centroidIndex] = pdf[rangeValue];
centroids[centroidIndex] = rangeValue;
}
}
}
}
......@@ -70,13 +98,13 @@ public class LloydMaxU16ScalarQuantization {
int lowerBound, upperBound;
for (int j = 0; j < codebookSize; j++) {
for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) {
numerator = 0.0;
denominator = 0.0;
lowerBound = boundaryPoints[j];
upperBound = boundaryPoints[j + 1];
lowerBound = boundaryPoints[centroidIndex];
upperBound = boundaryPoints[centroidIndex + 1];
for (int n = lowerBound; n <= upperBound; n++) {
numerator += (double) n * pdf[n];
......@@ -84,7 +112,7 @@ public class LloydMaxU16ScalarQuantization {
}
if (denominator > 0) {
centroids[j] = (int) Math.floor(numerator / denominator);
centroids[centroidIndex] = (int) Math.floor(numerator / denominator);
}
}
}
......@@ -98,6 +126,16 @@ public class LloydMaxU16ScalarQuantization {
throw new RuntimeException("Value couldn't be quantized!");
}
private double calculateMAE() {
double mae = 0.0;
for (final int trainingDatum : trainingData) {
int quantizedValue = quantize(trainingDatum);
mae += Math.abs((double) trainingDatum - (double) quantizedValue);
}
return (mae / (double) trainingData.length);
}
private double getCurrentMse() {
double mse = 0.0;
......@@ -149,6 +187,8 @@ public class LloydMaxU16ScalarQuantization {
public QTrainIteration[] train(final boolean shouldBeVerbose) {
this.verbose = shouldBeVerbose;
final int RECALCULATE_N_TIMES = 10;
final int PATIENCE = 1;
int noImprovementCounter = 0;
if (verbose) {
System.out.println("Training data count: " + trainingData.length);
}
......@@ -156,6 +196,8 @@ public class LloydMaxU16ScalarQuantization {
initialize();
initializeProbabilityDensityFunction();
double currMAE = 1.0;
double prevMse = 1.0;
double currentMse = 1.0;
double psnr;
......@@ -163,7 +205,8 @@ public class LloydMaxU16ScalarQuantization {
ArrayList<QTrainIteration> solutionHistory = new ArrayList<>();
recalculateBoundaryPoints();
recalculateCentroids();
// recalculateCentroids();
initializeCentroids();
currentMse = getCurrentMse();
psnr = Utils.calculatePsnr(currentMse, U16.Max);
......@@ -174,27 +217,59 @@ public class LloydMaxU16ScalarQuantization {
solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr));
double dist = 1;
double mseImprovement = 1;
int iteration = 0;
do {
for (int i = 0; i < RECALCULATE_N_TIMES; i++) {
recalculateBoundaryPoints();
recalculateCentroids();
}
// TODO(Moravec): Check if we are improving MSE.
// Save the best centroids, the lowest MSE.
currMAE = calculateMAE();
prevMse = currentMse;
currentMse = getCurrentMse();
mseImprovement = prevMse - currentMse;
// System.out.println(String.format("Improvement: %.4f", mseImprovement));
// if ((prevMAE < currMAE) && (iteration != 0)) {
// System.err.println(String.format(
// "MAE = +%.5f",
// currMAE - prevMAE));
// }
psnr = Utils.calculatePsnr(currentMse, U16.Max);
solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr));
dist = (prevMse - currentMse) / currentMse;
// dist = (prevMse - currentMse) / currentMse;
if (verbose) {
System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB", currentMse, psnr));
System.out.println(String.format("Current MAE: %.4f MSE: %.4f PSNR: %.4f dB",
currMAE,
currentMse,
psnr));
}
} while (dist > 0.001); //0.0005
// if (mseImprovement < 1.0 && mseImprovement > 0.0005) {
// System.out.println("----- low improvement " + mseImprovement);
// }
if (mseImprovement < 1.0) {
if ((++noImprovementCounter) >= PATIENCE) {
break;
}
}
} while (true); //0.001 //0.0005// || currMAE > 1500//mseImprovement > 0.0005
if (verbose) {
System.out.println("\nFinished training.");
}
System.out.println(String.format("Final MAE: %.4f after %d iterations", currMAE, iteration));
return solutionHistory.toArray(new QTrainIteration[0]);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment