diff --git a/src/main/java/azgracompress/cli/CliConstants.java b/src/main/java/azgracompress/cli/CliConstants.java
index aa5dd83adeb0a67bc37ae3e1ea69a4c6a6c7c1f0..98f8dbbf623256397b11810c758c1399bb9621f4 100644
--- a/src/main/java/azgracompress/cli/CliConstants.java
+++ b/src/main/java/azgracompress/cli/CliConstants.java
@@ -51,7 +51,7 @@ public class CliConstants {
public static final String VECTOR_QUANTIZATION_SHORT = "vq";
public static final String VECTOR_QUANTIZATION_LONG = "vector-quantization";
- public static final String USE_MIDDLE_PLANE_SHORT = "md";
+ public static final String USE_MIDDLE_PLANE_SHORT = "mp";
public static final String USE_MIDDLE_PLANE_LONG = "middle-plane";
@NotNull
diff --git a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java
index 63094e6bd6901fea148998cac79862eed101ce70..2c5bf47e0c65debd3453637bba06cbb9f1ef3a11 100644
--- a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java
+++ b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java
@@ -12,14 +12,14 @@ public abstract class CompressorDecompressorBase {
public static final String EXTENSION = ".QCMP";
protected final ParsedCliOptions options;
- protected final int codebookSize;
+ private final int codebookSize;
public CompressorDecompressorBase(ParsedCliOptions options) {
this.options = options;
this.codebookSize = (int) Math.pow(2, this.options.getBitsPerPixel());
}
- protected int[] createHuffmanSymbols() {
+ protected int[] createHuffmanSymbols(final int codebookSize) {
int[] symbols = new int[codebookSize];
for (int i = 0; i < codebookSize; i++) {
symbols[i] = i;
@@ -112,4 +112,8 @@ public abstract class CompressorDecompressorBase {
throw new ImageCompressionException("Unable to write indices to OutBitStream.", ex);
}
}
+
+ protected int getCodebookSize() {
+ return codebookSize;
+ }
}
diff --git a/src/main/java/azgracompress/compression/ImageDecompressor.java b/src/main/java/azgracompress/compression/ImageDecompressor.java
index a12f6c83919c36145872e086146cbeaedd560f2e..8c9efe11dba4469d672f334d94a5d8f85b7915f1 100644
--- a/src/main/java/azgracompress/compression/ImageDecompressor.java
+++ b/src/main/java/azgracompress/compression/ImageDecompressor.java
@@ -97,9 +97,13 @@ public class ImageDecompressor extends CompressorDecompressorBase {
break;
}
logBuilder.append("Bits per pixel:\t\t").append(header.getBitsPerPixel()).append('\n');
+
logBuilder.append("Codebook:\t\t").append(header.isCodebookPerPlane() ? "one per plane\n" : "one for " +
"all\n");
+ final int codebookSize = (int)Math.pow(2, header.getBitsPerPixel());
+ logBuilder.append("Codebook size:\t\t").append(codebookSize).append('\n');
+
logBuilder.append("Image size X:\t\t").append(header.getImageSizeX()).append('\n');
logBuilder.append("Image size Y:\t\t").append(header.getImageSizeY()).append('\n');
logBuilder.append("Image size Z:\t\t").append(header.getImageSizeZ()).append('\n');
@@ -121,6 +125,10 @@ public class ImageDecompressor extends CompressorDecompressorBase {
(fileSize / 1000),
((fileSize / 1000) / 1000)));
logBuilder.append("Data size:\t\t").append(dataSize).append(" Bytes ").append(dataSize == expectedDataSize ? "(correct)\n" : "(INVALID)\n");
+
+ final long uncompressedSize = header.getImageDims().multiplyTogether() * 2;
+ final double compressionRatio = (double)fileSize / (double)uncompressedSize;
+ logBuilder.append(String.format("Compression ratio:\t%.5f\n", compressionRatio));
}
}
diff --git a/src/main/java/azgracompress/compression/SQImageCompressor.java b/src/main/java/azgracompress/compression/SQImageCompressor.java
index 0df7772f0699fc2b4b63e243c9b0d26b84ab27b3..48654a8bc30b33e8fc6e85d768d9ef9ea06b2d47 100644
--- a/src/main/java/azgracompress/compression/SQImageCompressor.java
+++ b/src/main/java/azgracompress/compression/SQImageCompressor.java
@@ -28,8 +28,9 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
* @return Trained scalar quantizer.
*/
private ScalarQuantizer trainScalarQuantizerFromData(final int[] planeData) {
+
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(planeData,
- codebookSize,
+ getCodebookSize(),
options.getWorkerCount());
lloydMax.train(false);
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook());
@@ -71,7 +72,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
private ScalarQuantizer loadQuantizerFromCache() throws ImageCompressionException {
QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder());
- final SQCodebook codebook = cacheManager.loadSQCodebook(options.getInputFile(), codebookSize);
+ final SQCodebook codebook = cacheManager.loadSQCodebook(options.getInputFile(), getCodebookSize());
if (codebook == null) {
throw new ImageCompressionException("Failed to read quantization values from cache file.");
}
@@ -89,7 +90,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
final boolean hasGeneralQuantizer = options.hasCodebookCacheFolder() || options.shouldUseMiddlePlane();
ScalarQuantizer quantizer = null;
Huffman huffman = null;
- final int[] huffmanSymbols = createHuffmanSymbols();
+ final int[] huffmanSymbols = createHuffmanSymbols(getCodebookSize());
if (options.hasCodebookCacheFolder()) {
Log("Loading codebook from cache file.");
@@ -196,7 +197,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
int[] trainData = loadConfiguredPlanesData();
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(trainData,
- codebookSize,
+ getCodebookSize(),
options.getWorkerCount());
Log("Starting LloydMax training.");
lloydMax.train(options.isVerbose());
diff --git a/src/main/java/azgracompress/compression/SQImageDecompressor.java b/src/main/java/azgracompress/compression/SQImageDecompressor.java
index 20214473d20739c35784a499b9e327747d4ee78f..1ef5f6ac5cdc349f0605d6977c5fbe94e77484f0 100644
--- a/src/main/java/azgracompress/compression/SQImageDecompressor.java
+++ b/src/main/java/azgracompress/compression/SQImageDecompressor.java
@@ -19,7 +19,8 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I
super(options);
}
- private SQCodebook readScalarQuantizationValues(DataInputStream compressedStream) throws ImageDecompressionException {
+ private SQCodebook readScalarQuantizationValues(DataInputStream compressedStream,
+ final int codebookSize) throws ImageDecompressionException {
int[] quantizationValues = new int[codebookSize];
long[] symbolFrequencies = new long[codebookSize];
try {
@@ -59,8 +60,8 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I
DataOutputStream decompressStream,
QCMPFileHeader header) throws ImageDecompressionException {
- final int[] huffmanSymbols = createHuffmanSymbols();
final int codebookSize = (int) Math.pow(2, header.getBitsPerPixel());
+ final int[] huffmanSymbols = createHuffmanSymbols(codebookSize);
final int planeCountForDecompression = header.getImageSizeZ();
final int planePixelCount = header.getImageSizeX() * header.getImageSizeY();
@@ -71,7 +72,7 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I
if (!header.isCodebookPerPlane()) {
// There is only one codebook.
Log("Loading single codebook and huffman coder.");
- codebook = readScalarQuantizationValues(compressedStream);
+ codebook = readScalarQuantizationValues(compressedStream, codebookSize);
huffman = createHuffmanCoder(huffmanSymbols, codebook.getSymbolFrequencies());
}
@@ -80,7 +81,7 @@ public class SQImageDecompressor extends CompressorDecompressorBase implements I
stopwatch.restart();
if (header.isCodebookPerPlane()) {
Log("Loading plane codebook...");
- codebook = readScalarQuantizationValues(compressedStream);
+ codebook = readScalarQuantizationValues(compressedStream, codebookSize);
huffman = createHuffmanCoder(huffmanSymbols, codebook.getSymbolFrequencies());
}
assert (codebook != null && huffman != null);
diff --git a/src/main/java/azgracompress/compression/VQImageCompressor.java b/src/main/java/azgracompress/compression/VQImageCompressor.java
index 718680f014ceee1b16e655a7a0f76e03f8388825..9960c8d6a6b32e609c99399f2cd9174b037e5d8d 100644
--- a/src/main/java/azgracompress/compression/VQImageCompressor.java
+++ b/src/main/java/azgracompress/compression/VQImageCompressor.java
@@ -28,7 +28,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
private VectorQuantizer trainVectorQuantizerFromPlaneVectors(final int[][] planeVectors) {
LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeVectors,
- codebookSize,
+ getCodebookSize(),
options.getWorkerCount(),
options.getVectorDimension().toV3i());
LBGResult vqResult = vqInitializer.findOptimalCodebook(false);
@@ -74,7 +74,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder());
final VQCodebook codebook = cacheManager.loadVQCodebook(options.getInputFile(),
- codebookSize,
+ getCodebookSize(),
options.getVectorDimension().toV3i());
if (codebook == null) {
throw new ImageCompressionException("Failed to read quantization vectors from cache.");
@@ -94,7 +94,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
final boolean hasGeneralQuantizer = options.hasCodebookCacheFolder() || options.shouldUseMiddlePlane();
- final int[] huffmanSymbols = createHuffmanSymbols();
+ final int[] huffmanSymbols = createHuffmanSymbols(getCodebookSize());
VectorQuantizer quantizer = null;
Huffman huffman = null;
@@ -231,7 +231,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
final int[][] trainingData = loadConfiguredPlanesData();
LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData,
- codebookSize,
+ getCodebookSize(),
options.getWorkerCount(),
options.getVectorDimension().toV3i());
Log("Starting LBG optimization.");
diff --git a/src/main/java/azgracompress/compression/VQImageDecompressor.java b/src/main/java/azgracompress/compression/VQImageDecompressor.java
index 8965733987b0e87e3f1d324b4c8947c85e742014..f37b1192d15f34efe10a620ba06ffc01caf206ca 100644
--- a/src/main/java/azgracompress/compression/VQImageDecompressor.java
+++ b/src/main/java/azgracompress/compression/VQImageDecompressor.java
@@ -4,7 +4,11 @@ import azgracompress.cli.ParsedCliOptions;
import azgracompress.compression.exception.ImageDecompressionException;
import azgracompress.data.*;
import azgracompress.fileformat.QCMPFileHeader;
+import azgracompress.huffman.Huffman;
+import azgracompress.huffman.HuffmanNode;
import azgracompress.io.InBitStream;
+import azgracompress.quantization.vector.CodebookEntry;
+import azgracompress.quantization.vector.VQCodebook;
import azgracompress.utilities.Stopwatch;
import azgracompress.utilities.TypeConverter;
@@ -31,21 +35,29 @@ public class VQImageDecompressor extends CompressorDecompressorBase implements I
return (long) Math.ceil((planeVectorCount * bpp) / 8.0);
}
- private int[][] readCodebookVectors(DataInputStream compressedStream,
- final int codebookSize,
- final int vectorSize) throws ImageDecompressionException {
+ private VQCodebook readCodebook(DataInputStream compressedStream,
+ final int codebookSize,
+ final int vectorSize) throws ImageDecompressionException {
- int[][] codebook = new int[codebookSize][vectorSize];
+ final CodebookEntry[] codebookVectors = new CodebookEntry[codebookSize];
+ final long[] frequencies = new long[codebookSize];
try {
for (int codebookIndex = 0; codebookIndex < codebookSize; codebookIndex++) {
+ final int[] vector = new int[vectorSize];
for (int vecIndex = 0; vecIndex < vectorSize; vecIndex++) {
- codebook[codebookIndex][vecIndex] = compressedStream.readUnsignedShort();
+ vector[vecIndex] = compressedStream.readUnsignedShort();
}
+ codebookVectors[codebookIndex] = new CodebookEntry(vector);
+ }
+ for (int codebookIndex = 0; codebookIndex < codebookSize; codebookIndex++) {
+ frequencies[codebookIndex] = compressedStream.readLong();
}
} catch (IOException ioEx) {
throw new ImageDecompressionException("Unable to read quantization values from compressed stream.", ioEx);
}
- return codebook;
+
+ // We don't care about vector dimensions in here.
+ return new VQCodebook(new V3i(0), codebookVectors, frequencies);
}
@@ -73,19 +85,17 @@ public class VQImageDecompressor extends CompressorDecompressorBase implements I
final int vectorDataSize = 2 * header.getVectorSizeX() * header.getVectorSizeY() * header.getVectorSizeZ();
// Total codebook size in bytes.
- final long codebookDataSize = (codebookSize * vectorDataSize) * (header.isCodebookPerPlane() ?
- header.getImageSizeZ() : 1);
-
- // Number of vectors per plane.
- final long planeVectorCount = calculatePlaneVectorCount(header);
-
- // Data size of single plane indices.
- final long planeDataSize = calculatePlaneDataSize(planeVectorCount, header.getBitsPerPixel());
-
- // All planes data size.
- final long allPlanesDataSize = planeDataSize * header.getImageSizeZ();
+ final long codebookDataSize = ((codebookSize * vectorDataSize) + (codebookSize * LONG_BYTES)) *
+ (header.isCodebookPerPlane() ? header.getImageSizeZ() : 1);
+
+ // Indices are encoded using huffman. Plane data size is written in the header.
+ long[] planeDataSizes = header.getPlaneDataSizes();
+ long totalPlaneDataSize = 0;
+ for (final long planeDataSize : planeDataSizes) {
+ totalPlaneDataSize += planeDataSize;
+ }
- return (codebookDataSize + allPlanesDataSize);
+ return (codebookDataSize + totalPlaneDataSize);
}
@Override
@@ -97,15 +107,18 @@ public class VQImageDecompressor extends CompressorDecompressorBase implements I
final int vectorSize = header.getVectorSizeX() * header.getVectorSizeY() * header.getVectorSizeZ();
final int planeCountForDecompression = header.getImageSizeZ();
final long planeVectorCount = calculatePlaneVectorCount(header);
- final long planeDataSize = calculatePlaneDataSize(planeVectorCount, header.getBitsPerPixel());
+ //final long planeDataSize = calculatePlaneDataSize(planeVectorCount, header.getBitsPerPixel());
final V2i qVector = new V2i(header.getVectorSizeX(), header.getVectorSizeY());
+ final int[] huffmanSymbols = createHuffmanSymbols(codebookSize);
- int[][] quantizationVectors = null;
+ VQCodebook codebook = null;
+ Huffman huffman = null;
if (!header.isCodebookPerPlane()) {
// There is only one codebook.
Log("Loading codebook from cache...");
- quantizationVectors = readCodebookVectors(compressedStream, codebookSize, vectorSize);
+ codebook = readCodebook(compressedStream, codebookSize, vectorSize);
+ huffman = createHuffmanCoder(huffmanSymbols, codebook.getVectorFrequencies());
}
Stopwatch stopwatch = new Stopwatch();
@@ -113,25 +126,31 @@ public class VQImageDecompressor extends CompressorDecompressorBase implements I
stopwatch.restart();
if (header.isCodebookPerPlane()) {
Log("Loading plane codebook...");
- quantizationVectors = readCodebookVectors(compressedStream, codebookSize, vectorSize);
+ codebook = readCodebook(compressedStream, codebookSize, vectorSize);
+ huffman = createHuffmanCoder(huffmanSymbols, codebook.getVectorFrequencies());
}
- assert (quantizationVectors != null);
+ assert (codebook != null && huffman != null);
Log(String.format("Decompressing plane %d...", planeIndex));
byte[] decompressedPlaneData = null;
+ final int planeDataSize = (int) header.getPlaneDataSizes()[planeIndex];
try (InBitStream inBitStream = new InBitStream(compressedStream,
header.getBitsPerPixel(),
- (int) planeDataSize)) {
+ planeDataSize)) {
inBitStream.readToBuffer();
inBitStream.setAllowReadFromUnderlyingStream(false);
- final int[] indices = inBitStream.readNValues((int) planeVectorCount);
int[][] decompressedVectors = new int[(int) planeVectorCount][vectorSize];
for (int vecIndex = 0; vecIndex < planeVectorCount; vecIndex++) {
-
- System.arraycopy(quantizationVectors[indices[vecIndex]],
+ HuffmanNode currentHuffmanNode = huffman.getRoot();
+ boolean bit;
+ while (!currentHuffmanNode.isLeaf()) {
+ bit = inBitStream.readBit();
+ currentHuffmanNode = currentHuffmanNode.traverse(bit);
+ }
+ System.arraycopy(codebook.getVectors()[currentHuffmanNode.getSymbol()].getVector(),
0,
decompressedVectors[vecIndex],
0,
diff --git a/src/main/java/azgracompress/data/V3i.java b/src/main/java/azgracompress/data/V3i.java
index cd238d77b39cfe330ad92a5fee25d409f44d867c..9d09bee0445c2216a7968ae374cf18f285b98519 100644
--- a/src/main/java/azgracompress/data/V3i.java
+++ b/src/main/java/azgracompress/data/V3i.java
@@ -12,7 +12,7 @@ public class V3i {
}
public V3i(final int x, final int y) {
- this(x,y,1);
+ this(x, y, 1);
}
public V3i(final int universalValue) {
@@ -67,4 +67,8 @@ public class V3i {
public V2i toV2i() {
return new V2i(x, y);
}
+
+ public long multiplyTogether() {
+ return (x * y * z);
+ }
}
diff --git a/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java b/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java
index e9fb89b0b0b0723badbfb4cf4979ab317a8d70e7..480b3d99b2e4e577c3e67d470e071169a481d373 100644
--- a/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java
+++ b/src/main/java/azgracompress/quantization/scalar/ScalarQuantizer.java
@@ -102,18 +102,4 @@ public class ScalarQuantizer {
public SQCodebook getCodebook() {
return codebook;
}
-
- public long[] calculateFrequencies(int[] trainData) {
- long[] frequencies = new long[codebook.getCodebookSize()];
-
- // Speedup maybe?
- for (final int value : trainData) {
- for (int intervalId = 1; intervalId <= codebook.getCodebookSize(); intervalId++) {
- if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) {
- ++frequencies[intervalId - 1];
- }
- }
- }
- return frequencies;
- }
}
diff --git a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java
index a248ec5278c3d1a9d5a650cfc471b4ad165f78c2..070d5d97868b25eb10f95016d19ff13b16170ca2 100644
--- a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java
+++ b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java
@@ -45,6 +45,7 @@ public class LBGVectorQuantizer {
this.codebookSize = codebookSize;
this.workerCount = workerCount;
+ frequencies = new long[this.codebookSize];
findUniqueVectors();
}
@@ -167,8 +168,7 @@ public class LBGVectorQuantizer {
}
private synchronized void addWorkerFrequencies(final long[] workerFrequencies) {
- assert (frequencies.length == workerFrequencies.length) : "Frequency array length mismatch.";
- for (int i = 0; i < frequencies.length; i++) {
+ for (int i = 0; i < workerFrequencies.length; i++) {
frequencies[i] += workerFrequencies[i];
}
}
@@ -193,30 +193,30 @@ public class LBGVectorQuantizer {
final int toIndex = (wId == workerCount - 1) ? trainingVectors.length : (workSize + (wId * workSize));
workers[wId] = new Thread(() -> {
+ long[] workerFrequencies = new long[codebook.length];
VectorQuantizer quantizer = new VectorQuantizer(new VQCodebook(vectorDimensions,
codebook,
frequencies));
-
double threadMse = 0.0;
- int cnt = 0;
int[] vector;
- int[] quantizedVector;
+ int qIndex;
+ int[] qVector;
for (int i = fromIndex; i < toIndex; i++) {
- ++cnt;
vector = trainingVectors[i].getVector();
- quantizedVector = quantizer.quantize(vector);
+ qIndex = quantizer.quantizeToIndex(vector);
+ ++workerFrequencies[qIndex];
+ qVector = quantizer.getCodebookVectors()[qIndex].getVector();
for (int vI = 0; vI < vectorSize; vI++) {
- threadMse += Math.pow(((double) vector[vI] - (double) quantizedVector[vI]), 2);
+ threadMse += Math.pow(((double) vector[vI] - (double) qVector[vI]), 2);
}
}
- assert (cnt == toIndex - fromIndex);
threadMse /= (double) (toIndex - fromIndex);
// Update global mse, updateMse function is synchronized.
updateMse(threadMse);
- addWorkerFrequencies(quantizer.getFrequencies());
+ addWorkerFrequencies(workerFrequencies);
});
workers[wId].start();
@@ -234,11 +234,14 @@ public class LBGVectorQuantizer {
VectorQuantizer quantizer = new VectorQuantizer(new VQCodebook(vectorDimensions,
codebook,
frequencies));
+ int qIndex;
+ int[] qVector;
for (final TrainingVector trV : trainingVectors) {
- int[] quantizedV = quantizer.quantize(trV.getVector());
-
+ qIndex = quantizer.quantizeToIndex(trV.getVector());
+ qVector = quantizer.getCodebookVectors()[qIndex].getVector();
+ ++frequencies[qIndex];
for (int i = 0; i < vectorSize; i++) {
- mse += Math.pow(((double) trV.getVector()[i] - (double) quantizedV[i]), 2);
+ mse += Math.pow(((double) trV.getVector()[i] - (double) qVector[i]), 2);
}
}
mse /= (double) trainingVectors.length;
diff --git a/src/main/java/azgracompress/quantization/vector/VQCodebook.java b/src/main/java/azgracompress/quantization/vector/VQCodebook.java
index 25548891ba0e4ebaafcffdbf01d47f3dfd50388c..c2a53f64643a45d33cb568980c6a9f5deaf41037 100644
--- a/src/main/java/azgracompress/quantization/vector/VQCodebook.java
+++ b/src/main/java/azgracompress/quantization/vector/VQCodebook.java
@@ -27,7 +27,7 @@ public class VQCodebook {
private final V3i vectorDims;
public VQCodebook(final V3i vectorDims, final CodebookEntry[] vectors, final long[] vectorFrequencies) {
- assert (vectors.length == vectorFrequencies.length);
+ //assert (vectors.length == vectorFrequencies.length);
this.vectorDims = vectorDims;
this.vectors = vectors;
this.vectorFrequencies = vectorFrequencies;
diff --git a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java
index afbb6b64994f89c97509bc47c10e0e7dc6bfb5b8..94a44241eaa50c7b1cc184ffa7272b024bc11efa 100644
--- a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java
+++ b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java
@@ -1,21 +1,18 @@
package azgracompress.quantization.vector;
-import java.util.Arrays;
-
public class VectorQuantizer {
private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean;
private final VQCodebook codebook;
private final CodebookEntry[] codebookVectors;
private final int vectorSize;
-
- private long[] frequencies;
+ private final long[] frequencies;
public VectorQuantizer(final VQCodebook codebook) {
this.codebook = codebook;
this.codebookVectors = codebook.getVectors();
vectorSize = codebookVectors[0].getVector().length;
- frequencies = codebook.getVectorFrequencies();
+ this.frequencies = codebook.getVectorFrequencies();
}
public int[] quantize(final int[] dataVector) {
@@ -24,6 +21,11 @@ public class VectorQuantizer {
return closestEntry.getVector();
}
+ public int quantizeToIndex(final int[] dataVector) {
+ assert (dataVector.length > 0 && dataVector.length % vectorSize == 0) : "Wrong vector size";
+ return findClosestCodebookEntryIndex(dataVector, metric);
+ }
+
public int[][] quantize(final int[][] dataVectors, final int workerCount) {
assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size";
int[][] result = new int[dataVectors.length][vectorSize];
@@ -47,25 +49,16 @@ public class VectorQuantizer {
return quantizeIntoIndices(dataVectors, 1);
}
- private synchronized void addWorkerFrequencies(final long[] workerFrequencies) {
- assert (frequencies.length == workerFrequencies.length) : "Frequency array length mismatch.";
- for (int i = 0; i < frequencies.length; i++) {
- frequencies[i] += workerFrequencies[i];
- }
- }
-
public int[] quantizeIntoIndices(final int[][] dataVectors, final int maxWorkerCount) {
assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size";
int[] indices = new int[dataVectors.length];
- Arrays.fill(frequencies, 0);
if (maxWorkerCount == 1) {
int closestIndex;
for (int vectorIndex = 0; vectorIndex < dataVectors.length; vectorIndex++) {
closestIndex = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric);
indices[vectorIndex] = closestIndex;
- ++frequencies[closestIndex];
}
} else {
// Cap the worker count on 8
@@ -80,13 +73,10 @@ public class VectorQuantizer {
workers[wId] = new Thread(() -> {
int closestIndex;
- long[] workerFrequencies = new long[codebookVectors.length];
for (int vectorIndex = fromIndex; vectorIndex < toIndex; vectorIndex++) {
closestIndex = findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric);
indices[vectorIndex] = closestIndex;
- ++workerFrequencies[vectorIndex];
}
- addWorkerFrequencies(workerFrequencies);
});
workers[wId].start();
@@ -168,44 +158,5 @@ public class VectorQuantizer {
public long[] getFrequencies() {
return frequencies;
}
-
- public long[] calculateFrequencies(int[][] dataVectors, final int maxWorkerCount) {
- Arrays.fill(frequencies, 0);
- assert (dataVectors.length > 0 && dataVectors[0].length % vectorSize == 0) : "Wrong vector size";
-
- if (maxWorkerCount == 1) {
- for (final int[] dataVector : dataVectors) {
- ++frequencies[findClosestCodebookEntryIndex(dataVector, metric)];
- }
- } else {
- // Cap the worker count on 8
- final int workerCount = Math.min(maxWorkerCount, 8);
- Thread[] workers = new Thread[workerCount];
- final int workSize = dataVectors.length / workerCount;
-
- for (int wId = 0; wId < workerCount; wId++) {
- final int fromIndex = wId * workSize;
- final int toIndex = (wId == workerCount - 1) ? dataVectors.length : (workSize + (wId * workSize));
-
- workers[wId] = new Thread(() -> {
- long[] workerFrequencies = new long[codebookVectors.length];
- for (int vectorIndex = fromIndex; vectorIndex < toIndex; vectorIndex++) {
- ++workerFrequencies[findClosestCodebookEntryIndex(dataVectors[vectorIndex], metric)];
- }
- addWorkerFrequencies(workerFrequencies);
- });
-
- workers[wId].start();
- }
- try {
- for (int wId = 0; wId < workerCount; wId++) {
- workers[wId].join();
- }
- } catch (InterruptedException e) {
- e.printStackTrace();
- }
- }
- return frequencies;
- }
}