diff --git a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java index d00696ea747d5405ee2e1826a95c3a345b1395f9..a0e043ad6bc22f742180348f67d6c6fe0773265c 100644 --- a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java @@ -1,5 +1,7 @@ package azgracompress.quantization.vector; +import azgracompress.kdtree.KDTree; +import azgracompress.kdtree.KDTreeBuilder; import azgracompress.utilities.Utils; public class VectorQuantizer { @@ -9,10 +11,14 @@ public class VectorQuantizer { private final int vectorSize; private final long[] frequencies; + private final KDTree kdTree; + public VectorQuantizer(final VQCodebook codebook) { this.codebookVectors = codebook.getVectors(); - vectorSize = codebookVectors[0].getVector().length; + this.vectorSize = codebookVectors[0].getVector().length; this.frequencies = codebook.getVectorFrequencies(); + + kdTree = new KDTreeBuilder(this.vectorSize, 8).buildTree(codebook.getRawVectors()); } public int[] quantize(final int[] dataVector) { @@ -49,8 +55,18 @@ public class VectorQuantizer { return quantizeIntoIndices(dataVectors, 1); } - public int[] quantizeIntoIndices(final int[][] dataVectors, final int maxWorkerCount) { + public int[] quantizeIntoIndicesUsingKDTree(final int[][] dataVectors, final int maxWorkerCount) { + assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; + int[] indices = new int[dataVectors.length]; + + for (int vectorIndex = 0; vectorIndex < dataVectors.length; vectorIndex++) { + indices[vectorIndex] = kdTree.findNearestBBF(dataVectors[vectorIndex], 32); + } + return indices; + } + + public int[] quantizeIntoIndices(final int[][] dataVectors, final int maxWorkerCount) { assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; int[] indices = new int[dataVectors.length];