Skip to content
Snippets Groups Projects
Commit 86e6a7e8 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Implemented train codebook method.

parent 6dee5fbc
Branches
No related tags found
No related merge requests found
Showing
with 230 additions and 27 deletions
...@@ -5,8 +5,6 @@ import azgracompress.cli.CliConstants; ...@@ -5,8 +5,6 @@ import azgracompress.cli.CliConstants;
import azgracompress.cli.ParsedCliOptions; import azgracompress.cli.ParsedCliOptions;
import azgracompress.compression.ImageCompressor; import azgracompress.compression.ImageCompressor;
import azgracompress.compression.ImageDecompressor; import azgracompress.compression.ImageDecompressor;
import azgracompress.quantization.QuantizationValueCache;
import azgracompress.quantization.vector.CodebookEntry;
import org.apache.commons.cli.*; import org.apache.commons.cli.*;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
...@@ -38,7 +36,9 @@ public class DataCompressor { ...@@ -38,7 +36,9 @@ public class DataCompressor {
return; return;
} }
// System.out.println(parsedCliOptions.report()); if (parsedCliOptions.isVerbose()) {
System.out.println(parsedCliOptions.report());
}
switch (parsedCliOptions.getMethod()) { switch (parsedCliOptions.getMethod()) {
...@@ -57,10 +57,16 @@ public class DataCompressor { ...@@ -57,10 +57,16 @@ public class DataCompressor {
return; return;
} }
case Benchmark: { case Benchmark: {
System.out.println(parsedCliOptions.report());
CompressionBenchmark.runBenchmark(parsedCliOptions); CompressionBenchmark.runBenchmark(parsedCliOptions);
return; return;
} }
case TrainCodebook: {
ImageCompressor compressor = new ImageCompressor(parsedCliOptions);
if (!compressor.trainAndSaveCodebook()) {
System.err.println("Errors occurred during training/saving of codebook.");
}
return;
}
case PrintHelp: { case PrintHelp: {
formatter.printHelp(CliConstants.MAIN_HELP, options); formatter.printHelp(CliConstants.MAIN_HELP, options);
} }
...@@ -104,6 +110,11 @@ public class DataCompressor { ...@@ -104,6 +110,11 @@ public class DataCompressor {
false, false,
"Benchmark")); "Benchmark"));
methodGroup.addOption(new Option(CliConstants.TRAIN_SHORT,
CliConstants.TRAIN_LONG,
false,
"Train codebook and save learned codebook to cache file."));
methodGroup.addOption(new Option(CliConstants.HELP_SHORT, CliConstants.HELP_LONG, false, "Print help")); methodGroup.addOption(new Option(CliConstants.HELP_SHORT, CliConstants.HELP_LONG, false, "Print help"));
OptionGroup compressionMethodGroup = new OptionGroup(); OptionGroup compressionMethodGroup = new OptionGroup();
...@@ -128,7 +139,12 @@ public class DataCompressor { ...@@ -128,7 +139,12 @@ public class DataCompressor {
CliConstants.VERBOSE_LONG, CliConstants.VERBOSE_LONG,
false, false,
"Make program verbose")); "Make program verbose"));
// options.addRequiredOption(INPUT_SHORT, INPUT_LONG, true, "Input file");
options.addOption(new Option(CliConstants.WORKER_COUNT_SHORT,
CliConstants.WORKER_COUNT_LONG,
true,
"Number of worker threads"));
options.addOption(CliConstants.OUTPUT_SHORT, CliConstants.OUTPUT_LONG, true, "Custom output file"); options.addOption(CliConstants.OUTPUT_SHORT, CliConstants.OUTPUT_LONG, true, "Custom output file");
return options; return options;
} }
......
...@@ -16,6 +16,9 @@ public class CliConstants { ...@@ -16,6 +16,9 @@ public class CliConstants {
public static final String BENCHMARK_SHORT = "bench"; public static final String BENCHMARK_SHORT = "bench";
public static final String BENCHMARK_LONG = "benchmark"; public static final String BENCHMARK_LONG = "benchmark";
public static final String TRAIN_SHORT = "tcb";
public static final String TRAIN_LONG = "train-codebook";
public static final String INSPECT_SHORT = "i"; public static final String INSPECT_SHORT = "i";
public static final String INSPECT_LONG = "inspect"; public static final String INSPECT_LONG = "inspect";
...@@ -28,6 +31,9 @@ public class CliConstants { ...@@ -28,6 +31,9 @@ public class CliConstants {
public static final String VERBOSE_SHORT = "v"; public static final String VERBOSE_SHORT = "v";
public static final String VERBOSE_LONG = "verbose"; public static final String VERBOSE_LONG = "verbose";
public static final String WORKER_COUNT_SHORT = "wc";
public static final String WORKER_COUNT_LONG = "worker-count";
public static final String SCALAR_QUANTIZATION_SHORT = "sq"; public static final String SCALAR_QUANTIZATION_SHORT = "sq";
public static final String SCALAR_QUANTIZATION_LONG = "scalar-quantization"; public static final String SCALAR_QUANTIZATION_LONG = "scalar-quantization";
......
...@@ -38,6 +38,8 @@ public class ParsedCliOptions { ...@@ -38,6 +38,8 @@ public class ParsedCliOptions {
private int fromPlaneIndex; private int fromPlaneIndex;
private int toPlaneIndex; private int toPlaneIndex;
private int workerCount = 1;
public ParsedCliOptions(CommandLine cmdInput) { public ParsedCliOptions(CommandLine cmdInput) {
parseCLI(cmdInput); parseCLI(cmdInput);
} }
...@@ -94,6 +96,17 @@ public class ParsedCliOptions { ...@@ -94,6 +96,17 @@ public class ParsedCliOptions {
verbose = cmd.hasOption(CliConstants.VERBOSE_LONG); verbose = cmd.hasOption(CliConstants.VERBOSE_LONG);
if (cmd.hasOption(CliConstants.WORKER_COUNT_LONG)) {
final String wcString = cmd.getOptionValue(CliConstants.WORKER_COUNT_LONG);
ParseResult<Integer> pr = tryParseInt(wcString);
if (pr.isSuccess()) {
workerCount = pr.getValue();
} else {
errorOccurred = true;
errorBuilder.append("Unable to parse worker count. Expected int got: ").append(wcString).append('\n');
}
}
if (!errorOccurred) { if (!errorOccurred) {
outputFile = cmd.getOptionValue(CliConstants.OUTPUT_LONG, getDefaultOutputFilePath(inputFile)); outputFile = cmd.getOptionValue(CliConstants.OUTPUT_LONG, getDefaultOutputFilePath(inputFile));
} }
...@@ -230,8 +243,15 @@ public class ParsedCliOptions { ...@@ -230,8 +243,15 @@ public class ParsedCliOptions {
} }
} }
private boolean hasQuantizationType(final ProgramMethod method) {
return (method == ProgramMethod.Compress) ||
(method == ProgramMethod.Benchmark) ||
(method == ProgramMethod.TrainCodebook);
}
private void parseCompressionType(CommandLine cmd, StringBuilder errorBuilder) { private void parseCompressionType(CommandLine cmd, StringBuilder errorBuilder) {
if ((method == ProgramMethod.Compress) || (method == ProgramMethod.Benchmark)) { if (hasQuantizationType(method)) {
if (cmd.hasOption(CliConstants.SCALAR_QUANTIZATION_LONG)) { if (cmd.hasOption(CliConstants.SCALAR_QUANTIZATION_LONG)) {
quantizationType = QuantizationType.Scalar; quantizationType = QuantizationType.Scalar;
} else if (cmd.hasOption(CliConstants.VECTOR_QUANTIZATION_LONG)) { } else if (cmd.hasOption(CliConstants.VECTOR_QUANTIZATION_LONG)) {
...@@ -293,6 +313,8 @@ public class ParsedCliOptions { ...@@ -293,6 +313,8 @@ public class ParsedCliOptions {
method = ProgramMethod.Decompress; method = ProgramMethod.Decompress;
} else if (cmd.hasOption(CliConstants.BENCHMARK_LONG)) { } else if (cmd.hasOption(CliConstants.BENCHMARK_LONG)) {
method = ProgramMethod.Benchmark; method = ProgramMethod.Benchmark;
} else if (cmd.hasOption(CliConstants.TRAIN_LONG)) {
method = ProgramMethod.TrainCodebook;
} else if (cmd.hasOption(CliConstants.INSPECT_LONG)) { } else if (cmd.hasOption(CliConstants.INSPECT_LONG)) {
method = ProgramMethod.InspectFile; method = ProgramMethod.InspectFile;
} else { } else {
...@@ -378,6 +400,10 @@ public class ParsedCliOptions { ...@@ -378,6 +400,10 @@ public class ParsedCliOptions {
return toPlaneIndex; return toPlaneIndex;
} }
public int getWorkerCount() {
return workerCount;
}
public String report() { public String report() {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
...@@ -397,7 +423,8 @@ public class ParsedCliOptions { ...@@ -397,7 +423,8 @@ public class ParsedCliOptions {
break; break;
} }
if (method == ProgramMethod.Compress) {
if (hasQuantizationType(method)) {
sb.append("Quantization type: "); sb.append("Quantization type: ");
switch (quantizationType) { switch (quantizationType) {
case Scalar: case Scalar:
...@@ -417,7 +444,7 @@ public class ParsedCliOptions { ...@@ -417,7 +444,7 @@ public class ParsedCliOptions {
sb.append("Output: ").append(outputFile).append('\n'); sb.append("Output: ").append(outputFile).append('\n');
sb.append("InputFile: ").append(inputFile).append('\n'); sb.append("InputFile: ").append(inputFile).append('\n');
if (method == ProgramMethod.Compress) { if (hasQuantizationType(method)) {
sb.append("Input image dims: ").append(imageDimension.toString()).append('\n'); sb.append("Input image dims: ").append(imageDimension.toString()).append('\n');
} }
...@@ -432,6 +459,9 @@ public class ParsedCliOptions { ...@@ -432,6 +459,9 @@ public class ParsedCliOptions {
sb.append("ToPlaneIndex: ").append(toPlaneIndex).append('\n'); sb.append("ToPlaneIndex: ").append(toPlaneIndex).append('\n');
} }
sb.append("Verbose: ").append(verbose).append('\n');
sb.append("ThreadWorkerCount: ").append(workerCount).append('\n');
return sb.toString(); return sb.toString();
} }
......
...@@ -4,6 +4,7 @@ public enum ProgramMethod { ...@@ -4,6 +4,7 @@ public enum ProgramMethod {
Compress, Compress,
Decompress, Decompress,
Benchmark, Benchmark,
TrainCodebook,
PrintHelp, PrintHelp,
InspectFile InspectFile
} }
package azgracompress.compression; package azgracompress.compression;
import azgracompress.cli.ParsedCliOptions; import azgracompress.cli.ParsedCliOptions;
import azgracompress.compression.exception.ImageCompressionException;
import azgracompress.io.RawDataIO;
import java.io.IOException;
public abstract class CompressorDecompressorBase { public abstract class CompressorDecompressorBase {
public static final String EXTENSTION = ".QCMP"; public static final String EXTENSTION = ".QCMP";
...@@ -39,6 +43,38 @@ public abstract class CompressorDecompressorBase { ...@@ -39,6 +43,38 @@ public abstract class CompressorDecompressorBase {
return planeIndices; return planeIndices;
} }
protected int[] loadConfiguredPlanesData() throws ImageCompressionException {
int[] trainData = null;
if (options.hasPlaneIndexSet()) {
try {
Log("Loading single plane data.");
trainData = RawDataIO.loadImageU16(options.getInputFile(),
options.getImageDimension(),
options.getPlaneIndex()).getData();
} catch (IOException e) {
throw new ImageCompressionException("Failed to load reference image data.", e);
}
} else if (options.hasPlaneRangeSet()) {
Log("Loading plane range data.");
final int[] planes = getPlaneIndicesForCompression();
try {
trainData = RawDataIO.loadPlanesData(options.getInputFile(), options.getImageDimension(), planes);
} catch (IOException e) {
e.printStackTrace();
throw new ImageCompressionException("Failed to load plane range data.", e);
}
} else {
Log("Loading all planes data.");
try {
trainData = RawDataIO.loadAllPlanesData(options.getInputFile(), options.getImageDimension());
} catch (IOException e) {
throw new ImageCompressionException("Failed to load all planes data.", e);
}
}
return trainData;
}
protected void Log(final String message) { protected void Log(final String message) {
if (options.isVerbose()) { if (options.isVerbose()) {
System.out.println(message); System.out.println(message);
......
...@@ -8,8 +8,16 @@ public interface IImageCompressor { ...@@ -8,8 +8,16 @@ public interface IImageCompressor {
/** /**
* Compress the image planes. * Compress the image planes.
*
* @param compressStream Compressed data stream. * @param compressStream Compressed data stream.
* @throws ImageCompressionException when compression fails. * @throws ImageCompressionException when compression fails.
*/ */
void compress(DataOutputStream compressStream) throws ImageCompressionException; void compress(DataOutputStream compressStream) throws ImageCompressionException;
/**
* Train codebook from selected frames and save the learned codebook to cache file.
*
* @throws ImageCompressionException when training or saving fails.
*/
void trainAndSaveCodebook() throws ImageCompressionException;
} }
...@@ -44,6 +44,22 @@ public class ImageCompressor extends CompressorDecompressorBase { ...@@ -44,6 +44,22 @@ public class ImageCompressor extends CompressorDecompressorBase {
System.out.println(String.format("Compression ratio: %.5f", compressionRatio)); System.out.println(String.format("Compression ratio: %.5f", compressionRatio));
} }
public boolean trainAndSaveCodebook() {
Log("=== Training codebook ===");
IImageCompressor imageCompressor = getImageCompressor();
if (imageCompressor == null) {
return false;
}
try {
imageCompressor.trainAndSaveCodebook();
} catch (ImageCompressionException e) {
System.err.println(e.getMessage());
e.printStackTrace();
return false;
}
return true;
}
public boolean compress() { public boolean compress() {
IImageCompressor imageCompressor = getImageCompressor(); IImageCompressor imageCompressor = getImageCompressor();
......
...@@ -6,6 +6,7 @@ import azgracompress.compression.exception.ImageCompressionException; ...@@ -6,6 +6,7 @@ import azgracompress.compression.exception.ImageCompressionException;
import azgracompress.data.ImageU16; import azgracompress.data.ImageU16;
import azgracompress.io.OutBitStream; import azgracompress.io.OutBitStream;
import azgracompress.io.RawDataIO; import azgracompress.io.RawDataIO;
import azgracompress.quantization.QuantizationValueCache;
import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization; import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization;
import azgracompress.quantization.scalar.ScalarQuantizer; import azgracompress.quantization.scalar.ScalarQuantizer;
import azgracompress.utilities.Stopwatch; import azgracompress.utilities.Stopwatch;
...@@ -119,4 +120,29 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -119,4 +120,29 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
Log(String.format("Finished processing of plane %d", planeIndex)); Log(String.format("Finished processing of plane %d", planeIndex));
} }
} }
@Override
public void trainAndSaveCodebook() throws ImageCompressionException {
int[] trainData = loadConfiguredPlanesData();
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(trainData,
codebookSize,
options.getWorkerCount());
Log("Starting LloydMax training.");
lloydMax.train(options.isVerbose());
final int[] qValues = lloydMax.getCentroids();
Log("Finished LloydMax training.");
Log(String.format("Saving cache file to %s", options.getOutputFile()));
QuantizationValueCache cache = new QuantizationValueCache(options.getOutputFile());
try {
cache.saveQuantizationValues(options.getInputFile(), qValues);
} catch (IOException e) {
throw new ImageCompressionException("Unable to write cache.", e);
}
Log("Operation completed.");
}
} }
...@@ -147,5 +147,10 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -147,5 +147,10 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
} }
} }
@Override
public void trainAndSaveCodebook() throws ImageCompressionException {
throw new ImageCompressionException("Not implemented yet");
}
} }
...@@ -5,6 +5,7 @@ import azgracompress.data.V3i; ...@@ -5,6 +5,7 @@ import azgracompress.data.V3i;
import azgracompress.utilities.TypeConverter; import azgracompress.utilities.TypeConverter;
import java.io.*; import java.io.*;
import java.util.Arrays;
public class RawDataIO { public class RawDataIO {
/** /**
...@@ -48,6 +49,52 @@ public class RawDataIO { ...@@ -48,6 +49,52 @@ public class RawDataIO {
TypeConverter.unsignedShortBytesToIntArray(buffer)); TypeConverter.unsignedShortBytesToIntArray(buffer));
} }
public static int[] loadPlanesData(final String rawFile,
final V3i rawDataDims,
int[] planes) throws IOException {
if (planes.length < 1)
return new int[0];
final int planeValueCount = rawDataDims.getX() * rawDataDims.getY();
final long planeDataSize = 2 * (long) planeValueCount;
final long totalValueCount = (long) planeValueCount * planes.length;
int[] values = new int[(int) totalValueCount];
if (totalValueCount > (long) Integer.MAX_VALUE) {
throw new IOException("Integer count is too big.");
}
Arrays.sort(planes);
try (FileInputStream fileStream = new FileInputStream(rawFile);
DataInputStream dis = new DataInputStream(new BufferedInputStream(fileStream, 8192))) {
int lastIndex = 0;
int valIndex = 0;
for (final int planeIndex : planes) {
// Skip specific number of bytes to get to the next plane.
final int requestedSkip = (planeIndex == 0) ? 0 : ((planeIndex - lastIndex) - 1) * (int) planeDataSize;
lastIndex = planeIndex;
final int actualSkip = dis.skipBytes(requestedSkip);
if (requestedSkip != actualSkip) {
throw new IOException("Skip operation failed.");
}
for (int i = 0; i < planeValueCount; i++) {
values[valIndex++] = dis.readUnsignedShort();
}
}
}
return values;
}
public static int[] loadAllPlanesData(final String rawFile, final V3i imageDims) throws IOException { public static int[] loadAllPlanesData(final String rawFile, final V3i imageDims) throws IOException {
final long dataSize = (long) imageDims.getX() * (long) imageDims.getY() * (long) imageDims.getZ(); final long dataSize = (long) imageDims.getX() * (long) imageDims.getY() * (long) imageDims.getZ();
...@@ -93,4 +140,6 @@ public class RawDataIO { ...@@ -93,4 +140,6 @@ public class RawDataIO {
fileStream.flush(); fileStream.flush();
fileStream.close(); fileStream.close();
} }
} }
...@@ -14,8 +14,9 @@ public class QuantizationValueCache { ...@@ -14,8 +14,9 @@ public class QuantizationValueCache {
} }
private File getCacheFileForScalarValues(final String trainFile, final int quantizationValueCount) { private File getCacheFileForScalarValues(final String trainFile, final int quantizationValueCount) {
final File inputFile = new File(trainFile);
final File cacheFile = new File(cacheFolder, String.format("%s_%d_bits.qvc", final File cacheFile = new File(cacheFolder, String.format("%s_%d_bits.qvc",
trainFile, quantizationValueCount)); inputFile.getName(), quantizationValueCount));
return cacheFile; return cacheFile;
} }
...@@ -23,12 +24,13 @@ public class QuantizationValueCache { ...@@ -23,12 +24,13 @@ public class QuantizationValueCache {
final int codebookSize, final int codebookSize,
final int entryWidth, final int entryWidth,
final int entryHeight) { final int entryHeight) {
final File inputFile = new File(trainFile);
final File cacheFile = new File(cacheFolder, String.format("%s_%d_%dx%d.qvc", final File cacheFile = new File(cacheFolder, String.format("%s_%d_%dx%d.qvc",
trainFile, codebookSize, entryWidth, entryHeight)); inputFile.getName(), codebookSize, entryWidth, entryHeight));
return cacheFile; return cacheFile;
} }
public void saveQuantizationValues(final String trainFile, final int[] quantizationValues) { public void saveQuantizationValues(final String trainFile, final int[] quantizationValues) throws IOException {
final int quantizationValueCount = quantizationValues.length; final int quantizationValueCount = quantizationValues.length;
final String cacheFile = getCacheFileForScalarValues(trainFile, quantizationValueCount).getAbsolutePath(); final String cacheFile = getCacheFileForScalarValues(trainFile, quantizationValueCount).getAbsolutePath();
...@@ -38,13 +40,14 @@ public class QuantizationValueCache { ...@@ -38,13 +40,14 @@ public class QuantizationValueCache {
for (final int qv : quantizationValues) { for (final int qv : quantizationValues) {
dos.writeInt(qv); dos.writeInt(qv);
} }
} catch (IOException ioEx) { } catch (IOException ex) {
System.err.println("Failed to save scalar quantization values to cache."); throw new IOException(String.format("Failed to write cache to file: %s.\nInner Ex:\n%s",
ioEx.printStackTrace(); cacheFile,
ex.getMessage()));
} }
} }
public void saveQuantizationValues(final String trainFile, final CodebookEntry[] entries) { public void saveQuantizationValues(final String trainFile, final CodebookEntry[] entries) throws IOException {
final int codebookSize = entries.length; final int codebookSize = entries.length;
final int entryWidth = entries[0].getWidth(); final int entryWidth = entries[0].getWidth();
final int entryHeight = entries[0].getHeight(); final int entryHeight = entries[0].getHeight();
...@@ -66,9 +69,6 @@ public class QuantizationValueCache { ...@@ -66,9 +69,6 @@ public class QuantizationValueCache {
dos.writeInt(vectorValue); dos.writeInt(vectorValue);
} }
} }
} catch (IOException ioEx) {
System.err.println("Failed to save quantization vectors to cache.");
ioEx.printStackTrace();
} }
} }
......
...@@ -18,6 +18,8 @@ public class LloydMaxU16ScalarQuantization { ...@@ -18,6 +18,8 @@ public class LloydMaxU16ScalarQuantization {
private final int workerCount; private final int workerCount;
private boolean verbose = false;
public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize, final int workerCount) { public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize, final int workerCount) {
trainingData = trainData; trainingData = trainData;
this.codebookSize = codebookSize; this.codebookSize = codebookSize;
...@@ -51,8 +53,10 @@ public class LloydMaxU16ScalarQuantization { ...@@ -51,8 +53,10 @@ public class LloydMaxU16ScalarQuantization {
pdf[trainingData[i]] += 1; pdf[trainingData[i]] += 1;
} }
s.stop(); s.stop();
if (verbose) {
System.out.println("Init_PDF: " + s.getElapsedTimeString()); System.out.println("Init_PDF: " + s.getElapsedTimeString());
} }
}
private void recalculateBoundaryPoints() { private void recalculateBoundaryPoints() {
for (int j = 1; j < codebookSize; j++) { for (int j = 1; j < codebookSize; j++) {
...@@ -97,9 +101,9 @@ public class LloydMaxU16ScalarQuantization { ...@@ -97,9 +101,9 @@ public class LloydMaxU16ScalarQuantization {
private double getCurrentMse() { private double getCurrentMse() {
double mse = 0.0; double mse = 0.0;
if (workerCount > 1) {
Stopwatch s = new Stopwatch(); Stopwatch s = new Stopwatch();
s.start(); s.start();
if (workerCount > 1) {
// Speedup // Speedup
final int workSize = trainingData.length / workerCount; final int workSize = trainingData.length / workerCount;
...@@ -133,22 +137,28 @@ public class LloydMaxU16ScalarQuantization { ...@@ -133,22 +137,28 @@ public class LloydMaxU16ScalarQuantization {
} catch (InterruptedException e) { } catch (InterruptedException e) {
e.printStackTrace(); e.printStackTrace();
} }
s.stop();
System.out.println("\ngetCurrentMse time: " + s.getElapsedTimeString());
} else { } else {
for (final int trainingDatum : trainingData) { for (final int trainingDatum : trainingData) {
int quantizedValue = quantize(trainingDatum); int quantizedValue = quantize(trainingDatum);
mse += Math.pow((double) trainingDatum - (double) quantizedValue, 2); mse += Math.pow((double) trainingDatum - (double) quantizedValue, 2);
} }
} }
s.stop();
if (verbose) {
System.out.println("\ngetCurrentMse time: " + s.getElapsedTimeString());
}
mse /= (double) trainingData.length; mse /= (double) trainingData.length;
return mse; return mse;
} }
public QTrainIteration[] train(final boolean verbose) { public QTrainIteration[] train(final boolean shouldBeVerbose) {
System.out.println("Data len: " + trainingData.length); this.verbose = shouldBeVerbose;
if (verbose) {
System.out.println("Training data count: " + trainingData.length);
}
initialize(); initialize();
initializeProbabilityDensityFunction(); initializeProbabilityDensityFunction();
...@@ -184,7 +194,7 @@ public class LloydMaxU16ScalarQuantization { ...@@ -184,7 +194,7 @@ public class LloydMaxU16ScalarQuantization {
dist = (prevMse - currentMse) / currentMse; dist = (prevMse - currentMse) / currentMse;
if (verbose) { if (verbose) {
System.out.print(String.format("\rCurrent MSE: %.4f PSNR: %.4f dB", currentMse, psnr)); System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB", currentMse, psnr));
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment