Skip to content
Snippets Groups Projects
Commit 243fd105 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Added worker count to LBG and first parallel function.

parent 83fc59ef
No related branches found
No related tags found
No related merge requests found
...@@ -85,7 +85,7 @@ public class VectorQuantizationBenchmark extends BenchmarkBase { ...@@ -85,7 +85,7 @@ public class VectorQuantizationBenchmark extends BenchmarkBase {
} }
final int[][] refPlaneData = getPlaneVectors(plane, qVector); final int[][] refPlaneData = getPlaneVectors(plane, qVector);
LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(refPlaneData, codebookSize); LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(refPlaneData, codebookSize, 1);
final LBGResult vqResult = vqInitializer.findOptimalCodebook(); final LBGResult vqResult = vqInitializer.findOptimalCodebook();
quantizer = new VectorQuantizer(vqResult.getCodebook()); quantizer = new VectorQuantizer(vqResult.getCodebook());
System.out.println("Created reference quantizer."); System.out.println("Created reference quantizer.");
...@@ -106,7 +106,7 @@ public class VectorQuantizationBenchmark extends BenchmarkBase { ...@@ -106,7 +106,7 @@ public class VectorQuantizationBenchmark extends BenchmarkBase {
if (!hasGeneralQuantizer) { if (!hasGeneralQuantizer) {
LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeData, codebookSize); LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeData, codebookSize,1);
LBGResult vqResult = vqInitializer.findOptimalCodebook(); LBGResult vqResult = vqInitializer.findOptimalCodebook();
quantizer = new VectorQuantizer(vqResult.getCodebook()); quantizer = new VectorQuantizer(vqResult.getCodebook());
System.out.println("Created plane quantizer."); System.out.println("Created plane quantizer.");
......
...@@ -29,7 +29,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -29,7 +29,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
* @return Trained vector quantizer with codebook of set size. * @return Trained vector quantizer with codebook of set size.
*/ */
private VectorQuantizer trainVectorQuantizerFromPlaneVectors(final int[][] planeVectors) { private VectorQuantizer trainVectorQuantizerFromPlaneVectors(final int[][] planeVectors) {
LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeVectors, codebookSize); LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(planeVectors, codebookSize, options.getWorkerCount());
LBGResult vqResult = vqInitializer.findOptimalCodebook(false); LBGResult vqResult = vqInitializer.findOptimalCodebook(false);
return new VectorQuantizer(vqResult.getCodebook()); return new VectorQuantizer(vqResult.getCodebook());
} }
...@@ -185,7 +185,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -185,7 +185,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
throw new ImageCompressionException("Failed to load reference image data.", e); throw new ImageCompressionException("Failed to load reference image data.", e);
} }
} else { } else {
Log(options.hasPlaneRangeSet() ? "Loading plane range data." : "Loading all planes data."); Log(options.hasPlaneRangeSet() ? "VQ: Loading plane range data." : "VQ: Loading all planes data.");
final int[] planeIndices = getPlaneIndicesForCompression(); final int[] planeIndices = getPlaneIndicesForCompression();
final int chunkCountPerPlane = Chunk2D.calculateRequiredChunkCountPerPlane( final int chunkCountPerPlane = Chunk2D.calculateRequiredChunkCountPerPlane(
...@@ -198,7 +198,6 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -198,7 +198,6 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
int[][] planeVectors; int[][] planeVectors;
int planeCounter = 0; int planeCounter = 0;
for (final int planeIndex : planeIndices) { for (final int planeIndex : planeIndices) {
Log("Loading plane %d vectors", planeIndex);
try { try {
planeVectors = loadPlaneQuantizationVectors(planeIndex); planeVectors = loadPlaneQuantizationVectors(planeIndex);
assert (planeVectors.length == chunkCountPerPlane) : "Wrong chunk count per plane"; assert (planeVectors.length == chunkCountPerPlane) : "Wrong chunk count per plane";
...@@ -220,7 +219,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -220,7 +219,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
public void trainAndSaveCodebook() throws ImageCompressionException { public void trainAndSaveCodebook() throws ImageCompressionException {
final int[][] trainingData = loadConfiguredPlanesData(); final int[][] trainingData = loadConfiguredPlanesData();
LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, codebookSize); LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, codebookSize, options.getWorkerCount());
Log("Starting LBG optimization."); Log("Starting LBG optimization.");
LBGResult lbgResult = vqInitializer.findOptimalCodebook(options.isVerbose()); LBGResult lbgResult = vqInitializer.findOptimalCodebook(options.isVerbose());
Log("Learned the optimal codebook."); Log("Learned the optimal codebook.");
......
...@@ -122,13 +122,8 @@ public class LloydMaxU16ScalarQuantization { ...@@ -122,13 +122,8 @@ public class LloydMaxU16ScalarQuantization {
boundaryPoints, boundaryPoints,
codebookSize); codebookSize);
workers[wId] = new Thread(runnables[wId]); workers[wId] = new Thread(runnables[wId]);
}
for (int wId = 0; wId < workerCount; wId++) {
workers[wId].start(); workers[wId].start();
} }
try { try {
for (int wId = 0; wId < workerCount; wId++) { for (int wId = 0; wId < workerCount; wId++) {
workers[wId].join(); workers[wId].join();
......
...@@ -17,14 +17,16 @@ public class LBGVectorQuantizer { ...@@ -17,14 +17,16 @@ public class LBGVectorQuantizer {
private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean; private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean;
boolean verbose = false; boolean verbose = false;
private final int workerCount;
public LBGVectorQuantizer(final int[][] trainingVectors, final int codebookSize) { public LBGVectorQuantizer(final int[][] trainingVectors, final int codebookSize, final int workerCount) {
assert (trainingVectors.length > 0) : "No training vectors provided"; assert (trainingVectors.length > 0) : "No training vectors provided";
this.trainingVectors = trainingVectors; this.trainingVectors = trainingVectors;
this.vectorSize = trainingVectors[0].length; this.vectorSize = trainingVectors[0].length;
this.codebookSize = codebookSize; this.codebookSize = codebookSize;
this.workerCount = workerCount;
} }
public LBGResult findOptimalCodebook() { public LBGResult findOptimalCodebook() {
...@@ -212,7 +214,6 @@ public class LBGVectorQuantizer { ...@@ -212,7 +214,6 @@ public class LBGVectorQuantizer {
double previousDistortion = Double.POSITIVE_INFINITY; double previousDistortion = Double.POSITIVE_INFINITY;
int iteration = 1;
Stopwatch innerLoopStopwatch = new Stopwatch("LBG inner loop"); Stopwatch innerLoopStopwatch = new Stopwatch("LBG inner loop");
Stopwatch findingClosestEntryStopwatch = new Stopwatch("FindingClosestEntry"); Stopwatch findingClosestEntryStopwatch = new Stopwatch("FindingClosestEntry");
Stopwatch distCalcStopwatch = new Stopwatch("DistortionCalc"); Stopwatch distCalcStopwatch = new Stopwatch("DistortionCalc");
...@@ -222,37 +223,20 @@ public class LBGVectorQuantizer { ...@@ -222,37 +223,20 @@ public class LBGVectorQuantizer {
innerLoopStopwatch.restart(); innerLoopStopwatch.restart();
// Step 1 // Step 1
// Speedup - speed the finding of the closest codebook entry.
findingClosestEntryStopwatch.restart();
for (final int[] trainingVec : trainingVectors) {
double minDist = Double.POSITIVE_INFINITY;
LearningCodebookEntry closestEntry = null;
for (LearningCodebookEntry entry : codebook) {
double entryDistance = VectorQuantizer.distanceBetweenVectors(entry.getVector(),
trainingVec,
metric);
if (entryDistance < minDist) { findingClosestEntryStopwatch.restart();
minDist = entryDistance;
closestEntry = entry; assignVectorsToClosestEntry(codebook);
}
}
if (closestEntry != null) {
closestEntry.addTrainingVector(trainingVec, minDist);
} else {
assert (false) : "Did not found closest entry.";
System.err.println("Did not found closest entry.");
}
}
findingClosestEntryStopwatch.stop(); findingClosestEntryStopwatch.stop();
System.out.println(findingClosestEntryStopwatch); System.out.println(findingClosestEntryStopwatch);
fixEmptyStopwatch.restart(); // fixEmptyStopwatch.restart();
fixEmptyEntries(codebook, verbose); fixEmptyEntries(codebook, verbose);
fixEmptyStopwatch.stop(); // fixEmptyStopwatch.stop();
System.out.println(fixEmptyStopwatch); // System.out.println(fixEmptyStopwatch);
// Step 2 // Step 2
distCalcStopwatch.restart(); distCalcStopwatch.restart();
...@@ -285,7 +269,7 @@ public class LBGVectorQuantizer { ...@@ -285,7 +269,7 @@ public class LBGVectorQuantizer {
} }
innerLoopStopwatch.stop(); innerLoopStopwatch.stop();
System.out.println(innerLoopStopwatch); // System.out.println(innerLoopStopwatch);
System.out.println("================"); System.out.println("================");
} }
...@@ -293,6 +277,83 @@ public class LBGVectorQuantizer { ...@@ -293,6 +277,83 @@ public class LBGVectorQuantizer {
System.out.println(totalLbgFun); System.out.println(totalLbgFun);
} }
private void assignVectorsToClosestEntry(ArrayList<LearningCodebookEntry> codebook) {
if (workerCount > 1) {
Thread[] workers = new Thread[workerCount];
final int workSize = trainingVectors.length / workerCount;
for (int wId = 0; wId < workerCount; wId++) {
final int fromIndex = wId * workSize;
final int toIndex = (wId == workerCount - 1) ? trainingVectors.length : (workSize + (wId * workSize));
workers[wId] = new Thread(() -> {
double minimalDistance, entryDistance;
for (int vecIndex = fromIndex; vecIndex < toIndex; vecIndex++) {
minimalDistance = Double.POSITIVE_INFINITY;
LearningCodebookEntry closestEntry = null;
for (LearningCodebookEntry entry : codebook) {
entryDistance = VectorQuantizer.distanceBetweenVectors(entry.getVector(),
trainingVectors[vecIndex],
metric);
if (entryDistance < minimalDistance) {
minimalDistance = entryDistance;
closestEntry = entry;
}
}
if (closestEntry != null) {
closestEntry.addTrainingVector(trainingVectors[vecIndex],
minimalDistance);
} else {
assert (false) : "Did not found closest entry.";
System.err.println("Did not found closest entry.");
}
}
});
workers[wId].start();
}
try {
for (int wId = 0; wId < workerCount; wId++) {
workers[wId].join();
}
} catch (InterruptedException e) {
e.printStackTrace();
assert (false) : "Failed parallel join";
}
} else {
//////////////////////////////////////////////////////////////////////////
// Speedup - speed the finding of the closest codebook entry.
for (final int[] trainingVec : trainingVectors) {
double minDist = Double.POSITIVE_INFINITY;
LearningCodebookEntry closestEntry = null;
for (LearningCodebookEntry entry : codebook) {
double entryDistance = VectorQuantizer.distanceBetweenVectors(entry.getVector(),
trainingVec,
metric);
if (entryDistance < minDist) {
minDist = entryDistance;
closestEntry = entry;
}
}
if (closestEntry != null) {
closestEntry.addTrainingVector(trainingVec, minDist);
} else {
assert (false) : "Did not found closest entry.";
System.err.println("Did not found closest entry.");
}
}
}
}
private void fixEmptyEntries(ArrayList<LearningCodebookEntry> codebook, final boolean verbose) { private void fixEmptyEntries(ArrayList<LearningCodebookEntry> codebook, final boolean verbose) {
LearningCodebookEntry emptyEntry = null; LearningCodebookEntry emptyEntry = null;
...@@ -318,9 +379,10 @@ public class LBGVectorQuantizer { ...@@ -318,9 +379,10 @@ public class LBGVectorQuantizer {
private void fixSingleEmptyEntry(ArrayList<LearningCodebookEntry> codebook, private void fixSingleEmptyEntry(ArrayList<LearningCodebookEntry> codebook,
final LearningCodebookEntry emptyEntry, final LearningCodebookEntry emptyEntry,
final boolean verbose) { final boolean verbose) {
if (verbose) { // if (verbose) {
System.out.println("******** FOUND EMPTY ENTRY ********"); // System.out.println("******** FOUND EMPTY ENTRY ********");
} // }
// Remove empty entry from codebook. // Remove empty entry from codebook.
codebook.remove(emptyEntry); codebook.remove(emptyEntry);
......
...@@ -42,7 +42,7 @@ public class LearningCodebookEntry extends CodebookEntry { ...@@ -42,7 +42,7 @@ public class LearningCodebookEntry extends CodebookEntry {
this.trainingVectors = trainingVectors; this.trainingVectors = trainingVectors;
} }
public void addTrainingVector(final int[] trainingVec, final double vecDist) { public synchronized void addTrainingVector(final int[] trainingVec, final double vecDist) {
trainingVectors.add(trainingVec); trainingVectors.add(trainingVec);
trainingVectorsDistances.add(vecDist); trainingVectorsDistances.add(vecDist);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment