diff --git a/src/main/java/azgracompress/DataCompressor.java b/src/main/java/azgracompress/DataCompressor.java index c9717146585c95b74042b1ec046d7fd931a01012..ec68519a3d3a4525fda93cbc14f0be79ffc9fe3d 100644 --- a/src/main/java/azgracompress/DataCompressor.java +++ b/src/main/java/azgracompress/DataCompressor.java @@ -5,8 +5,6 @@ import azgracompress.cli.CliConstants; import azgracompress.cli.ParsedCliOptions; import azgracompress.compression.ImageCompressor; import azgracompress.compression.ImageDecompressor; -import azgracompress.quantization.QuantizationValueCache; -import azgracompress.quantization.vector.CodebookEntry; import org.apache.commons.cli.*; import org.jetbrains.annotations.NotNull; @@ -38,7 +36,9 @@ public class DataCompressor { return; } - // System.out.println(parsedCliOptions.report()); + if (parsedCliOptions.isVerbose()) { + System.out.println(parsedCliOptions.report()); + } switch (parsedCliOptions.getMethod()) { @@ -57,10 +57,16 @@ public class DataCompressor { return; } case Benchmark: { - System.out.println(parsedCliOptions.report()); CompressionBenchmark.runBenchmark(parsedCliOptions); return; } + case TrainCodebook: { + ImageCompressor compressor = new ImageCompressor(parsedCliOptions); + if (!compressor.trainAndSaveCodebook()) { + System.err.println("Errors occurred during training/saving of codebook."); + } + return; + } case PrintHelp: { formatter.printHelp(CliConstants.MAIN_HELP, options); } @@ -104,6 +110,11 @@ public class DataCompressor { false, "Benchmark")); + methodGroup.addOption(new Option(CliConstants.TRAIN_SHORT, + CliConstants.TRAIN_LONG, + false, + "Train codebook and save learned codebook to cache file.")); + methodGroup.addOption(new Option(CliConstants.HELP_SHORT, CliConstants.HELP_LONG, false, "Print help")); OptionGroup compressionMethodGroup = new OptionGroup(); @@ -128,7 +139,12 @@ public class DataCompressor { CliConstants.VERBOSE_LONG, false, "Make program verbose")); - // options.addRequiredOption(INPUT_SHORT, INPUT_LONG, true, "Input file"); + + options.addOption(new Option(CliConstants.WORKER_COUNT_SHORT, + CliConstants.WORKER_COUNT_LONG, + true, + "Number of worker threads")); + options.addOption(CliConstants.OUTPUT_SHORT, CliConstants.OUTPUT_LONG, true, "Custom output file"); return options; } diff --git a/src/main/java/azgracompress/cli/CliConstants.java b/src/main/java/azgracompress/cli/CliConstants.java index 1930897d1ba68ed21dca6af06f584952dab9b7b4..644e3003ecfb5a1a306b480c1903a93881a3aa4e 100644 --- a/src/main/java/azgracompress/cli/CliConstants.java +++ b/src/main/java/azgracompress/cli/CliConstants.java @@ -16,6 +16,9 @@ public class CliConstants { public static final String BENCHMARK_SHORT = "bench"; public static final String BENCHMARK_LONG = "benchmark"; + public static final String TRAIN_SHORT = "tcb"; + public static final String TRAIN_LONG = "train-codebook"; + public static final String INSPECT_SHORT = "i"; public static final String INSPECT_LONG = "inspect"; @@ -28,6 +31,9 @@ public class CliConstants { public static final String VERBOSE_SHORT = "v"; public static final String VERBOSE_LONG = "verbose"; + public static final String WORKER_COUNT_SHORT = "wc"; + public static final String WORKER_COUNT_LONG = "worker-count"; + public static final String SCALAR_QUANTIZATION_SHORT = "sq"; public static final String SCALAR_QUANTIZATION_LONG = "scalar-quantization"; diff --git a/src/main/java/azgracompress/cli/ParsedCliOptions.java b/src/main/java/azgracompress/cli/ParsedCliOptions.java index 15d25da640338b79f8fa3b5b32c7d913f2e1588e..d60f48881d2f70e2b3fd23d41e8c641c857622b7 100644 --- a/src/main/java/azgracompress/cli/ParsedCliOptions.java +++ b/src/main/java/azgracompress/cli/ParsedCliOptions.java @@ -38,6 +38,8 @@ public class ParsedCliOptions { private int fromPlaneIndex; private int toPlaneIndex; + private int workerCount = 1; + public ParsedCliOptions(CommandLine cmdInput) { parseCLI(cmdInput); } @@ -94,6 +96,17 @@ public class ParsedCliOptions { verbose = cmd.hasOption(CliConstants.VERBOSE_LONG); + if (cmd.hasOption(CliConstants.WORKER_COUNT_LONG)) { + final String wcString = cmd.getOptionValue(CliConstants.WORKER_COUNT_LONG); + ParseResult<Integer> pr = tryParseInt(wcString); + if (pr.isSuccess()) { + workerCount = pr.getValue(); + } else { + errorOccurred = true; + errorBuilder.append("Unable to parse worker count. Expected int got: ").append(wcString).append('\n'); + } + } + if (!errorOccurred) { outputFile = cmd.getOptionValue(CliConstants.OUTPUT_LONG, getDefaultOutputFilePath(inputFile)); } @@ -230,8 +243,15 @@ public class ParsedCliOptions { } } + private boolean hasQuantizationType(final ProgramMethod method) { + return (method == ProgramMethod.Compress) || + (method == ProgramMethod.Benchmark) || + (method == ProgramMethod.TrainCodebook); + } + private void parseCompressionType(CommandLine cmd, StringBuilder errorBuilder) { - if ((method == ProgramMethod.Compress) || (method == ProgramMethod.Benchmark)) { + if (hasQuantizationType(method)) { + if (cmd.hasOption(CliConstants.SCALAR_QUANTIZATION_LONG)) { quantizationType = QuantizationType.Scalar; } else if (cmd.hasOption(CliConstants.VECTOR_QUANTIZATION_LONG)) { @@ -293,6 +313,8 @@ public class ParsedCliOptions { method = ProgramMethod.Decompress; } else if (cmd.hasOption(CliConstants.BENCHMARK_LONG)) { method = ProgramMethod.Benchmark; + } else if (cmd.hasOption(CliConstants.TRAIN_LONG)) { + method = ProgramMethod.TrainCodebook; } else if (cmd.hasOption(CliConstants.INSPECT_LONG)) { method = ProgramMethod.InspectFile; } else { @@ -378,6 +400,10 @@ public class ParsedCliOptions { return toPlaneIndex; } + public int getWorkerCount() { + return workerCount; + } + public String report() { StringBuilder sb = new StringBuilder(); @@ -397,7 +423,8 @@ public class ParsedCliOptions { break; } - if (method == ProgramMethod.Compress) { + + if (hasQuantizationType(method)) { sb.append("Quantization type: "); switch (quantizationType) { case Scalar: @@ -417,7 +444,7 @@ public class ParsedCliOptions { sb.append("Output: ").append(outputFile).append('\n'); sb.append("InputFile: ").append(inputFile).append('\n'); - if (method == ProgramMethod.Compress) { + if (hasQuantizationType(method)) { sb.append("Input image dims: ").append(imageDimension.toString()).append('\n'); } @@ -432,6 +459,9 @@ public class ParsedCliOptions { sb.append("ToPlaneIndex: ").append(toPlaneIndex).append('\n'); } + sb.append("Verbose: ").append(verbose).append('\n'); + sb.append("ThreadWorkerCount: ").append(workerCount).append('\n'); + return sb.toString(); } diff --git a/src/main/java/azgracompress/cli/ProgramMethod.java b/src/main/java/azgracompress/cli/ProgramMethod.java index a7f689fde254c66b927df65d90d4ccbc034d6ee6..cf1e71ca08bf3fea3e91db49cefc1927fc78ae61 100644 --- a/src/main/java/azgracompress/cli/ProgramMethod.java +++ b/src/main/java/azgracompress/cli/ProgramMethod.java @@ -4,6 +4,7 @@ public enum ProgramMethod { Compress, Decompress, Benchmark, + TrainCodebook, PrintHelp, InspectFile } diff --git a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java index fa6c3b78a6f6c2b1cabe4c8e3db77b86ab50fc05..43d8e191df9153565816da250d42757386f73bd2 100644 --- a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java +++ b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java @@ -1,6 +1,10 @@ package azgracompress.compression; import azgracompress.cli.ParsedCliOptions; +import azgracompress.compression.exception.ImageCompressionException; +import azgracompress.io.RawDataIO; + +import java.io.IOException; public abstract class CompressorDecompressorBase { public static final String EXTENSTION = ".QCMP"; @@ -39,6 +43,38 @@ public abstract class CompressorDecompressorBase { return planeIndices; } + protected int[] loadConfiguredPlanesData() throws ImageCompressionException { + int[] trainData = null; + if (options.hasPlaneIndexSet()) { + try { + Log("Loading single plane data."); + trainData = RawDataIO.loadImageU16(options.getInputFile(), + options.getImageDimension(), + options.getPlaneIndex()).getData(); + } catch (IOException e) { + throw new ImageCompressionException("Failed to load reference image data.", e); + } + } else if (options.hasPlaneRangeSet()) { + Log("Loading plane range data."); + final int[] planes = getPlaneIndicesForCompression(); + try { + trainData = RawDataIO.loadPlanesData(options.getInputFile(), options.getImageDimension(), planes); + } catch (IOException e) { + e.printStackTrace(); + throw new ImageCompressionException("Failed to load plane range data.", e); + } + } else { + Log("Loading all planes data."); + try { + trainData = RawDataIO.loadAllPlanesData(options.getInputFile(), options.getImageDimension()); + } catch (IOException e) { + throw new ImageCompressionException("Failed to load all planes data.", e); + } + } + return trainData; + } + + protected void Log(final String message) { if (options.isVerbose()) { System.out.println(message); diff --git a/src/main/java/azgracompress/compression/IImageCompressor.java b/src/main/java/azgracompress/compression/IImageCompressor.java index 3dece34df62877e16d498ac283ce8f1c79929d3b..d5bb08f1aa054f26df2accb3dca80415bdfc1e91 100644 --- a/src/main/java/azgracompress/compression/IImageCompressor.java +++ b/src/main/java/azgracompress/compression/IImageCompressor.java @@ -8,8 +8,16 @@ public interface IImageCompressor { /** * Compress the image planes. + * * @param compressStream Compressed data stream. * @throws ImageCompressionException when compression fails. */ void compress(DataOutputStream compressStream) throws ImageCompressionException; + + /** + * Train codebook from selected frames and save the learned codebook to cache file. + * + * @throws ImageCompressionException when training or saving fails. + */ + void trainAndSaveCodebook() throws ImageCompressionException; } diff --git a/src/main/java/azgracompress/compression/ImageCompressor.java b/src/main/java/azgracompress/compression/ImageCompressor.java index 06ee3689bd5bc8a18a08661ee84802a44c1be8fd..20cb982594b32b98606cd17a66e439d28809225e 100644 --- a/src/main/java/azgracompress/compression/ImageCompressor.java +++ b/src/main/java/azgracompress/compression/ImageCompressor.java @@ -44,6 +44,22 @@ public class ImageCompressor extends CompressorDecompressorBase { System.out.println(String.format("Compression ratio: %.5f", compressionRatio)); } + public boolean trainAndSaveCodebook() { + Log("=== Training codebook ==="); + IImageCompressor imageCompressor = getImageCompressor(); + if (imageCompressor == null) { + return false; + } + try { + imageCompressor.trainAndSaveCodebook(); + } catch (ImageCompressionException e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + return false; + } + return true; + } + public boolean compress() { IImageCompressor imageCompressor = getImageCompressor(); diff --git a/src/main/java/azgracompress/compression/SQImageCompressor.java b/src/main/java/azgracompress/compression/SQImageCompressor.java index 6484150235fe96c466e14a353c9c2ff0a4a02cbb..765ab5eadf913801264d8c11708b9bdd364667ef 100644 --- a/src/main/java/azgracompress/compression/SQImageCompressor.java +++ b/src/main/java/azgracompress/compression/SQImageCompressor.java @@ -6,6 +6,7 @@ import azgracompress.compression.exception.ImageCompressionException; import azgracompress.data.ImageU16; import azgracompress.io.OutBitStream; import azgracompress.io.RawDataIO; +import azgracompress.quantization.QuantizationValueCache; import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization; import azgracompress.quantization.scalar.ScalarQuantizer; import azgracompress.utilities.Stopwatch; @@ -119,4 +120,29 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm Log(String.format("Finished processing of plane %d", planeIndex)); } } + + @Override + public void trainAndSaveCodebook() throws ImageCompressionException { + + + int[] trainData = loadConfiguredPlanesData(); + + LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(trainData, + codebookSize, + options.getWorkerCount()); + + Log("Starting LloydMax training."); + lloydMax.train(options.isVerbose()); + final int[] qValues = lloydMax.getCentroids(); + Log("Finished LloydMax training."); + + Log(String.format("Saving cache file to %s", options.getOutputFile())); + QuantizationValueCache cache = new QuantizationValueCache(options.getOutputFile()); + try { + cache.saveQuantizationValues(options.getInputFile(), qValues); + } catch (IOException e) { + throw new ImageCompressionException("Unable to write cache.", e); + } + Log("Operation completed."); + } } diff --git a/src/main/java/azgracompress/compression/VQImageCompressor.java b/src/main/java/azgracompress/compression/VQImageCompressor.java index 07f03c4d811eecf768b55e4622dc243e90f67677..3bc8ee5bcfb39aaae90bb8cc17aeb1e7706b8a5e 100644 --- a/src/main/java/azgracompress/compression/VQImageCompressor.java +++ b/src/main/java/azgracompress/compression/VQImageCompressor.java @@ -147,5 +147,10 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm } } + @Override + public void trainAndSaveCodebook() throws ImageCompressionException { + throw new ImageCompressionException("Not implemented yet"); + } + } diff --git a/src/main/java/azgracompress/io/RawDataIO.java b/src/main/java/azgracompress/io/RawDataIO.java index a0984118bb6efb404def09899e6f9ad3013e1a77..5cf711f0271b8642282e2343d0eef24533218d9f 100644 --- a/src/main/java/azgracompress/io/RawDataIO.java +++ b/src/main/java/azgracompress/io/RawDataIO.java @@ -5,6 +5,7 @@ import azgracompress.data.V3i; import azgracompress.utilities.TypeConverter; import java.io.*; +import java.util.Arrays; public class RawDataIO { /** @@ -48,6 +49,52 @@ public class RawDataIO { TypeConverter.unsignedShortBytesToIntArray(buffer)); } + public static int[] loadPlanesData(final String rawFile, + final V3i rawDataDims, + int[] planes) throws IOException { + + if (planes.length < 1) + return new int[0]; + + final int planeValueCount = rawDataDims.getX() * rawDataDims.getY(); + final long planeDataSize = 2 * (long) planeValueCount; + + final long totalValueCount = (long) planeValueCount * planes.length; + int[] values = new int[(int) totalValueCount]; + + + if (totalValueCount > (long) Integer.MAX_VALUE) { + throw new IOException("Integer count is too big."); + } + + Arrays.sort(planes); + + try (FileInputStream fileStream = new FileInputStream(rawFile); + DataInputStream dis = new DataInputStream(new BufferedInputStream(fileStream, 8192))) { + + int lastIndex = 0; + int valIndex = 0; + + for (final int planeIndex : planes) { + // Skip specific number of bytes to get to the next plane. + final int requestedSkip = (planeIndex == 0) ? 0 : ((planeIndex - lastIndex) - 1) * (int) planeDataSize; + lastIndex = planeIndex; + + final int actualSkip = dis.skipBytes(requestedSkip); + if (requestedSkip != actualSkip) { + throw new IOException("Skip operation failed."); + } + + for (int i = 0; i < planeValueCount; i++) { + values[valIndex++] = dis.readUnsignedShort(); + } + + } + } + + return values; + } + public static int[] loadAllPlanesData(final String rawFile, final V3i imageDims) throws IOException { final long dataSize = (long) imageDims.getX() * (long) imageDims.getY() * (long) imageDims.getZ(); @@ -93,4 +140,6 @@ public class RawDataIO { fileStream.flush(); fileStream.close(); } + + } diff --git a/src/main/java/azgracompress/quantization/QuantizationValueCache.java b/src/main/java/azgracompress/quantization/QuantizationValueCache.java index 05de1415f202268c6ef9c600a622a0fb5bfcbc9e..87a497245841fb49c55a891a06c2b0c528a2685a 100644 --- a/src/main/java/azgracompress/quantization/QuantizationValueCache.java +++ b/src/main/java/azgracompress/quantization/QuantizationValueCache.java @@ -14,8 +14,9 @@ public class QuantizationValueCache { } private File getCacheFileForScalarValues(final String trainFile, final int quantizationValueCount) { + final File inputFile = new File(trainFile); final File cacheFile = new File(cacheFolder, String.format("%s_%d_bits.qvc", - trainFile, quantizationValueCount)); + inputFile.getName(), quantizationValueCount)); return cacheFile; } @@ -23,12 +24,13 @@ public class QuantizationValueCache { final int codebookSize, final int entryWidth, final int entryHeight) { + final File inputFile = new File(trainFile); final File cacheFile = new File(cacheFolder, String.format("%s_%d_%dx%d.qvc", - trainFile, codebookSize, entryWidth, entryHeight)); + inputFile.getName(), codebookSize, entryWidth, entryHeight)); return cacheFile; } - public void saveQuantizationValues(final String trainFile, final int[] quantizationValues) { + public void saveQuantizationValues(final String trainFile, final int[] quantizationValues) throws IOException { final int quantizationValueCount = quantizationValues.length; final String cacheFile = getCacheFileForScalarValues(trainFile, quantizationValueCount).getAbsolutePath(); @@ -38,13 +40,14 @@ public class QuantizationValueCache { for (final int qv : quantizationValues) { dos.writeInt(qv); } - } catch (IOException ioEx) { - System.err.println("Failed to save scalar quantization values to cache."); - ioEx.printStackTrace(); + } catch (IOException ex) { + throw new IOException(String.format("Failed to write cache to file: %s.\nInner Ex:\n%s", + cacheFile, + ex.getMessage())); } } - public void saveQuantizationValues(final String trainFile, final CodebookEntry[] entries) { + public void saveQuantizationValues(final String trainFile, final CodebookEntry[] entries) throws IOException { final int codebookSize = entries.length; final int entryWidth = entries[0].getWidth(); final int entryHeight = entries[0].getHeight(); @@ -66,9 +69,6 @@ public class QuantizationValueCache { dos.writeInt(vectorValue); } } - } catch (IOException ioEx) { - System.err.println("Failed to save quantization vectors to cache."); - ioEx.printStackTrace(); } } diff --git a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java index f027dd4c981dc3e86197fdcc9dcfe8300f92947a..b850a9e6e17df4ccab37bbe2b3e5349a63065aa5 100644 --- a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java +++ b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java @@ -18,6 +18,8 @@ public class LloydMaxU16ScalarQuantization { private final int workerCount; + private boolean verbose = false; + public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize, final int workerCount) { trainingData = trainData; this.codebookSize = codebookSize; @@ -25,7 +27,7 @@ public class LloydMaxU16ScalarQuantization { } public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize) { - this(trainData, codebookSize,1); + this(trainData, codebookSize, 1); } private void initialize() { @@ -51,7 +53,9 @@ public class LloydMaxU16ScalarQuantization { pdf[trainingData[i]] += 1; } s.stop(); - System.out.println("Init_PDF: " + s.getElapsedTimeString()); + if (verbose) { + System.out.println("Init_PDF: " + s.getElapsedTimeString()); + } } private void recalculateBoundaryPoints() { @@ -97,9 +101,9 @@ public class LloydMaxU16ScalarQuantization { private double getCurrentMse() { double mse = 0.0; + Stopwatch s = new Stopwatch(); + s.start(); if (workerCount > 1) { - Stopwatch s = new Stopwatch(); - s.start(); // Speedup final int workSize = trainingData.length / workerCount; @@ -133,22 +137,28 @@ public class LloydMaxU16ScalarQuantization { } catch (InterruptedException e) { e.printStackTrace(); } - s.stop(); - System.out.println("\ngetCurrentMse time: " + s.getElapsedTimeString()); } else { for (final int trainingDatum : trainingData) { int quantizedValue = quantize(trainingDatum); mse += Math.pow((double) trainingDatum - (double) quantizedValue, 2); } } + s.stop(); + if (verbose) { + System.out.println("\ngetCurrentMse time: " + s.getElapsedTimeString()); + } mse /= (double) trainingData.length; return mse; } - public QTrainIteration[] train(final boolean verbose) { - System.out.println("Data len: " + trainingData.length); + public QTrainIteration[] train(final boolean shouldBeVerbose) { + this.verbose = shouldBeVerbose; + if (verbose) { + System.out.println("Training data count: " + trainingData.length); + } + initialize(); initializeProbabilityDensityFunction(); @@ -184,7 +194,7 @@ public class LloydMaxU16ScalarQuantization { dist = (prevMse - currentMse) / currentMse; if (verbose) { - System.out.print(String.format("\rCurrent MSE: %.4f PSNR: %.4f dB", currentMse, psnr)); + System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB", currentMse, psnr)); }