diff --git a/src/main/java/azgracompress/cli/CliConstants.java b/src/main/java/azgracompress/cli/CliConstants.java index aa5dd83adeb0a67bc37ae3e1ea69a4c6a6c7c1f0..98f8dbbf623256397b11810c758c1399bb9621f4 100644 --- a/src/main/java/azgracompress/cli/CliConstants.java +++ b/src/main/java/azgracompress/cli/CliConstants.java @@ -51,7 +51,7 @@ public class CliConstants { public static final String VECTOR_QUANTIZATION_SHORT = "vq"; public static final String VECTOR_QUANTIZATION_LONG = "vector-quantization"; - public static final String USE_MIDDLE_PLANE_SHORT = "md"; + public static final String USE_MIDDLE_PLANE_SHORT = "mp"; public static final String USE_MIDDLE_PLANE_LONG = "middle-plane"; @NotNull diff --git a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java index 63094e6bd6901fea148998cac79862eed101ce70..2c5bf47e0c65debd3453637bba06cbb9f1ef3a11 100644 --- a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java +++ b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java @@ -12,14 +12,14 @@ public abstract class CompressorDecompressorBase { public static final String EXTENSION = ".QCMP"; protected final ParsedCliOptions options; - protected final int codebookSize; + private final int codebookSize; public CompressorDecompressorBase(ParsedCliOptions options) { this.options = options; this.codebookSize = (int) Math.pow(2, this.options.getBitsPerPixel()); } - protected int[] createHuffmanSymbols() { + protected int[] createHuffmanSymbols(final int codebookSize) { int[] symbols = new int[codebookSize]; for (int i = 0; i < codebookSize; i++) { symbols[i] = i; @@ -112,4 +112,8 @@ public abstract class CompressorDecompressorBase { throw new ImageCompressionException("Unable to write indices to OutBitStream.", ex); } } + + protected int getCodebookSize() { + return codebookSize; + } } diff --git a/src/main/java/azgracompress/compression/ImageDecompressor.java b/src/main/java/azgracompress/compression/ImageDecompressor.java index a12f6c83919c36145872e086146cbeaedd560f2e..8c9efe11dba4469d672f334d94a5d8f85b7915f1 100644 --- a/src/main/java/azgracompress/compression/ImageDecompressor.java +++ b/src/main/java/azgracompress/compression/ImageDecompressor.java @@ -97,9 +97,13 @@ public class ImageDecompressor extends CompressorDecompressorBase { break; } logBuilder.append("Bits per pixel:\t\t").append(header.getBitsPerPixel()).append('\n'); + logBuilder.append("Codebook:\t\t").append(header.isCodebookPerPlane() ? "one per plane\n" : "one for " + "all\n"); + final int codebookSize = (int)Math.pow(2, header.getBitsPerPixel()); + logBuilder.append("Codebook size:\t\t").append(codebookSize).append('\n'); + logBuilder.append("Image size X:\t\t").append(header.getImageSizeX()).append('\n'); logBuilder.append("Image size Y:\t\t").append(header.getImageSizeY()).append('\n'); logBuilder.append("Image size Z:\t\t").append(header.getImageSizeZ()).append('\n'); @@ -121,6 +125,10 @@ public class ImageDecompressor extends CompressorDecompressorBase { (fileSize / 1000), ((fileSize / 1000) / 1000))); logBuilder.append("Data size:\t\t").append(dataSize).append(" Bytes ").append(dataSize == expectedDataSize ? "(correct)\n" : "(INVALID)\n"); + + final long uncompressedSize = header.getImageDims().multiplyTogether() * 2; + final double compressionRatio = (double)fileSize / (double)uncompressedSize; + logBuilder.append(String.format("Compression ratio:\t%.5f\n", compressionRatio)); } } diff --git a/src/main/java/azgracompress/compression/SQImageCompressor.java b/src/main/java/azgracompress/compression/SQImageCompressor.java index 0df7772f0699fc2b4b63e243c9b0d26b84ab27b3..48654a8bc30b33e8fc6e85d768d9ef9ea06b2d47 100644 --- a/src/main/java/azgracompress/compression/SQImageCompressor.java +++ b/src/main/java/azgracompress/compression/SQImageCompressor.java @@ -28,8 +28,9 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm * @return Trained scalar quantizer. */ private ScalarQuantizer trainScalarQuantizerFromData(final int[] planeData) { + LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(planeData, - codebookSize, + getCodebookSize(), options.getWorkerCount()); lloydMax.train(false); return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook()); @@ -71,7 +72,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm private ScalarQuantizer loadQuantizerFromCache() throws ImageCompressionException { QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder()); - final SQCodebook codebook = cacheManager.loadSQCodebook(options.getInputFile(), codebookSize); + final SQCodebook codebook = cacheManager.loadSQCodebook(options.getInputFile(), getCodebookSize()); if (codebook == null) { throw new ImageCompressionException("Failed to read quantization values from cache file."); } @@ -89,7 +90,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm final boolean hasGeneralQuantizer = options.hasCodebookCacheFolder() || options.shouldUseMiddlePlane(); ScalarQuantizer quantizer = null; Huffman huffman = null; - final int[] huffmanSymbols = createHuffmanSymbols(); + final int[] huffmanSymbols = createHuffmanSymbols(getCodebookSize()); if (options.hasCodebookCacheFolder()) { Log("Loading codebook from cache file."); @@ -196,7 +197,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm int[] trainData = loadConfiguredPlanesData(); LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(trainData, - codebookSize, + getCodebookSize(), options.getWorkerCount()); Log("Starting LloydMax training."); lloydMax.train(options.isVerbose()); diff --git a/src/main/java/azgracompress/compression/SQImageDecompressor.java b/src/main/java/azgracompress/compression/SQImageDecompressor.java index 20214473d20739c35784a499b9e327747d4ee78f..1ef5f6ac5cdc349f0605d6977c5fbe94e77484f0 100644 --- a/src/main/java/azgracompress/compression/SQImageDecompressor.java +++ b/src/main/java/azgracompress/compression/SQImageDecompressor.java @@ -19,7 +19,8 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I super(options); } - private SQCodebook readScalarQuantizationValues(DataInputStream compressedStream) throws ImageDecompressionException { + private SQCodebook readScalarQuantizationValues(DataInputStream compressedStream, + final int codebookSize) throws ImageDecompressionException { int[] quantizationValues = new int[codebookSize]; long[] symbolFrequencies = new long[codebookSize]; try { @@ -59,8 +60,8 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I DataOutputStream decompressStream, QCMPFileHeader header) throws ImageDecompressionException { - final int[] huffmanSymbols = createHuffmanSymbols(); final int codebookSize = (int) Math.pow(2, header.getBitsPerPixel()); + final int[] huffmanSymbols = createHuffmanSymbols(codebookSize); final int planeCountForDecompression = header.getImageSizeZ(); final int planePixelCount = header.getImageSizeX() * header.getImageSizeY(); @@ -71,7 +72,7 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I if (!header.isCodebookPerPlane()) { // There is only one codebook. Log("Loading single codebook and huffman coder."); - codebook = readScalarQuantizationValues(compressedStream); + codebook = readScalarQuantizationValues(compressedStream, codebookSize); huffman = createHuffmanCoder(huffmanSymbols, codebook.getSymbolFrequencies()); } @@ -80,7 +81,7 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I stopwatch.restart(); if (header.isCodebookPerPlane()) { Log("Loading plane codebook..."); - codebook = readScalarQuantizationValues(compressedStream); + codebook = readScalarQuantizationValues(compressedStream, codebookSize); huffman = createHuffmanCoder(huffmanSymbols, codebook.getSymbolFrequencies()); } assert (codebook != null && huffman != null); diff --git a/src/main/java/azgracompress/compression/VQImageCompressor.java b/src/main/java/azgracompress/compression/VQImageCompressor.java index 718680f014ceee1b16e655a7a0f76e03f8388825..9960c8d6a6b32e609c99399f2cd9174b037e5d8d 100644 --- a/src/main/java/azgracompress/compression/VQImageCompressor.java +++ b/src/main/java/azgracompress/compression/VQImageCompressor.java @@ -28,7 +28,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm private VectorQuantizer trainVectorQuantizerFromPlaneVectors(final int[][] planeVectors) { LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeVectors, - codebookSize, + getCodebookSize(), options.getWorkerCount(), options.getVectorDimension().toV3i()); LBGResult vqResult = vqInitializer.findOptimalCodebook(false); @@ -74,7 +74,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder()); final VQCodebook codebook = cacheManager.loadVQCodebook(options.getInputFile(), - codebookSize, + getCodebookSize(), options.getVectorDimension().toV3i()); if (codebook == null) { throw new ImageCompressionException("Failed to read quantization vectors from cache."); @@ -94,7 +94,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm final boolean hasGeneralQuantizer = options.hasCodebookCacheFolder() || options.shouldUseMiddlePlane(); - final int[] huffmanSymbols = createHuffmanSymbols(); + final int[] huffmanSymbols = createHuffmanSymbols(getCodebookSize()); VectorQuantizer quantizer = null; Huffman huffman = null; @@ -231,7 +231,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm final int[][] trainingData = loadConfiguredPlanesData(); LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, - codebookSize, + getCodebookSize(), options.getWorkerCount(), options.getVectorDimension().toV3i()); Log("Starting LBG optimization."); diff --git a/src/main/java/azgracompress/compression/VQImageDecompressor.java b/src/main/java/azgracompress/compression/VQImageDecompressor.java index 8965733987b0e87e3f1d324b4c8947c85e742014..f37b1192d15f34efe10a620ba06ffc01caf206ca 100644 --- a/src/main/java/azgracompress/compression/VQImageDecompressor.java +++ b/src/main/java/azgracompress/compression/VQImageDecompressor.java @@ -4,7 +4,11 @@ import azgracompress.cli.ParsedCliOptions; import azgracompress.compression.exception.ImageDecompressionException; import azgracompress.data.*; import azgracompress.fileformat.QCMPFileHeader; +import azgracompress.huffman.Huffman; +import azgracompress.huffman.HuffmanNode; import azgracompress.io.InBitStream; +import azgracompress.quantization.vector.CodebookEntry; +import azgracompress.quantization.vector.VQCodebook; import azgracompress.utilities.Stopwatch; import azgracompress.utilities.TypeConverter; @@ -31,21 +35,29 @@ public class VQImageDecompressor extends CompressorDecompressorBase implements I return (long) Math.ceil((planeVectorCount * bpp) / 8.0); } - private int[][] readCodebookVectors(DataInputStream compressedStream, - final int codebookSize, - final int vectorSize) throws ImageDecompressionException { + private VQCodebook readCodebook(DataInputStream compressedStream, + final int codebookSize, + final int vectorSize) throws ImageDecompressionException { - int[][] codebook = new int[codebookSize][vectorSize]; + final CodebookEntry[] codebookVectors = new CodebookEntry[codebookSize]; + final long[] frequencies = new long[codebookSize]; try { for (int codebookIndex = 0; codebookIndex < codebookSize; codebookIndex++) { + final int[] vector = new int[vectorSize]; for (int vecIndex = 0; vecIndex < vectorSize; vecIndex++) { - codebook[codebookIndex][vecIndex] = compressedStream.readUnsignedShort(); + vector[vecIndex] = compressedStream.readUnsignedShort(); } + codebookVectors[codebookIndex] = new CodebookEntry(vector); + } + for (int codebookIndex = 0; codebookIndex < codebookSize; codebookIndex++) { + frequencies[codebookIndex] = compressedStream.readLong(); } } catch (IOException ioEx) { throw new ImageDecompressionException("Unable to read quantization values from compressed stream.", ioEx); } - return codebook; + + // We don't care about vector dimensions in here. + return new VQCodebook(new V3i(0), codebookVectors, frequencies); } @@ -73,19 +85,17 @@ public class VQImageDecompressor extends CompressorDecompressorBase implements I final int vectorDataSize = 2 * header.getVectorSizeX() * header.getVectorSizeY() * header.getVectorSizeZ(); // Total codebook size in bytes. - final long codebookDataSize = (codebookSize * vectorDataSize) * (header.isCodebookPerPlane() ? - header.getImageSizeZ() : 1); - - // Number of vectors per plane. - final long planeVectorCount = calculatePlaneVectorCount(header); - - // Data size of single plane indices. - final long planeDataSize = calculatePlaneDataSize(planeVectorCount, header.getBitsPerPixel()); - - // All planes data size. - final long allPlanesDataSize = planeDataSize * header.getImageSizeZ(); + final long codebookDataSize = ((codebookSize * vectorDataSize) + (codebookSize * LONG_BYTES)) * + (header.isCodebookPerPlane() ? header.getImageSizeZ() : 1); + + // Indices are encoded using huffman. Plane data size is written in the header. + long[] planeDataSizes = header.getPlaneDataSizes(); + long totalPlaneDataSize = 0; + for (final long planeDataSize : planeDataSizes) { + totalPlaneDataSize += planeDataSize; + } - return (codebookDataSize + allPlanesDataSize); + return (codebookDataSize + totalPlaneDataSize); } @Override @@ -97,15 +107,18 @@ public class VQImageDecompressor extends CompressorDecompressorBase implements I final int vectorSize = header.getVectorSizeX() * header.getVectorSizeY() * header.getVectorSizeZ(); final int planeCountForDecompression = header.getImageSizeZ(); final long planeVectorCount = calculatePlaneVectorCount(header); - final long planeDataSize = calculatePlaneDataSize(planeVectorCount, header.getBitsPerPixel()); + //final long planeDataSize = calculatePlaneDataSize(planeVectorCount, header.getBitsPerPixel()); final V2i qVector = new V2i(header.getVectorSizeX(), header.getVectorSizeY()); + final int[] huffmanSymbols = createHuffmanSymbols(codebookSize); - int[][] quantizationVectors = null; + VQCodebook codebook = null; + Huffman huffman = null; if (!header.isCodebookPerPlane()) { // There is only one codebook. Log("Loading codebook from cache..."); - quantizationVectors = readCodebookVectors(compressedStream, codebookSize, vectorSize); + codebook = readCodebook(compressedStream, codebookSize, vectorSize); + huffman = createHuffmanCoder(huffmanSymbols, codebook.getVectorFrequencies()); } Stopwatch stopwatch = new Stopwatch(); @@ -113,25 +126,31 @@ public class VQImageDecompressor extends CompressorDecompressorBase implements I stopwatch.restart(); if (header.isCodebookPerPlane()) { Log("Loading plane codebook..."); - quantizationVectors = readCodebookVectors(compressedStream, codebookSize, vectorSize); + codebook = readCodebook(compressedStream, codebookSize, vectorSize); + huffman = createHuffmanCoder(huffmanSymbols, codebook.getVectorFrequencies()); } - assert (quantizationVectors != null); + assert (codebook != null && huffman != null); Log(String.format("Decompressing plane %d...", planeIndex)); byte[] decompressedPlaneData = null; + final int planeDataSize = (int) header.getPlaneDataSizes()[planeIndex]; try (InBitStream inBitStream = new InBitStream(compressedStream, header.getBitsPerPixel(), - (int) planeDataSize)) { + planeDataSize)) { inBitStream.readToBuffer(); inBitStream.setAllowReadFromUnderlyingStream(false); - final int[] indices = inBitStream.readNValues((int) planeVectorCount); int[][] decompressedVectors = new int[(int) planeVectorCount][vectorSize]; for (int vecIndex = 0; vecIndex < planeVectorCount; vecIndex++) { - - System.arraycopy(quantizationVectors[indices[vecIndex]], + HuffmanNode currentHuffmanNode = huffman.getRoot(); + boolean bit; + while (!currentHuffmanNode.isLeaf()) { + bit = inBitStream.readBit(); + currentHuffmanNode = currentHuffmanNode.traverse(bit); + } + System.arraycopy(codebook.getVectors()[currentHuffmanNode.getSymbol()].getVector(), 0, decompressedVectors[vecIndex], 0, diff --git a/src/main/java/azgracompress/data/V3i.java b/src/main/java/azgracompress/data/V3i.java index cd238d77b39cfe330ad92a5fee25d409f44d867c..9d09bee0445c2216a7968ae374cf18f285b98519 100644 --- a/src/main/java/azgracompress/data/V3i.java +++ b/src/main/java/azgracompress/data/V3i.java @@ -12,7 +12,7 @@ public class V3i { } public V3i(final int x, final int y) { - this(x,y,1); + this(x, y, 1); } public V3i(final int universalValue) { @@ -67,4 +67,8 @@ public class V3i { public V2i toV2i() { return new V2i(x, y); } + + public long multiplyTogether() { + return (x * y * z); + } } diff --git a/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java b/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java index e9fb89b0b0b0723badbfb4cf4979ab317a8d70e7..480b3d99b2e4e577c3e67d470e071169a481d373 100644 --- a/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java +++ b/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java @@ -102,18 +102,4 @@ public class ScalarQuantizer { public SQCodebook getCodebook() { return codebook; } - - public long[] calculateFrequencies(int[] trainData) { - long[] frequencies = new long[codebook.getCodebookSize()]; - - // Speedup maybe? - for (final int value : trainData) { - for (int intervalId = 1; intervalId <= codebook.getCodebookSize(); intervalId++) { - if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) { - ++frequencies[intervalId - 1]; - } - } - } - return frequencies; - } } diff --git a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java index a248ec5278c3d1a9d5a650cfc471b4ad165f78c2..070d5d97868b25eb10f95016d19ff13b16170ca2 100644 --- a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java @@ -45,6 +45,7 @@ public class LBGVectorQuantizer { this.codebookSize = codebookSize; this.workerCount = workerCount; + frequencies = new long[this.codebookSize]; findUniqueVectors(); } @@ -167,8 +168,7 @@ public class LBGVectorQuantizer { } private synchronized void addWorkerFrequencies(final long[] workerFrequencies) { - assert (frequencies.length == workerFrequencies.length) : "Frequency array length mismatch."; - for (int i = 0; i < frequencies.length; i++) { + for (int i = 0; i < workerFrequencies.length; i++) { frequencies[i] += workerFrequencies[i]; } } @@ -193,30 +193,30 @@ public class LBGVectorQuantizer { final int toIndex = (wId == workerCount - 1) ? trainingVectors.length : (workSize + (wId * workSize)); workers[wId] = new Thread(() -> { + long[] workerFrequencies = new long[codebook.length]; VectorQuantizer quantizer = new VectorQuantizer(new VQCodebook(vectorDimensions, codebook, frequencies)); - double threadMse = 0.0; - int cnt = 0; int[] vector; - int[] quantizedVector; + int qIndex; + int[] qVector; for (int i = fromIndex; i < toIndex; i++) { - ++cnt; vector = trainingVectors[i].getVector(); - quantizedVector = quantizer.quantize(vector); + qIndex = quantizer.quantizeToIndex(vector); + ++workerFrequencies[qIndex]; + qVector = quantizer.getCodebookVectors()[qIndex].getVector(); for (int vI = 0; vI < vectorSize; vI++) { - threadMse += Math.pow(((double) vector[vI] - (double) quantizedVector[vI]), 2); + threadMse += Math.pow(((double) vector[vI] - (double) qVector[vI]), 2); } } - assert (cnt == toIndex - fromIndex); threadMse /= (double) (toIndex - fromIndex); // Update global mse, updateMse function is synchronized. updateMse(threadMse); - addWorkerFrequencies(quantizer.getFrequencies()); + addWorkerFrequencies(workerFrequencies); }); workers[wId].start(); @@ -234,11 +234,14 @@ public class LBGVectorQuantizer { VectorQuantizer quantizer = new VectorQuantizer(new VQCodebook(vectorDimensions, codebook, frequencies)); + int qIndex; + int[] qVector; for (final TrainingVector trV : trainingVectors) { - int[] quantizedV = quantizer.quantize(trV.getVector()); - + qIndex = quantizer.quantizeToIndex(trV.getVector()); + qVector = quantizer.getCodebookVectors()[qIndex].getVector(); + ++frequencies[qIndex]; for (int i = 0; i < vectorSize; i++) { - mse += Math.pow(((double) trV.getVector()[i] - (double) quantizedV[i]), 2); + mse += Math.pow(((double) trV.getVector()[i] - (double) qVector[i]), 2); } } mse /= (double) trainingVectors.length; diff --git a/src/main/java/azgracompress/quantization/vector/VQCodebook.java b/src/main/java/azgracompress/quantization/vector/VQCodebook.java index 25548891ba0e4ebaafcffdbf01d47f3dfd50388c..c2a53f64643a45d33cb568980c6a9f5deaf41037 100644 --- a/src/main/java/azgracompress/quantization/vector/VQCodebook.java +++ b/src/main/java/azgracompress/quantization/vector/VQCodebook.java @@ -27,7 +27,7 @@ public class VQCodebook { private final V3i vectorDims; public VQCodebook(final V3i vectorDims, final CodebookEntry[] vectors, final long[] vectorFrequencies) { - assert (vectors.length == vectorFrequencies.length); + //assert (vectors.length == vectorFrequencies.length); this.vectorDims = vectorDims; this.vectors = vectors; this.vectorFrequencies = vectorFrequencies; diff --git a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java index afbb6b64994f89c97509bc47c10e0e7dc6bfb5b8..94a44241eaa50c7b1cc184ffa7272b024bc11efa 100644 --- a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java @@ -1,21 +1,18 @@ package azgracompress.quantization.vector; -import java.util.Arrays; - public class VectorQuantizer { private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean; private final VQCodebook codebook; private final CodebookEntry[] codebookVectors; private final int vectorSize; - - private long[] frequencies; + private final long[] frequencies; public VectorQuantizer(final VQCodebook codebook) { this.codebook = codebook; this.codebookVectors = codebook.getVectors(); vectorSize = codebookVectors[0].getVector().length; - frequencies = codebook.getVectorFrequencies(); + this.frequencies = codebook.getVectorFrequencies(); } public int[] quantize(final int[] dataVector) { @@ -24,6 +21,11 @@ public class VectorQuantizer { return closestEntry.getVector(); } + public int quantizeToIndex(final int[] dataVector) { + assert (dataVector.length > 0 && dataVector.length % vectorSize == 0) : "Wrong vector size"; + return findClosestCodebookEntryIndex(dataVector, metric); + } + public int[][] quantize(final int[][] dataVectors, final int workerCount) { assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; int[][] result = new int[dataVectors.length][vectorSize]; @@ -47,25 +49,16 @@ public class VectorQuantizer { return quantizeIntoIndices(dataVectors, 1); } - private synchronized void addWorkerFrequencies(final long[] workerFrequencies) { - assert (frequencies.length == workerFrequencies.length) : "Frequency array length mismatch."; - for (int i = 0; i < frequencies.length; i++) { - frequencies[i] += workerFrequencies[i]; - } - } - public int[] quantizeIntoIndices(final int[][] dataVectors, final int maxWorkerCount) { assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; int[] indices = new int[dataVectors.length]; - Arrays.fill(frequencies, 0); if (maxWorkerCount == 1) { int closestIndex; for (int vectorIndex = 0; vectorIndex < dataVectors.length; vectorIndex++) { closestIndex = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric); indices[vectorIndex] = closestIndex; - ++frequencies[closestIndex]; } } else { // Cap the worker count on 8 @@ -80,13 +73,10 @@ public class VectorQuantizer { workers[wId] = new Thread(() -> { int closestIndex; - long[] workerFrequencies = new long[codebookVectors.length]; for (int vectorIndex = fromIndex; vectorIndex < toIndex; vectorIndex++) { closestIndex = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric); indices[vectorIndex] = closestIndex; - ++workerFrequencies[vectorIndex]; } - addWorkerFrequencies(workerFrequencies); }); workers[wId].start(); @@ -168,44 +158,5 @@ public class VectorQuantizer { public long[] getFrequencies() { return frequencies; } - - public long[] calculateFrequencies(int[][] dataVectors, final int maxWorkerCount) { - Arrays.fill(frequencies, 0); - assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; - - if (maxWorkerCount == 1) { - for (final int[] dataVector : dataVectors) { - ++frequencies[findClosestCodebookEntryIndex(dataVector, metric)]; - } - } else { - // Cap the worker count on 8 - final int workerCount = Math.min(maxWorkerCount, 8); - Thread[] workers = new Thread[workerCount]; - final int workSize = dataVectors.length / workerCount; - - for (int wId = 0; wId < workerCount; wId++) { - final int fromIndex = wId * workSize; - final int toIndex = (wId == workerCount - 1) ? dataVectors.length : (workSize + (wId * workSize)); - - workers[wId] = new Thread(() -> { - long[] workerFrequencies = new long[codebookVectors.length]; - for (int vectorIndex = fromIndex; vectorIndex < toIndex; vectorIndex++) { - ++workerFrequencies[findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric)]; - } - addWorkerFrequencies(workerFrequencies); - }); - - workers[wId].start(); - } - try { - for (int wId = 0; wId < workerCount; wId++) { - workers[wId].join(); - } - } catch (InterruptedException e) { - e.printStackTrace(); - } - } - return frequencies; - } }