Skip to content
Snippets Groups Projects
Commit ba983b54 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Replace `verbose` logic with optional listener in LloydMax class.

parent d688404f
No related branches found
No related tags found
No related merge requests found
...@@ -147,7 +147,7 @@ public class SQBenchmark extends BenchmarkBase { ...@@ -147,7 +147,7 @@ public class SQBenchmark extends BenchmarkBase {
final int codebookSize, final int codebookSize,
final String trainLogFile) { final String trainLogFile) {
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize, workerCount); LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize, workerCount);
QTrainIteration[] trainingReport = lloydMax.train(false); QTrainIteration[] trainingReport = lloydMax.train();
if (trainLogFile != null) { if (trainLogFile != null) {
saveQTrainLog(trainLogFile, trainingReport); saveQTrainLog(trainLogFile, trainingReport);
System.out.println("Saved the train log file to: " + trainLogFile); System.out.println("Saved the train log file to: " + trainLogFile);
......
...@@ -31,9 +31,9 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -31,9 +31,9 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
private ScalarQuantizer trainScalarQuantizerFromData(final int[] planeData) { private ScalarQuantizer trainScalarQuantizerFromData(final int[] planeData) {
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(planeData, LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(planeData,
getCodebookSize(), getCodebookSize(),
options.getWorkerCount()); options.getWorkerCount());
lloydMax.train(false); lloydMax.train();
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook()); return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook());
} }
...@@ -77,7 +77,8 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -77,7 +77,8 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
trainAndSaveCodebook(); trainAndSaveCodebook();
} }
final SQCodebook codebook = cacheManager.loadSQCodebook(options.getInputDataInfo().getCacheFileName(), getCodebookSize()); final SQCodebook codebook = cacheManager.loadSQCodebook(options.getInputDataInfo().getCacheFileName(),
getCodebookSize());
if (codebook == null) { if (codebook == null) {
throw new ImageCompressionException("Failed to read quantization values from cache file."); throw new ImageCompressionException("Failed to read quantization values from cache file.");
} }
...@@ -166,7 +167,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -166,7 +167,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
stopwatch.stop(); stopwatch.stop();
reportProgressToListeners(planeIndex, planeIndices.length, reportProgressToListeners(planeIndex, planeIndices.length,
"Compressed plane %d in %s.", planeIndex, stopwatch.getElapsedTimeString()); "Compressed plane %d in %s.", planeIndex, stopwatch.getElapsedTimeString());
} }
return planeDataSizes; return planeDataSizes;
} }
...@@ -212,11 +213,14 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -212,11 +213,14 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
public void trainAndSaveCodebook() throws ImageCompressionException { public void trainAndSaveCodebook() throws ImageCompressionException {
int[] trainData = loadConfiguredPlanesData(); int[] trainData = loadConfiguredPlanesData();
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(trainData, LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(trainData,
getCodebookSize(), getCodebookSize(),
options.getWorkerCount()); options.getWorkerCount());
reportStatusToListeners("Starting LloydMax training."); reportStatusToListeners("Starting LloydMax training.");
lloydMax.train(options.isVerbose());
lloydMax.setStatusListener(this::reportStatusToListeners);
lloydMax.train();
final SQCodebook codebook = lloydMax.getCodebook(); final SQCodebook codebook = lloydMax.getCodebook();
reportStatusToListeners("Finished LloydMax training."); reportStatusToListeners("Finished LloydMax training.");
......
package azgracompress.quantization.scalar; package azgracompress.quantization.scalar;
import azgracompress.U16; import azgracompress.U16;
import azgracompress.compression.listeners.IStatusListener;
import azgracompress.quantization.QTrainIteration; import azgracompress.quantization.QTrainIteration;
import azgracompress.utilities.MinMaxResult; import azgracompress.utilities.MinMaxResult;
import azgracompress.utilities.Stopwatch; import azgracompress.utilities.Stopwatch;
...@@ -25,7 +26,7 @@ public class LloydMaxU16ScalarQuantization { ...@@ -25,7 +26,7 @@ public class LloydMaxU16ScalarQuantization {
private final int workerCount; private final int workerCount;
private boolean verbose = false; private IStatusListener statusListener = null;
public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize, final int workerCount) { public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize, final int workerCount) {
trainingData = trainData; trainingData = trainData;
...@@ -37,6 +38,10 @@ public class LloydMaxU16ScalarQuantization { ...@@ -37,6 +38,10 @@ public class LloydMaxU16ScalarQuantization {
this(trainData, codebookSize, 1); this(trainData, codebookSize, 1);
} }
public void setStatusListener(final IStatusListener listener) {
this.statusListener = listener;
}
private void initialize() { private void initialize() {
frequencies = new long[codebookSize]; frequencies = new long[codebookSize];
centroids = new int[codebookSize]; centroids = new int[codebookSize];
...@@ -57,20 +62,20 @@ public class LloydMaxU16ScalarQuantization { ...@@ -57,20 +62,20 @@ public class LloydMaxU16ScalarQuantization {
} }
} }
private void reportStatus(final String message) {
if (statusListener != null)
statusListener.sendMessage(message);
}
private void reportStatus(final String format, final Object... arg) {
reportStatus(String.format(format, arg));
}
private void initializeProbabilityDensityFunction() { private void initializeProbabilityDensityFunction() {
pdf = new double[U16.Max + 1]; pdf = new double[U16.Max + 1];
// Speedup - for now it is fast enough
Stopwatch s = new Stopwatch();
s.start();
for (final int trainingDatum : trainingData) { for (final int trainingDatum : trainingData) {
pdf[trainingDatum] += 1.0; pdf[trainingDatum] += 1.0;
} }
s.stop();
if (verbose) {
System.out.println("Init_PDF: " + s.getElapsedTimeString());
}
} }
private void recalculateBoundaryPoints() { private void recalculateBoundaryPoints() {
...@@ -142,8 +147,6 @@ public class LloydMaxU16ScalarQuantization { ...@@ -142,8 +147,6 @@ public class LloydMaxU16ScalarQuantization {
double mse = 0.0; double mse = 0.0;
resetFrequencies(); resetFrequencies();
Stopwatch s = new Stopwatch();
s.start();
if (workerCount > 1) { if (workerCount > 1) {
final int workSize = trainingData.length / workerCount; final int workSize = trainingData.length / workerCount;
...@@ -178,11 +181,6 @@ public class LloydMaxU16ScalarQuantization { ...@@ -178,11 +181,6 @@ public class LloydMaxU16ScalarQuantization {
mse += Math.pow((double) trainingDatum - (double) quantizedValue, 2); mse += Math.pow((double) trainingDatum - (double) quantizedValue, 2);
} }
} }
s.stop();
if (verbose) {
System.out.println("\nLloydMax: getCurrentMse time: " + s.getElapsedTimeString());
}
mse /= (double) trainingData.length; mse /= (double) trainingData.length;
return mse; return mse;
...@@ -195,14 +193,14 @@ public class LloydMaxU16ScalarQuantization { ...@@ -195,14 +193,14 @@ public class LloydMaxU16ScalarQuantization {
} }
} }
public QTrainIteration[] train(final boolean shouldBeVerbose) { public QTrainIteration[] train() {
this.verbose = shouldBeVerbose;
final int RECALCULATE_N_TIMES = 10; final int RECALCULATE_N_TIMES = 10;
final int PATIENCE = 1; final int PATIENCE = 1;
int noImprovementCounter = 0; int noImprovementCounter = 0;
if (verbose) {
System.out.println("Training data count: " + trainingData.length); reportStatus("LloydMax::train() - Worker count: %d", workerCount);
} reportStatus("LloydMax::train() - Training data count: %d", trainingData.length);
initialize(); initialize();
initializeProbabilityDensityFunction(); initializeProbabilityDensityFunction();
...@@ -221,15 +219,15 @@ public class LloydMaxU16ScalarQuantization { ...@@ -221,15 +219,15 @@ public class LloydMaxU16ScalarQuantization {
currentMse = getCurrentMse(); currentMse = getCurrentMse();
psnr = Utils.calculatePsnr(currentMse, U16.Max); psnr = Utils.calculatePsnr(currentMse, U16.Max);
if (verbose) { reportStatus("LloydMax::train() - Initial MSE: %f", currentMse);
System.out.println(String.format("Initial MSE: %f", currentMse));
}
solutionHistory.add(new QTrainIteration(0, currentMse, psnr)); solutionHistory.add(new QTrainIteration(0, currentMse, psnr));
double mseImprovement = 1; double mseImprovement = 1;
int iteration = 0; int iteration = 0;
Stopwatch stopwatch = new Stopwatch();
do { do {
stopwatch.restart();
for (int i = 0; i < RECALCULATE_N_TIMES; i++) { for (int i = 0; i < RECALCULATE_N_TIMES; i++) {
recalculateBoundaryPoints(); recalculateBoundaryPoints();
recalculateCentroids(); recalculateCentroids();
...@@ -242,11 +240,10 @@ public class LloydMaxU16ScalarQuantization { ...@@ -242,11 +240,10 @@ public class LloydMaxU16ScalarQuantization {
psnr = Utils.calculatePsnr(currentMse, U16.Max); psnr = Utils.calculatePsnr(currentMse, U16.Max);
solutionHistory.add(new QTrainIteration(++iteration, currentMse, psnr)); solutionHistory.add(new QTrainIteration(++iteration, currentMse, psnr));
if (verbose) { stopwatch.stop();
System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB", reportStatus("LloydMax::train() - Iteration: %d MSE: %f PSNR: %f Time: %s",
currentMse, iteration, currentMse,
psnr)); psnr, stopwatch.getElapsedTimeString());
}
if (mseImprovement < 1.0) { if (mseImprovement < 1.0) {
if ((++noImprovementCounter) >= PATIENCE) { if ((++noImprovementCounter) >= PATIENCE) {
...@@ -256,9 +253,7 @@ public class LloydMaxU16ScalarQuantization { ...@@ -256,9 +253,7 @@ public class LloydMaxU16ScalarQuantization {
} while (true); } while (true);
if (verbose) { reportStatus("LloydMax::train() - Optimization is finished.");
System.out.println("\nFinished training.");
}
return solutionHistory.toArray(new QTrainIteration[0]); return solutionHistory.toArray(new QTrainIteration[0]);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment