diff --git a/src/main/java/azgracompress/compression/VQImageCompressor.java b/src/main/java/azgracompress/compression/VQImageCompressor.java index de00719fdb44827f1c1a77c4ebad57d7762a0698..33cadffa9d106ab1ef7a559f583a5df9a44e3ca7 100644 --- a/src/main/java/azgracompress/compression/VQImageCompressor.java +++ b/src/main/java/azgracompress/compression/VQImageCompressor.java @@ -143,7 +143,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm assert (quantizer != null); Log("Compression plane..."); - final int[] indices = quantizer.quantizeIntoIndices(planeVectors); + final int[] indices = quantizer.quantizeIntoIndices(planeVectors, options.getWorkerCount()); try (OutBitStream outBitStream = new OutBitStream(compressStream, options.getBitsPerPixel(), 2048)) { outBitStream.write(indices); diff --git a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java index 72b4f1e09eb33c1f7c5ccd52c994c74d98fbd766..adb9b5217244cb8a10b31b7bf0ec62b065c59c80 100644 --- a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java @@ -31,13 +31,47 @@ public class VectorQuantizer { } public int[] quantizeIntoIndices(final int[][] dataVectors) { + return quantizeIntoIndices(dataVectors, 1); + } + + + public int[] quantizeIntoIndices(final int[][] dataVectors, final int maxWorkerCount) { + assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size"; int[] indices = new int[dataVectors.length]; - // Speedup - for (int vectorIndex = 0; vectorIndex < dataVectors.length; vectorIndex++) { - indices[vectorIndex] = findClosestCodebookEntryIndex(dataVectors[vectorIndex], - VectorDistanceMetric.Euclidean); + if (maxWorkerCount == 1) { + for (int vectorIndex = 0; vectorIndex < dataVectors.length; vectorIndex++) { + indices[vectorIndex] = findClosestCodebookEntryIndex(dataVectors[vectorIndex], + VectorDistanceMetric.Euclidean); + } + } else { + // Cap the worker count on 8 + final int workerCount = Math.min(maxWorkerCount, 8); + Thread[] workers = new Thread[workerCount]; + final int workSize = dataVectors.length / workerCount; + + for (int wId = 0; wId < workerCount; wId++) { + final int fromIndex = wId * workSize; + final int toIndex = (wId == workerCount - 1) ? dataVectors.length : (workSize + (wId * workSize)); + + workers[wId] = new Thread(() -> { + for (int vectorIndex = fromIndex; vectorIndex < toIndex; vectorIndex++) { + indices[vectorIndex] = findClosestCodebookEntryIndex(dataVectors[vectorIndex], + VectorDistanceMetric.Euclidean); + } + }); + + workers[wId].start(); + } + try { + for (int wId = 0; wId < workerCount; wId++) { + workers[wId].join(); + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + } return indices; }