diff --git a/src/main/java/azgracompress/benchmark/ScalarQuantizationBenchmark.java b/src/main/java/azgracompress/benchmark/ScalarQuantizationBenchmark.java index 7cd32233f7b98874c09a314290a0018d351f2666..44de36ee0c0789a2e854f0b004f816db4470b6d5 100644 --- a/src/main/java/azgracompress/benchmark/ScalarQuantizationBenchmark.java +++ b/src/main/java/azgracompress/benchmark/ScalarQuantizationBenchmark.java @@ -4,8 +4,9 @@ import azgracompress.U16; import azgracompress.cli.ParsedCliOptions; import azgracompress.data.V3i; import azgracompress.quantization.QTrainIteration; -import azgracompress.quantization.QuantizationValueCache; +import azgracompress.cache.QuantizationCacheManager; import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization; +import azgracompress.quantization.scalar.SQCodebook; import azgracompress.quantization.scalar.ScalarQuantizer; import java.io.File; @@ -36,16 +37,15 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase { ScalarQuantizer quantizer = null; if (hasCacheFolder) { System.out.println("Loading codebook from cache"); - QuantizationValueCache cache = new QuantizationValueCache(cacheFolder); - try { - final int[] quantizationValues = cache.readCachedValues(inputFile, codebookSize); - // TODO(Moravec): FIXME! - quantizer = null;//new ScalarQuantizer(U16.Min, U16.Max, quantizationValues); - } catch (IOException e) { + QuantizationCacheManager cacheManager = new QuantizationCacheManager(cacheFolder); + final SQCodebook codebook = cacheManager.loadSQCodebook(inputFile, codebookSize); + + if (codebook == null) { System.err.println("Failed to read quantization values from cache file."); - e.printStackTrace(); return; } + + quantizer = new ScalarQuantizer(codebook); System.out.println("Created quantizer from cache"); } else if (useMiddlePlane) { final int middlePlaneIndex = rawImageDims.getZ() / 2; @@ -67,17 +67,17 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase { return; } - if (!hasGeneralQuantizer) { - quantizer = trainLloydMaxQuantizer(planeData, codebookSize); System.out.println("Created plane quantizer"); } + if (quantizer == null) { System.err.println("Failed to initialize scalar quantizer."); return; } + // TODO(Moravec): Add huffman coding. final String quantizedFile = String.format(QUANTIZED_FILE_TEMPLATE, planeIndex, codebookSize); final String diffFile = String.format(DIFFERENCE_FILE_TEMPLATE, planeIndex, codebookSize); @@ -124,10 +124,7 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase { private ScalarQuantizer trainLloydMaxQuantizer(final int[] data, final int codebookSize) { LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize, workerCount); QTrainIteration[] trainingReport = lloydMax.train(false); - - //saveQTrainLog(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize), trainingReport); - - // TODO(Moravec): FIXME - return new ScalarQuantizer(U16.Min, U16.Max, null);//lloydMax.getCentroids()); +// saveQTrainLog(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize), trainingReport); + return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook()); } } diff --git a/src/main/java/azgracompress/benchmark/VectorQuantizationBenchmark.java b/src/main/java/azgracompress/benchmark/VectorQuantizationBenchmark.java index 5224d38e73a5e5e0a8766477945cb9cb7ced53c8..c3df11c0830cb40dfd12926be0e3ae6e38732f21 100644 --- a/src/main/java/azgracompress/benchmark/VectorQuantizationBenchmark.java +++ b/src/main/java/azgracompress/benchmark/VectorQuantizationBenchmark.java @@ -1,17 +1,14 @@ package azgracompress.benchmark; +import azgracompress.cache.QuantizationCacheManager; import azgracompress.cli.ParsedCliOptions; import azgracompress.data.*; -import azgracompress.quantization.QuantizationValueCache; -import azgracompress.quantization.vector.CodebookEntry; import azgracompress.quantization.vector.LBGResult; import azgracompress.quantization.vector.LBGVectorQuantizer; +import azgracompress.quantization.vector.VQCodebook; import azgracompress.quantization.vector.VectorQuantizer; import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStreamWriter; public class VectorQuantizationBenchmark extends BenchmarkBase { @@ -62,19 +59,13 @@ public class VectorQuantizationBenchmark extends BenchmarkBase { if (hasCacheFolder) { System.out.println("Loading codebook from cache"); - QuantizationValueCache cache = new QuantizationValueCache(cacheFolder); - try { - final CodebookEntry[] codebook = cache.readCachedValues(inputFile, - codebookSize, - qVector.getX(), - qVector.getY()); - quantizer = new VectorQuantizer(codebook); - - } catch (IOException e) { - e.printStackTrace(); + QuantizationCacheManager cacheManager = new QuantizationCacheManager(cacheFolder); + final VQCodebook codebook = cacheManager.loadVQCodebook(inputFile, codebookSize, qVector.toV3i()); + if (codebook == null) { System.err.println("Failed to read quantization vectors from cache."); return; } + quantizer = new VectorQuantizer(codebook); System.out.println("Created quantizer from cache"); } else if (useMiddlePlane) { final int middlePlaneIndex = rawImageDims.getZ() / 2; @@ -86,7 +77,10 @@ public class VectorQuantizationBenchmark extends BenchmarkBase { } final int[][] refPlaneData = getPlaneVectors(middlePlane, qVector); - LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(refPlaneData, codebookSize, workerCount); + LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(refPlaneData, + codebookSize, + workerCount, + qVector.toV3i()); final LBGResult vqResult = vqInitializer.findOptimalCodebook(); quantizer = new VectorQuantizer(vqResult.getCodebook()); System.out.println("Created quantizer from middle plane."); @@ -105,7 +99,10 @@ public class VectorQuantizationBenchmark extends BenchmarkBase { if (!hasGeneralQuantizer) { - LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeData, codebookSize,workerCount); + LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeData, + codebookSize, + workerCount, + qVector.toV3i()); LBGResult vqResult = vqInitializer.findOptimalCodebook(); quantizer = new VectorQuantizer(vqResult.getCodebook()); System.out.println("Created plane quantizer."); @@ -119,6 +116,8 @@ public class VectorQuantizationBenchmark extends BenchmarkBase { final int[][] quantizedData = quantizer.quantize(planeData, workerCount); + // TODO(Moravec): Add huffman coding. + final ImageU16 quantizedImage = reconstructImageFromQuantizedVectors(plane, quantizedData, qVector); if (!saveQuantizedPlaneData(quantizedImage.getData(), quantizedFile)) { @@ -129,25 +128,4 @@ public class VectorQuantizationBenchmark extends BenchmarkBase { saveDifference(plane.getData(), quantizedImage.getData(), diffFile, absoluteDiffFile); } } - - private void saveCodebook(final CodebookEntry[] codebook, final String codebookFile) { - final String outFile = getFileNamePathIntoOutDir(codebookFile); - try { - FileOutputStream fileStream = new FileOutputStream(outFile); - OutputStreamWriter writer = new OutputStreamWriter(fileStream); - - for (final CodebookEntry entry : codebook) { - writer.write(entry.getVectorString()); - } - - writer.flush(); - fileStream.flush(); - fileStream.close(); - } catch (IOException ioE) { - ioE.printStackTrace(); - System.err.println("Failed to save codebook vectors."); - } - } - - } diff --git a/src/main/java/azgracompress/cache/QuantizationCacheManager.java b/src/main/java/azgracompress/cache/QuantizationCacheManager.java index c908cd6e0ce25b9237b7a1fbc13bd01828edb404..dbd07c83d947793e49dab0e7e7109627868692a4 100644 --- a/src/main/java/azgracompress/cache/QuantizationCacheManager.java +++ b/src/main/java/azgracompress/cache/QuantizationCacheManager.java @@ -218,7 +218,7 @@ public class QuantizationCacheManager { * @param codebookSize Codebook size. * @return SQ codebook or null. */ - public SQCodebook readSQCodebook(final String trainFile, final int codebookSize) { + public SQCodebook loadSQCodebook(final String trainFile, final int codebookSize) { final SQCacheFile cacheFile = loadSQCacheFile(trainFile, codebookSize); if (cacheFile != null) return cacheFile.getCodebook(); @@ -234,7 +234,7 @@ public class QuantizationCacheManager { * @param vDim Quantization vector dimension. * @return VQ codebook. */ - public VQCodebook readVQCodebook(final String trainFile, + public VQCodebook loadVQCodebook(final String trainFile, final int codebookSize, final V3i vDim) { final VQCacheFile cacheFile = loadVQCacheFile(trainFile, codebookSize, vDim); diff --git a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java index 4519b12c88a86b47c1ef37de3a1cceb0d4c5b215..63094e6bd6901fea148998cac79862eed101ce70 100644 --- a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java +++ b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java @@ -1,6 +1,11 @@ package azgracompress.compression; import azgracompress.cli.ParsedCliOptions; +import azgracompress.compression.exception.ImageCompressionException; +import azgracompress.huffman.Huffman; +import azgracompress.io.OutBitStream; + +import java.io.DataOutputStream; public abstract class CompressorDecompressorBase { public static final int LONG_BYTES = 8; @@ -22,6 +27,12 @@ public abstract class CompressorDecompressorBase { return symbols; } + protected Huffman createHuffmanCoder(final int[] symbols, final long[] frequencies) { + Huffman huffman = new Huffman(symbols, frequencies); + huffman.buildHuffmanTree(); + return huffman; + } + protected int[] getPlaneIndicesForCompression() { if (options.hasPlaneIndexSet()) { return new int[]{options.getPlaneIndex()}; @@ -73,9 +84,32 @@ public abstract class CompressorDecompressorBase { /** * Get index of the middle plane. + * * @return Index of the middle plane. */ protected int getMiddlePlaneIndex() { return (options.getImageDimension().getZ() / 2); } + + /** + * Write huffman encoded indices to the compress stream. + * + * @param compressStream Compress stream. + * @param huffman Huffman encoder. + * @param indices Indices to write. + * @return Number of bytes written. + * @throws ImageCompressionException when fails to write to compress stream. + */ + protected long writeHuffmanEncodedIndices(DataOutputStream compressStream, + final Huffman huffman, + final int[] indices) throws ImageCompressionException { + try (OutBitStream outBitStream = new OutBitStream(compressStream, options.getBitsPerPixel(), 2048)) { + for (final int index : indices) { + outBitStream.write(huffman.getCode(index)); + } + return outBitStream.getBytesWritten(); + } catch (Exception ex) { + throw new ImageCompressionException("Unable to write indices to OutBitStream.", ex); + } + } } diff --git a/src/main/java/azgracompress/compression/SQImageCompressor.java b/src/main/java/azgracompress/compression/SQImageCompressor.java index 20bb0a222da8d20b0601ed840a924ecf82d78cde..0df7772f0699fc2b4b63e243c9b0d26b84ab27b3 100644 --- a/src/main/java/azgracompress/compression/SQImageCompressor.java +++ b/src/main/java/azgracompress/compression/SQImageCompressor.java @@ -1,15 +1,14 @@ package azgracompress.compression; import azgracompress.U16; +import azgracompress.cache.QuantizationCacheManager; import azgracompress.cli.ParsedCliOptions; import azgracompress.compression.exception.ImageCompressionException; import azgracompress.data.ImageU16; import azgracompress.huffman.Huffman; -import azgracompress.io.OutBitStream; import azgracompress.io.RawDataIO; -import azgracompress.quantization.QuantizationValueCache; import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization; -import azgracompress.quantization.scalar.ScalarQuantizationCodebook; +import azgracompress.quantization.scalar.SQCodebook; import azgracompress.quantization.scalar.ScalarQuantizer; import azgracompress.utilities.Stopwatch; @@ -45,7 +44,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm */ private void writeCodebookToOutputStream(final ScalarQuantizer quantizer, DataOutputStream compressStream) throws ImageCompressionException { - final ScalarQuantizationCodebook codebook = quantizer.getCodebook(); + final SQCodebook codebook = quantizer.getCodebook(); final int[] centroids = codebook.getCentroids(); final long[] frequencies = codebook.getSymbolFrequencies(); try { @@ -70,16 +69,13 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm * @throws ImageCompressionException when fails to read cached codebook. */ private ScalarQuantizer loadQuantizerFromCache() throws ImageCompressionException { - QuantizationValueCache cache = new QuantizationValueCache(options.getCodebookCacheFolder()); - try { + QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder()); - final int[] quantizationValues = cache.readCachedValues(options.getInputFile(), - codebookSize); - // TODO(Moravec): FIXME the null value. - return new ScalarQuantizer(U16.Min, U16.Max, null); - } catch (IOException e) { - throw new ImageCompressionException("Failed to read quantization values from cache file.", e); + final SQCodebook codebook = cacheManager.loadSQCodebook(options.getInputFile(), codebookSize); + if (codebook == null) { + throw new ImageCompressionException("Failed to read quantization values from cache file."); } + return new ScalarQuantizer(codebook); } /** @@ -95,10 +91,12 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm Huffman huffman = null; final int[] huffmanSymbols = createHuffmanSymbols(); if (options.hasCodebookCacheFolder()) { - // TODO(Moravec): Create huffman. Log("Loading codebook from cache file."); + quantizer = loadQuantizerFromCache(); - Log("Cached quantizer created."); + huffman = createHuffmanCoder(huffmanSymbols, quantizer.getCodebook().getSymbolFrequencies()); + + Log("Cached quantizer with huffman coder created."); writeCodebookToOutputStream(quantizer, compressStream); } else if (options.shouldUseMiddlePlane()) { stopwatch.restart(); @@ -108,19 +106,17 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm middlePlane = RawDataIO.loadImageU16(options.getInputFile(), options.getImageDimension(), getMiddlePlaneIndex()); - // TODO(Moravec): Create huffman. } catch (Exception ex) { throw new ImageCompressionException("Unable to load plane data.", ex); } - Log(String.format("Training scalar quantizer from middle plane %d.", middlePlaneIndex)); quantizer = trainScalarQuantizerFromData(middlePlane.getData()); - stopwatch.stop(); + huffman = createHuffmanCoder(huffmanSymbols, quantizer.getCodebook().getSymbolFrequencies()); + stopwatch.stop(); writeCodebookToOutputStream(quantizer, compressStream); - - Log("Middle plane codebook created in: " + stopwatch.getElapsedTimeString()); + Log("Middle plane codebook with huffman coder created in: " + stopwatch.getElapsedTimeString()); } final int[] planeIndices = getPlaneIndicesForCompression(); @@ -155,32 +151,8 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm Log("Compressing plane..."); final int[] indices = quantizer.quantizeIntoIndices(plane.getData(), 1); - // //////////////////////// - // for (int i = 0; i < indices.length; i++) { - // final boolean[] huffmanCode = huffman.getCode(indices[i]); - // HuffmanNode currentHuffmanNode = huffman.getRoot(); - // boolean bit; - // int index = 0; - // while (!currentHuffmanNode.isLeaf()) { - // bit = huffmanCode[index++]; - // currentHuffmanNode = currentHuffmanNode.traverse(bit); - // } - // assert (indices[i] == currentHuffmanNode.getSymbol()); - // } - // //////////////////////////////// - - - try (OutBitStream outBitStream = new OutBitStream(compressStream, options.getBitsPerPixel(), 2048)) { - for (final int index : indices) { - outBitStream.write(huffman.getCode(index)); - } - planeDataSizes[planeCounter++] = outBitStream.getBytesWritten(); - //outBitStream.write(indices); - } catch (Exception ex) { - throw new ImageCompressionException("Unable to write indices to OutBitStream.", ex); - } + planeDataSizes[planeCounter++] = writeHuffmanEncodedIndices(compressStream,huffman, indices); - // TODO: Fill plane data size stopwatch.stop(); Log("Plane time: " + stopwatch.getElapsedTimeString()); Log(String.format("Finished processing of plane %d", planeIndex)); @@ -221,23 +193,21 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm @Override public void trainAndSaveCodebook() throws ImageCompressionException { - - int[] trainData = loadConfiguredPlanesData(); LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(trainData, codebookSize, options.getWorkerCount()); - Log("Starting LloydMax training."); lloydMax.train(options.isVerbose()); - final int[] qValues = lloydMax.getCentroids(); + final SQCodebook codebook = lloydMax.getCodebook(); + final int[] qValues = codebook.getCentroids(); Log("Finished LloydMax training."); Log(String.format("Saving cache file to %s", options.getOutputFile())); - QuantizationValueCache cache = new QuantizationValueCache(options.getOutputFile()); + QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getOutputFile()); try { - cache.saveQuantizationValues(options.getInputFile(), qValues); + cacheManager.saveCodebook(options.getInputFile(), codebook); } catch (IOException e) { throw new ImageCompressionException("Unable to write cache.", e); } diff --git a/src/main/java/azgracompress/compression/SQImageDecompressor.java b/src/main/java/azgracompress/compression/SQImageDecompressor.java index 5d71d0da92c46310a6f608f0fb9e7e3c5b04c0e5..20214473d20739c35784a499b9e327747d4ee78f 100644 --- a/src/main/java/azgracompress/compression/SQImageDecompressor.java +++ b/src/main/java/azgracompress/compression/SQImageDecompressor.java @@ -6,7 +6,7 @@ import azgracompress.fileformat.QCMPFileHeader; import azgracompress.huffman.Huffman; import azgracompress.huffman.HuffmanNode; import azgracompress.io.InBitStream; -import azgracompress.quantization.scalar.ScalarQuantizationCodebook; +import azgracompress.quantization.scalar.SQCodebook; import azgracompress.utilities.Stopwatch; import azgracompress.utilities.TypeConverter; @@ -19,7 +19,7 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I super(options); } - private ScalarQuantizationCodebook readScalarQuantizationValues(DataInputStream compressedStream) throws ImageDecompressionException { + private SQCodebook readScalarQuantizationValues(DataInputStream compressedStream) throws ImageDecompressionException { int[] quantizationValues = new int[codebookSize]; long[] symbolFrequencies = new long[codebookSize]; try { @@ -32,7 +32,7 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I } catch (IOException ioEx) { throw new ImageDecompressionException("Unable to read quantization values from compressed stream.", ioEx); } - return new ScalarQuantizationCodebook(quantizationValues, symbolFrequencies); + return new SQCodebook(quantizationValues, symbolFrequencies); } @Override @@ -51,14 +51,6 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I totalPlaneDataSize += planeDataSize; } - // // Data size of single plane indices. - // final long planeIndicesDataSize = - // (long) Math.ceil(((header.getImageSizeX() * header.getImageSizeY()) * header - // .getBitsPerPixel()) / 8.0); - // - // // All planes data size. - // final long allPlaneIndicesDataSize = planeIndicesDataSize * header.getImageSizeZ(); - return (codebookDataSize + totalPlaneDataSize); } @@ -74,14 +66,13 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I final int planePixelCount = header.getImageSizeX() * header.getImageSizeY(); final int planeIndicesDataSize = (int) Math.ceil((planePixelCount * header.getBitsPerPixel()) / 8.0); - int[] quantizationValues = null; + SQCodebook codebook = null; Huffman huffman = null; if (!header.isCodebookPerPlane()) { // There is only one codebook. - huffman = null; - // TODO(Moravec): Handle loading of Huffman. - Log("Loading codebook from cache..."); - //quantizationValues = readScalarQuantizationValues(compressedStream, codebookSize); + Log("Loading single codebook and huffman coder."); + codebook = readScalarQuantizationValues(compressedStream); + huffman = createHuffmanCoder(huffmanSymbols, codebook.getSymbolFrequencies()); } Stopwatch stopwatch = new Stopwatch(); @@ -89,13 +80,10 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I stopwatch.restart(); if (header.isCodebookPerPlane()) { Log("Loading plane codebook..."); - ScalarQuantizationCodebook codebook = readScalarQuantizationValues(compressedStream); - quantizationValues = codebook.getCentroids(); - huffman = new Huffman(huffmanSymbols, codebook.getSymbolFrequencies()); - huffman.buildHuffmanTree(); + codebook = readScalarQuantizationValues(compressedStream); + huffman = createHuffmanCoder(huffmanSymbols, codebook.getSymbolFrequencies()); } - assert (quantizationValues != null); - assert (huffman != null); + assert (codebook != null && huffman != null); Log(String.format("Decompressing plane %d...", planeIndex)); byte[] decompressedPlaneData = null; @@ -107,6 +95,7 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I inBitStream.setAllowReadFromUnderlyingStream(false); int[] decompressedValues = new int[planePixelCount]; + final int[] quantizationValues = codebook.getCentroids(); for (int pixel = 0; pixel < planePixelCount; pixel++) { HuffmanNode currentHuffmanNode = huffman.getRoot(); boolean bit; diff --git a/src/main/java/azgracompress/compression/VQImageCompressor.java b/src/main/java/azgracompress/compression/VQImageCompressor.java index 8bf557a8e91994dc6df339e78b7d15a09ef7e359..718680f014ceee1b16e655a7a0f76e03f8388825 100644 --- a/src/main/java/azgracompress/compression/VQImageCompressor.java +++ b/src/main/java/azgracompress/compression/VQImageCompressor.java @@ -1,16 +1,13 @@ package azgracompress.compression; +import azgracompress.cache.QuantizationCacheManager; import azgracompress.cli.ParsedCliOptions; import azgracompress.compression.exception.ImageCompressionException; import azgracompress.data.Chunk2D; import azgracompress.data.ImageU16; -import azgracompress.io.OutBitStream; +import azgracompress.huffman.Huffman; import azgracompress.io.RawDataIO; -import azgracompress.quantization.QuantizationValueCache; -import azgracompress.quantization.vector.CodebookEntry; -import azgracompress.quantization.vector.LBGResult; -import azgracompress.quantization.vector.LBGVectorQuantizer; -import azgracompress.quantization.vector.VectorQuantizer; +import azgracompress.quantization.vector.*; import azgracompress.utilities.Stopwatch; import java.io.DataOutputStream; @@ -29,7 +26,11 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm * @return Trained vector quantizer with codebook of set size. */ private VectorQuantizer trainVectorQuantizerFromPlaneVectors(final int[][] planeVectors) { - LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeVectors, codebookSize, options.getWorkerCount()); + + LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeVectors, + codebookSize, + options.getWorkerCount(), + options.getVectorDimension().toV3i()); LBGResult vqResult = vqInitializer.findOptimalCodebook(false); return new VectorQuantizer(vqResult.getCodebook()); } @@ -43,7 +44,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm */ private void writeQuantizerToCompressStream(final VectorQuantizer quantizer, DataOutputStream compressStream) throws ImageCompressionException { - final CodebookEntry[] codebook = quantizer.getCodebook(); + final CodebookEntry[] codebook = quantizer.getCodebookVectors(); try { for (final CodebookEntry entry : codebook) { final int[] entryVector = entry.getVector(); @@ -51,6 +52,10 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm compressStream.writeShort(vecVal); } } + final long[] frequencies = quantizer.getFrequencies(); + for (final long symbolFrequency : frequencies) { + compressStream.writeLong(symbolFrequency); + } } catch (IOException ioEx) { throw new ImageCompressionException("Unable to write codebook to compress stream.", ioEx); } @@ -66,17 +71,15 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm * @throws ImageCompressionException when fails to read cached codebook. */ private VectorQuantizer loadQuantizerFromCache() throws ImageCompressionException { - QuantizationValueCache cache = new QuantizationValueCache(options.getCodebookCacheFolder()); - try { - final CodebookEntry[] codebook = cache.readCachedValues(options.getInputFile(), - codebookSize, - options.getVectorDimension().getX(), - options.getVectorDimension().getY()); - return new VectorQuantizer(codebook); + QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder()); - } catch (IOException e) { - throw new ImageCompressionException("Failed to read quantization vectors from cache.", e); + final VQCodebook codebook = cacheManager.loadVQCodebook(options.getInputFile(), + codebookSize, + options.getVectorDimension().toV3i()); + if (codebook == null) { + throw new ImageCompressionException("Failed to read quantization vectors from cache."); } + return new VectorQuantizer(codebook); } /** @@ -86,15 +89,20 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm * @throws ImageCompressionException When compress process fails. */ public long[] compress(DataOutputStream compressStream) throws ImageCompressionException { - long[] planeDataSizes = new long[options.getImageDimension().getZ()]; + Stopwatch stopwatch = new Stopwatch(); final boolean hasGeneralQuantizer = options.hasCodebookCacheFolder() || options.shouldUseMiddlePlane(); + + + final int[] huffmanSymbols = createHuffmanSymbols(); VectorQuantizer quantizer = null; + Huffman huffman = null; if (options.hasCodebookCacheFolder()) { Log("Loading codebook from cache file."); quantizer = loadQuantizerFromCache(); - Log("Cached quantizer created."); + huffman = createHuffmanCoder(huffmanSymbols, quantizer.getFrequencies()); + Log("Cached quantizer with huffman coder created."); writeQuantizerToCompressStream(quantizer, compressStream); } else if (options.shouldUseMiddlePlane()) { stopwatch.restart(); @@ -103,8 +111,8 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm ImageU16 middlePlane = null; try { middlePlane = RawDataIO.loadImageU16(options.getInputFile(), - options.getImageDimension(), - middlePlaneIndex); + options.getImageDimension(), + middlePlaneIndex); } catch (Exception ex) { throw new ImageCompressionException("Unable to load plane data.", ex); } @@ -112,12 +120,15 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm Log(String.format("Training vector quantizer from middle plane %d.", middlePlaneIndex)); final int[][] refPlaneVectors = middlePlane.toQuantizationVectors(options.getVectorDimension()); quantizer = trainVectorQuantizerFromPlaneVectors(refPlaneVectors); + huffman = createHuffmanCoder(huffmanSymbols, quantizer.getFrequencies()); writeQuantizerToCompressStream(quantizer, compressStream); stopwatch.stop(); Log("Middle plane codebook created in: " + stopwatch.getElapsedTimeString()); } final int[] planeIndices = getPlaneIndicesForCompression(); + long[] planeDataSizes = new long[planeIndices.length]; + int planeCounter = 0; for (final int planeIndex : planeIndices) { stopwatch.restart(); @@ -137,21 +148,18 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm if (!hasGeneralQuantizer) { Log(String.format("Training vector quantizer from plane %d.", planeIndex)); quantizer = trainVectorQuantizerFromPlaneVectors(planeVectors); + huffman = createHuffmanCoder(huffmanSymbols, quantizer.getFrequencies()); writeQuantizerToCompressStream(quantizer, compressStream); Log("Wrote plane codebook."); } assert (quantizer != null); - Log("Compression plane..."); + Log("Compressing plane..."); final int[] indices = quantizer.quantizeIntoIndices(planeVectors, options.getWorkerCount()); - try (OutBitStream outBitStream = new OutBitStream(compressStream, options.getBitsPerPixel(), 2048)) { - outBitStream.write(indices); - } catch (Exception ex) { - throw new ImageCompressionException("Unable to write indices to OutBitStream.", ex); - } - // TODO: Fill plane data size + planeDataSizes[planeCounter++] = writeHuffmanEncodedIndices(compressStream, huffman, indices); + stopwatch.stop(); Log("Plane time: " + stopwatch.getElapsedTimeString()); Log(String.format("Finished processing of plane %d.", planeIndex)); @@ -222,18 +230,21 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm public void trainAndSaveCodebook() throws ImageCompressionException { final int[][] trainingData = loadConfiguredPlanesData(); - LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, codebookSize, options.getWorkerCount()); + LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, + codebookSize, + options.getWorkerCount(), + options.getVectorDimension().toV3i()); Log("Starting LBG optimization."); LBGResult lbgResult = vqInitializer.findOptimalCodebook(options.isVerbose()); Log("Learned the optimal codebook."); Log("Saving cache file to %s", options.getOutputFile()); - QuantizationValueCache cache = new QuantizationValueCache(options.getOutputFile()); + QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getOutputFile()); try { - cache.saveQuantizationValues(options.getInputFile(), lbgResult.getCodebook(), options.getVectorDimension()); + cacheManager.saveCodebook(options.getInputFile(), lbgResult.getCodebook()); } catch (IOException e) { - throw new ImageCompressionException("Unable to write cache.", e); + throw new ImageCompressionException("Unable to write VQ cache.", e); } Log("Operation completed."); } diff --git a/src/main/java/azgracompress/compression/VQImageDecompressor.java b/src/main/java/azgracompress/compression/VQImageDecompressor.java index 5045688502d048a0c1c993d40ecb98433143f13e..8965733987b0e87e3f1d324b4c8947c85e742014 100644 --- a/src/main/java/azgracompress/compression/VQImageDecompressor.java +++ b/src/main/java/azgracompress/compression/VQImageDecompressor.java @@ -12,6 +12,8 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +// TODO(Moravec): Handle huffman decoding. + public class VQImageDecompressor extends CompressorDecompressorBase implements IImageDecompressor { public VQImageDecompressor(ParsedCliOptions options) { super(options); diff --git a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java index a625461ba3b44fcd6b200c517dddd7784f98f4c1..14bf0a18b165bf334f44643d5fa263d878af9870 100644 --- a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java +++ b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java @@ -262,12 +262,8 @@ public class LloydMaxU16ScalarQuantization { return solutionHistory.toArray(new QTrainIteration[0]); } - public int[] getCentroids() { - return centroids; - } - - public ScalarQuantizationCodebook getCodebook() { - return new ScalarQuantizationCodebook(centroids, frequencies); + public SQCodebook getCodebook() { + return new SQCodebook(centroids, frequencies); } } diff --git a/src/main/java/azgracompress/quantization/scalar/ScalarQuantizationCodebook.java b/src/main/java/azgracompress/quantization/scalar/SQCodebook.java similarity index 57% rename from src/main/java/azgracompress/quantization/scalar/ScalarQuantizationCodebook.java rename to src/main/java/azgracompress/quantization/scalar/SQCodebook.java index 94378bc5f52c9be85a3e59deebabd572ee36d0c7..f55749e892de0902a707037fe51b2f83b4b2ec02 100644 --- a/src/main/java/azgracompress/quantization/scalar/ScalarQuantizationCodebook.java +++ b/src/main/java/azgracompress/quantization/scalar/SQCodebook.java @@ -1,7 +1,10 @@ package azgracompress.quantization.scalar; -public class ScalarQuantizationCodebook { +/** + * Codebook for scalar quantizer. + */ +public class SQCodebook { /** * Quantization values. */ @@ -12,26 +15,45 @@ public class ScalarQuantizationCodebook { */ final long[] indexFrequencies; + /** + * Size of the codebook. + */ final int codebookSize; /** * @param centroids Quantization values. * @param indexFrequencies Absolute frequencies of quantization values. */ - public ScalarQuantizationCodebook(final int[] centroids, final long[] indexFrequencies) { + public SQCodebook(final int[] centroids, final long[] indexFrequencies) { + assert (centroids.length == indexFrequencies.length); this.centroids = centroids; this.indexFrequencies = indexFrequencies; this.codebookSize = this.centroids.length; } + /** + * Get centroids (quantization values) from the codebook. + * + * @return Quantization values. + */ public int[] getCentroids() { return centroids; } + /** + * Get frequencies of codebook symbols at indices. + * + * @return Frequencies of symbols. + */ public long[] getSymbolFrequencies() { return indexFrequencies; } + /** + * Get codebook size. + * + * @return Codebook size. + */ public int getCodebookSize() { return codebookSize; } diff --git a/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java b/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java index 1caf0679452e6858b77ab010c0976d7b4d117d3d..e9fb89b0b0b0723badbfb4cf4979ab317a8d70e7 100644 --- a/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java +++ b/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java @@ -1,12 +1,14 @@ package azgracompress.quantization.scalar; +import azgracompress.U16; + public class ScalarQuantizer { private final int min; private final int max; - private final ScalarQuantizationCodebook codebook; + private final SQCodebook codebook; private int[] boundaryPoints; - public ScalarQuantizer(final int min, final int max, final ScalarQuantizationCodebook codebook) { + public ScalarQuantizer(final int min, final int max, final SQCodebook codebook) { this.codebook = codebook; boundaryPoints = new int[codebook.getCodebookSize() + 1]; this.min = min; @@ -15,6 +17,10 @@ public class ScalarQuantizer { calculateBoundaryPoints(); } + public ScalarQuantizer(final SQCodebook codebook) { + this(U16.Min, U16.Max, codebook); + } + public int[] quantize(int[] data) { int[] result = new int[data.length]; for (int i = 0; i < data.length; i++) { @@ -93,7 +99,21 @@ public class ScalarQuantizer { return mse; } - public ScalarQuantizationCodebook getCodebook() { + public SQCodebook getCodebook() { return codebook; } + + public long[] calculateFrequencies(int[] trainData) { + long[] frequencies = new long[codebook.getCodebookSize()]; + + // Speedup maybe? + for (final int value : trainData) { + for (int intervalId = 1; intervalId <= codebook.getCodebookSize(); intervalId++) { + if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) { + ++frequencies[intervalId - 1]; + } + } + } + return frequencies; + } } diff --git a/src/main/java/azgracompress/quantization/vector/LBGResult.java b/src/main/java/azgracompress/quantization/vector/LBGResult.java index f2a930de069c33aad44df49812246cffa9db8f1b..da4a27c5064defbe180639aff5e8c32a9f34e2d2 100644 --- a/src/main/java/azgracompress/quantization/vector/LBGResult.java +++ b/src/main/java/azgracompress/quantization/vector/LBGResult.java @@ -1,19 +1,29 @@ package azgracompress.quantization.vector; +import azgracompress.data.V3i; + public class LBGResult { - private final CodebookEntry[] codebook; + private final CodebookEntry[] codebookVectors; + private final long[] frequencies; private final double averageMse; private final double psnr; + private final V3i vectorDims; - public LBGResult(CodebookEntry[] codebook, double averageMse, double psnr) { - this.codebook = codebook; + public LBGResult(final V3i vectorDims, + final CodebookEntry[] codebook, + final long[] frequencies, + final double averageMse, + final double psnr) { + this.vectorDims = vectorDims; + this.codebookVectors = codebook; + this.frequencies = frequencies; this.averageMse = averageMse; this.psnr = psnr; } - public CodebookEntry[] getCodebook() { - return codebook; + public VQCodebook getCodebook() { + return new VQCodebook(vectorDims, codebookVectors, frequencies); } public double getAverageMse() { @@ -25,6 +35,6 @@ public class LBGResult { } public int getCodebookSize() { - return codebook.length; + return codebookVectors.length; } } diff --git a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java index 47ddf3d162ace1d53d9fe825bc4579eacf307333..a248ec5278c3d1a9d5a650cfc471b4ad165f78c2 100644 --- a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java @@ -1,6 +1,7 @@ package azgracompress.quantization.vector; import azgracompress.U16; +import azgracompress.data.V3i; import azgracompress.utilities.Stopwatch; import azgracompress.utilities.Utils; @@ -11,6 +12,7 @@ import java.util.Random; public class LBGVectorQuantizer { public final static double PRT_VECTOR_DIVIDER = 4.0; private final double EPSILON = 0.005; + final V3i vectorDimensions; private final int vectorSize; private final int codebookSize; private final int workerCount; @@ -20,12 +22,18 @@ public class LBGVectorQuantizer { private final TrainingVector[] trainingVectors; private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean; + private long[] frequencies; + boolean verbose = false; private double _mse = 0.0; - public LBGVectorQuantizer(final int[][] vectors, final int codebookSize, final int workerCount) { + public LBGVectorQuantizer(final int[][] vectors, + final int codebookSize, + final int workerCount, + final V3i vectorDimensions) { assert (vectors.length > 0) : "No training vectors provided"; + this.vectorDimensions = vectorDimensions; this.vectorSize = vectors[0].length; @@ -34,13 +42,6 @@ public class LBGVectorQuantizer { trainingVectors[i] = new TrainingVector(Arrays.copyOf(vectors[i], vectorSize)); } - // boolean allzero = true; - // for (int i = 0; i < vectors.length; i++) { - // if (!VectorQuantizer.isZeroVector(trainingVectors[i].getVector())) { - // allzero = false; - // } - // } - this.codebookSize = codebookSize; this.workerCount = workerCount; @@ -92,7 +93,7 @@ public class LBGVectorQuantizer { if (verbose) { System.out.println(String.format("Final MSE: %.4f\nFinal PSNR: %.4f (dB)", mse, psnr)); } - return new LBGResult(codebook, mse, psnr); + return new LBGResult(vectorDimensions, codebook, frequencies, mse, psnr); } /** @@ -135,7 +136,7 @@ public class LBGVectorQuantizer { if (verbose) { System.out.println(stopwatch); } - return new LBGResult(learningCodebookToCodebook(codebook), finalMse, psnr); + return new LBGResult(vectorDimensions, learningCodebookToCodebook(codebook), frequencies, finalMse, psnr); } /** @@ -161,6 +162,17 @@ public class LBGVectorQuantizer { _mse += threadMse; } + private void resetFrequencies() { + Arrays.fill(frequencies, 0); + } + + private synchronized void addWorkerFrequencies(final long[] workerFrequencies) { + assert (frequencies.length == workerFrequencies.length) : "Frequency array length mismatch."; + for (int i = 0; i < frequencies.length; i++) { + frequencies[i] += workerFrequencies[i]; + } + } + /** * Calculate the average mean square error of the codebook. * @@ -170,6 +182,7 @@ public class LBGVectorQuantizer { private double averageMse(final CodebookEntry[] codebook) { Stopwatch s = Stopwatch.startNew("averageMse"); double mse = 0.0; + resetFrequencies(); if (workerCount > 1) { // Reset the global mse _mse = 0.0; @@ -180,7 +193,11 @@ public class LBGVectorQuantizer { final int toIndex = (wId == workerCount - 1) ? trainingVectors.length : (workSize + (wId * workSize)); workers[wId] = new Thread(() -> { - VectorQuantizer quantizer = new VectorQuantizer(codebook); + VectorQuantizer quantizer = new VectorQuantizer(new VQCodebook(vectorDimensions, + codebook, + frequencies)); + + double threadMse = 0.0; int cnt = 0; int[] vector; @@ -199,6 +216,7 @@ public class LBGVectorQuantizer { // Update global mse, updateMse function is synchronized. updateMse(threadMse); + addWorkerFrequencies(quantizer.getFrequencies()); }); workers[wId].start(); @@ -213,7 +231,9 @@ public class LBGVectorQuantizer { } mse = _mse / (double) workerCount; } else { - VectorQuantizer quantizer = new VectorQuantizer(codebook); + VectorQuantizer quantizer = new VectorQuantizer(new VQCodebook(vectorDimensions, + codebook, + frequencies)); for (final TrainingVector trV : trainingVectors) { int[] quantizedV = quantizer.quantize(trV.getVector()); @@ -373,8 +393,8 @@ public class LBGVectorQuantizer { LBG(codebook); + final double avgMse = averageMse(codebook); if (verbose) { - final double avgMse = averageMse(codebook); System.out.println(String.format("Average MSE: %.4f", avgMse)); } } diff --git a/src/main/java/azgracompress/quantization/vector/VQCodebook.java b/src/main/java/azgracompress/quantization/vector/VQCodebook.java new file mode 100644 index 0000000000000000000000000000000000000000..25548891ba0e4ebaafcffdbf01d47f3dfd50388c --- /dev/null +++ b/src/main/java/azgracompress/quantization/vector/VQCodebook.java @@ -0,0 +1,72 @@ +package azgracompress.quantization.vector; + +import azgracompress.data.V3i; + +/** + * Codebook for vector quantizer. + */ +public class VQCodebook { + /** + * Quantization vectors. + */ + private final CodebookEntry[] vectors; + + /** + * Absolute frequencies of quantization vectors. + */ + private long[] vectorFrequencies; + + /** + * Size of the codebook. + */ + private final int codebookSize; + + /** + * Vector dimensions. + */ + private final V3i vectorDims; + + public VQCodebook(final V3i vectorDims, final CodebookEntry[] vectors, final long[] vectorFrequencies) { + assert (vectors.length == vectorFrequencies.length); + this.vectorDims = vectorDims; + this.vectors = vectors; + this.vectorFrequencies = vectorFrequencies; + this.codebookSize = vectors.length; + } + + /** + * Get vectors (quantization vectors) from the codebook. + * + * @return Quantization vectors. + */ + public CodebookEntry[] getVectors() { + return vectors; + } + + /** + * Get frequencies of codebook vectors at indices. + * + * @return Frequencies of vectors. + */ + public long[] getVectorFrequencies() { + return vectorFrequencies; + } + + /** + * Get codebook size. + * + * @return Codebook size. + */ + public int getCodebookSize() { + return codebookSize; + } + + /** + * Get vector dimensions. + * + * @return Vector dimensions. + */ + public V3i getVectorDims() { + return vectorDims; + } +} diff --git a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java index 5ea733f455b71e568c6fae67bd8ed50309356a1e..afbb6b64994f89c97509bc47c10e0e7dc6bfb5b8 100644 --- a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java @@ -1,14 +1,21 @@ package azgracompress.quantization.vector; +import java.util.Arrays; + public class VectorQuantizer { private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean; - private final CodebookEntry[] codebook; + private final VQCodebook codebook; + private final CodebookEntry[] codebookVectors; private final int vectorSize; - public VectorQuantizer(final CodebookEntry[] codebook) { + private long[] frequencies; + + public VectorQuantizer(final VQCodebook codebook) { this.codebook = codebook; - vectorSize = codebook[0].getVector().length; + this.codebookVectors = codebook.getVectors(); + vectorSize = codebookVectors[0].getVector().length; + frequencies = codebook.getVectorFrequencies(); } public int[] quantize(final int[] dataVector) { @@ -29,7 +36,7 @@ public class VectorQuantizer { } else { final int[] indices = quantizeIntoIndices(dataVectors, workerCount); for (int i = 0; i < dataVectors.length; i++) { - result[i] = codebook[indices[i]].getVector(); + result[i] = codebookVectors[indices[i]].getVector(); } } @@ -40,15 +47,25 @@ public class VectorQuantizer { return quantizeIntoIndices(dataVectors, 1); } + private synchronized void addWorkerFrequencies(final long[] workerFrequencies) { + assert (frequencies.length == workerFrequencies.length) : "Frequency array length mismatch."; + for (int i = 0; i < frequencies.length; i++) { + frequencies[i] += workerFrequencies[i]; + } + } public int[] quantizeIntoIndices(final int[][] dataVectors, final int maxWorkerCount) { assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; int[] indices = new int[dataVectors.length]; + Arrays.fill(frequencies, 0); if (maxWorkerCount == 1) { + int closestIndex; for (int vectorIndex = 0; vectorIndex < dataVectors.length; vectorIndex++) { - indices[vectorIndex] = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric); + closestIndex = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric); + indices[vectorIndex] = closestIndex; + ++frequencies[closestIndex]; } } else { // Cap the worker count on 8 @@ -60,10 +77,16 @@ public class VectorQuantizer { final int fromIndex = wId * workSize; final int toIndex = (wId == workerCount - 1) ? dataVectors.length : (workSize + (wId * workSize)); + workers[wId] = new Thread(() -> { + int closestIndex; + long[] workerFrequencies = new long[codebookVectors.length]; for (int vectorIndex = fromIndex; vectorIndex < toIndex; vectorIndex++) { - indices[vectorIndex] = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric); + closestIndex = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric); + indices[vectorIndex] = closestIndex; + ++workerFrequencies[vectorIndex]; } + addWorkerFrequencies(workerFrequencies); }); workers[wId].start(); @@ -119,16 +142,16 @@ public class VectorQuantizer { } private CodebookEntry findClosestCodebookEntry(final int[] dataVector, final VectorDistanceMetric metric) { - return codebook[findClosestCodebookEntryIndex(dataVector, metric)]; + return codebookVectors[findClosestCodebookEntryIndex(dataVector, metric)]; } private int findClosestCodebookEntryIndex(final int[] dataVector, final VectorDistanceMetric metric) { double minDist = Double.MAX_VALUE; int closestEntryIndex = 0; - for (int entryIndex = 0; entryIndex < codebook.length; entryIndex++) { + for (int entryIndex = 0; entryIndex < codebookVectors.length; entryIndex++) { - final double dist = distanceBetweenVectors(dataVector, codebook[entryIndex].getVector(), metric); + final double dist = distanceBetweenVectors(dataVector, codebookVectors[entryIndex].getVector(), metric); if (dist < minDist) { minDist = dist; closestEntryIndex = entryIndex; @@ -138,8 +161,51 @@ public class VectorQuantizer { return closestEntryIndex; } - public CodebookEntry[] getCodebook() { - return codebook; + public CodebookEntry[] getCodebookVectors() { + return codebookVectors; + } + + public long[] getFrequencies() { + return frequencies; + } + + public long[] calculateFrequencies(int[][] dataVectors, final int maxWorkerCount) { + Arrays.fill(frequencies, 0); + assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; + + if (maxWorkerCount == 1) { + for (final int[] dataVector : dataVectors) { + ++frequencies[findClosestCodebookEntryIndex(dataVector, metric)]; + } + } else { + // Cap the worker count on 8 + final int workerCount = Math.min(maxWorkerCount, 8); + Thread[] workers = new Thread[workerCount]; + final int workSize = dataVectors.length / workerCount; + + for (int wId = 0; wId < workerCount; wId++) { + final int fromIndex = wId * workSize; + final int toIndex = (wId == workerCount - 1) ? dataVectors.length : (workSize + (wId * workSize)); + + workers[wId] = new Thread(() -> { + long[] workerFrequencies = new long[codebookVectors.length]; + for (int vectorIndex = fromIndex; vectorIndex < toIndex; vectorIndex++) { + ++workerFrequencies[findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric)]; + } + addWorkerFrequencies(workerFrequencies); + }); + + workers[wId].start(); + } + try { + for (int wId = 0; wId < workerCount; wId++) { + workers[wId].join(); + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + return frequencies; } }