diff --git a/src/main/java/azgracompress/compression/VQImageCompressor.java b/src/main/java/azgracompress/compression/VQImageCompressor.java index 8b522aabf37e0e93e4a98c1023672e78196a5d87..b32a93d272039b770bbfbdc2094085df45a116b5 100644 --- a/src/main/java/azgracompress/compression/VQImageCompressor.java +++ b/src/main/java/azgracompress/compression/VQImageCompressor.java @@ -159,6 +159,9 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm assert (quantizer != null); + // Use BestBinFirst KDTree for codebook lookup. +// final int[] indices = quantizer.quantizeIntoIndicesUsingKDTree(planeVectors, options.getWorkerCount()); + // Use BruteForce for codebook lookup. final int[] indices = quantizer.quantizeIntoIndices(planeVectors, options.getWorkerCount()); planeDataSizes[planeCounter++] = writeHuffmanEncodedIndices(compressStream, huffman, indices); diff --git a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java index a0e043ad6bc22f742180348f67d6c6fe0773265c..4fcc8caa6eb4d1b9f53f6989c58c305325091627 100644 --- a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java @@ -6,6 +6,10 @@ import azgracompress.utilities.Utils; public class VectorQuantizer { + private interface QuantizeVectorMethod { + int call(final int[] vector); + } + private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean; private final CodebookEntry[] codebookVectors; private final int vectorSize; @@ -55,26 +59,16 @@ public class VectorQuantizer { return quantizeIntoIndices(dataVectors, 1); } - public int[] quantizeIntoIndicesUsingKDTree(final int[][] dataVectors, final int maxWorkerCount) { - assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; - int[] indices = new int[dataVectors.length]; - - for (int vectorIndex = 0; vectorIndex < dataVectors.length; vectorIndex++) { + private int[] quantizeIntoIndicesImpl(final int[][] dataVectors, + final int maxWorkerCount, + final QuantizeVectorMethod method) { - indices[vectorIndex] = kdTree.findNearestBBF(dataVectors[vectorIndex], 32); - } - return indices; - } - - public int[] quantizeIntoIndices(final int[][] dataVectors, final int maxWorkerCount) { - assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; + assert (dataVectors.length > 0 && dataVectors[0].length == vectorSize) : "Wrong vector size"; int[] indices = new int[dataVectors.length]; if (maxWorkerCount == 1) { - int closestIndex; for (int vectorIndex = 0; vectorIndex < dataVectors.length; vectorIndex++) { - closestIndex = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric); - indices[vectorIndex] = closestIndex; + indices[vectorIndex] = method.call(dataVectors[vectorIndex]); } } else { // Cap the worker count on 8 @@ -86,12 +80,9 @@ public class VectorQuantizer { final int fromIndex = wId * workSize; final int toIndex = (wId == workerCount - 1) ? dataVectors.length : (workSize + (wId * workSize)); - workers[wId] = new Thread(() -> { - int closestIndex; for (int vectorIndex = fromIndex; vectorIndex < toIndex; vectorIndex++) { - closestIndex = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric); - indices[vectorIndex] = closestIndex; + indices[vectorIndex] = method.call(dataVectors[vectorIndex]); } }); @@ -104,11 +95,23 @@ public class VectorQuantizer { } catch (InterruptedException e) { e.printStackTrace(); } - } + return indices; } + public int[] quantizeIntoIndicesUsingKDTree(final int[][] dataVectors, final int maxWorkerCount) { + + return quantizeIntoIndicesImpl(dataVectors, maxWorkerCount, (final int[] queryVector) -> + kdTree.findNearestBBF(queryVector, 8)); + } + + public int[] quantizeIntoIndices(final int[][] dataVectors, final int maxWorkerCount) { + + return quantizeIntoIndicesImpl(dataVectors, maxWorkerCount, (final int[] queryVector) -> + findClosestCodebookEntryIndex(queryVector, metric)); + } + public static double distanceBetweenVectors(final int[] originalDataVector, final int[] codebookEntry, final VectorDistanceMetric metric) {