Skip to content
Snippets Groups Projects
Commit bd3bc91c authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Found the slow places in LBG.

parent 459e85c5
No related branches found
No related tags found
No related merge requests found
......@@ -46,6 +46,12 @@ public abstract class CompressorDecompressorBase {
}
}
protected void Log(final String format, final Object... args) {
if (options.isVerbose()) {
System.out.println(String.format(format, args));
}
}
protected void DebugLog(final String message) {
System.out.println(message);
}
......
......@@ -4,9 +4,9 @@ import azgracompress.cli.ParsedCliOptions;
import azgracompress.compression.exception.ImageCompressionException;
import azgracompress.data.Chunk2D;
import azgracompress.data.ImageU16;
import azgracompress.data.V2i;
import azgracompress.io.OutBitStream;
import azgracompress.io.RawDataIO;
import azgracompress.quantization.QuantizationValueCache;
import azgracompress.quantization.vector.CodebookEntry;
import azgracompress.quantization.vector.LBGResult;
import azgracompress.quantization.vector.LBGVectorQuantizer;
......@@ -29,15 +29,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
* @return Image vectors.
*/
private int[][] getPlaneVectors(final ImageU16 plane) {
final V2i qVector = options.getVectorDimension();
if (qVector.getY() > 1) {
// 2D Quantization, return `matrices`.
return Chunk2D.chunksAsImageVectors(plane.as2dChunk().divideIntoChunks(qVector));
} else {
// 1D Quantization, return row vectors.
return plane.as2dChunk().divideInto1DVectors(qVector.getX());
}
return plane.toQuantizationVectors(options.getVectorDimension());
}
/**
......@@ -147,9 +139,83 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
}
}
/**
* Load plane and convert the plane into quantization vectors.
*
* @param planeIndex Zero based plane index.
* @return Quantization vectors of configured quantization.
* @throws IOException When reading fails.
*/
private int[][] loadPlaneQuantizationVectors(final int planeIndex) throws IOException {
ImageU16 refPlane = RawDataIO.loadImageU16(options.getInputFile(),
options.getImageDimension(),
planeIndex);
return refPlane.toQuantizationVectors(options.getVectorDimension());
}
private int[][] loadConfiguredPlanesData() throws ImageCompressionException {
final int vectorSize = options.getVectorDimension().getX() * options.getVectorDimension().getY();
int[][] trainData = null;
Stopwatch s = new Stopwatch();
s.start();
if (options.hasPlaneIndexSet()) {
Log("VQ: Loading single plane data.");
try {
trainData = loadPlaneQuantizationVectors(options.getPlaneIndex());
} catch (IOException e) {
throw new ImageCompressionException("Failed to load reference image data.", e);
}
} else {
Log(options.hasPlaneRangeSet() ? "Loading plane range data." : "Loading all planes data.");
final int[] planeIndices = getPlaneIndicesForCompression();
final int chunkCountPerPlane = Chunk2D.calculateRequiredChunkCountPerPlane(
options.getImageDimension().toV2i(),
options.getVectorDimension());
final int totalChunkCount = chunkCountPerPlane * planeIndices.length;
trainData = new int[totalChunkCount][vectorSize];
int[][] planeVectors;
int planeCounter = 0;
for (final int planeIndex : planeIndices) {
Log("Loading plane %d vectors", planeIndex);
try {
planeVectors = loadPlaneQuantizationVectors(planeIndex);
assert (planeVectors.length == chunkCountPerPlane) : "Wrong chunk count per plane";
} catch (IOException e) {
throw new ImageCompressionException(String.format("Failed to load plane %d image data.",
planeIndex), e);
}
System.arraycopy(planeVectors, 0, trainData, (planeCounter * chunkCountPerPlane), chunkCountPerPlane);
++planeCounter;
}
}
s.stop();
Log("Quantization vector load took: " + s.getElapsedTimeString());
return trainData;
}
@Override
public void trainAndSaveCodebook() throws ImageCompressionException {
throw new ImageCompressionException("Not implemented yet");
final int[][] trainingData = loadConfiguredPlanesData();
LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, codebookSize);
Log("Starting LBG optimization.");
LBGResult lbgResult = vqInitializer.findOptimalCodebook(options.isVerbose());
Log("Learned the optimal codebook.");
Log("Saving cache file to %s", options.getOutputFile());
QuantizationValueCache cache = new QuantizationValueCache(options.getOutputFile());
try {
cache.saveQuantizationValues(options.getInputFile(), lbgResult.getCodebook());
} catch (IOException e) {
throw new ImageCompressionException("Unable to write cache.", e);
}
Log("Operation completed.");
}
......
......@@ -74,7 +74,7 @@ public class Chunk2D {
final int xSize = dims.getX();
final int ySize = dims.getY();
final int chunkCount = getRequiredChunkCount(chunkDims);
final int chunkCount = calculateRequiredChunkCountPerPlane(chunkDims);
Chunk2D[] chunks = new Chunk2D[chunkCount];
int chunkIndex = 0;
......@@ -95,7 +95,7 @@ public class Chunk2D {
final int chunkYSize = qVectorDims.getY();
final int chunkSize = chunkXSize * chunkYSize;
final int chunkCount = getRequiredChunkCount(qVectorDims);
final int chunkCount = calculateRequiredChunkCountPerPlane(qVectorDims);
int[][] vectors = new int[chunkCount][chunkSize];
int vecIndex = 0;
......@@ -108,9 +108,13 @@ public class Chunk2D {
return vectors;
}
private int getRequiredChunkCount(final V2i chunkDims) {
final int xChunkCount = (int) Math.ceil(dims.getX() / (double) chunkDims.getX());
final int yChunkCount = (int) Math.ceil(dims.getY() / (double) chunkDims.getY());
private int calculateRequiredChunkCountPerPlane(final V2i chunkDims) {
return calculateRequiredChunkCountPerPlane(dims, chunkDims);
}
public static int calculateRequiredChunkCountPerPlane(final V2i imageDims, final V2i chunkDims) {
final int xChunkCount = (int) Math.ceil(imageDims.getX() / (double) chunkDims.getX());
final int yChunkCount = (int) Math.ceil(imageDims.getY() / (double) chunkDims.getY());
return (xChunkCount * yChunkCount);
}
......@@ -119,7 +123,7 @@ public class Chunk2D {
assert (chunks.length > 0) : "No chunks in reconstruct";
final V2i chunkDims = chunks[0].getDims();
assert (getRequiredChunkCount(chunkDims) == chunks.length) : "Wrong chunk count in reconstruct";
assert (calculateRequiredChunkCountPerPlane(chunkDims) == chunks.length) : "Wrong chunk count in reconstruct";
for (final Chunk2D chunk : chunks) {
copyFromChunk(chunk);
}
......
......@@ -54,4 +54,13 @@ public class V3i {
public V3l toV3l() {
return new V3l(x, y, z);
}
/**
* Convert this vector to V2i by dropping the Z value.
*
* @return V2i vector with X and Y values.
*/
public V2i toV2i() {
return new V2i(x, y);
}
}
package azgracompress.quantization.vector;
import azgracompress.U16;
import azgracompress.utilities.Stopwatch;
import azgracompress.utilities.Utils;
import java.util.ArrayList;
......@@ -15,6 +16,8 @@ public class LBGVectorQuantizer {
private final int[][] trainingVectors;
private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean;
boolean verbose = false;
public LBGVectorQuantizer(final int[][] trainingVectors, final int codebookSize) {
assert (trainingVectors.length > 0) : "No training vectors provided";
......@@ -28,12 +31,13 @@ public class LBGVectorQuantizer {
return findOptimalCodebook(true);
}
public LBGResult findOptimalCodebook(boolean verbose) {
ArrayList<LearningCodebookEntry> codebook = initializeCodebook(verbose);
public LBGResult findOptimalCodebook(boolean isVerbose) {
this.verbose = isVerbose;
ArrayList<LearningCodebookEntry> codebook = initializeCodebook();
if (verbose) {
System.out.println("Got initial codebook. Improving codebook...");
}
LBG(codebook, EPSILON * 0.01, verbose);
LBG(codebook, EPSILON * 0.01);
final double finalMse = averageMse(codebook);
final double psnr = Utils.calculatePsnr(finalMse, U16.Max);
if (verbose) {
......@@ -100,7 +104,7 @@ public class LBGVectorQuantizer {
}
private ArrayList<LearningCodebookEntry> initializeCodebook(final boolean verbose) {
private ArrayList<LearningCodebookEntry> initializeCodebook() {
ArrayList<LearningCodebookEntry> codebook = new ArrayList<>(codebookSize);
// Initialize first codebook entry as average of training vectors
int k = 1;
......@@ -182,7 +186,7 @@ public class LBGVectorQuantizer {
if (verbose) {
System.out.println("Improving current codebook...");
}
LBG(codebook, verbose);
LBG(codebook);
if (verbose) {
......@@ -194,11 +198,13 @@ public class LBGVectorQuantizer {
}
private void LBG(ArrayList<LearningCodebookEntry> codebook, final boolean verbose) {
LBG(codebook, EPSILON, verbose);
private void LBG(ArrayList<LearningCodebookEntry> codebook) {
LBG(codebook, EPSILON);
}
private void LBG(ArrayList<LearningCodebookEntry> codebook, final double epsilon, final boolean verbose) {
private void LBG(ArrayList<LearningCodebookEntry> codebook, final double epsilon) {
Stopwatch totalLbgFun = Stopwatch.startNew("Whole LBG function");
codebook.forEach(entry -> {
entry.clearTrainingData();
assert (entry.getTrainingVectors().size() == 0) : "Using entries which are not cleared.";
......@@ -207,9 +213,17 @@ public class LBGVectorQuantizer {
double previousDistortion = Double.POSITIVE_INFINITY;
int iteration = 1;
Stopwatch innerLoopStopwatch = new Stopwatch("LBG inner loop");
Stopwatch findingClosestEntryStopwatch = new Stopwatch("FindingClosestEntry");
Stopwatch distCalcStopwatch = new Stopwatch("DistortionCalc");
Stopwatch fixEmptyStopwatch = new Stopwatch("FixEmpty");
while (true) {
System.out.println("================");
innerLoopStopwatch.restart();
// Step 1
// Speedup - speed the finding of the closest codebook entry.
findingClosestEntryStopwatch.restart();
for (final int[] trainingVec : trainingVectors) {
double minDist = Double.POSITIVE_INFINITY;
LearningCodebookEntry closestEntry = null;
......@@ -232,20 +246,29 @@ public class LBGVectorQuantizer {
System.err.println("Did not found closest entry.");
}
}
findingClosestEntryStopwatch.stop();
System.out.println(findingClosestEntryStopwatch);
fixEmptyStopwatch.restart();
fixEmptyEntries(codebook, verbose);
fixEmptyStopwatch.stop();
System.out.println(fixEmptyStopwatch);
// Step 2
distCalcStopwatch.restart();
double avgDistortion = 0;
for (LearningCodebookEntry entry : codebook) {
avgDistortion += entry.getAverageDistortion();
}
avgDistortion /= (double) codebook.size();
distCalcStopwatch.stop();
System.out.println(distCalcStopwatch);
// Step 3
double dist = (previousDistortion - avgDistortion) / avgDistortion;
if (verbose) {
System.out.println(String.format("It: %d Distortion: %.5f", iteration++, dist));
// System.out.println(String.format("It: %d Distortion: %.5f", iteration++, dist));
}
if (dist < epsilon) {
......@@ -260,7 +283,14 @@ public class LBGVectorQuantizer {
entry.clearTrainingData();
}
}
innerLoopStopwatch.stop();
System.out.println(innerLoopStopwatch);
System.out.println("================");
}
totalLbgFun.stop();
System.out.println(totalLbgFun);
}
......@@ -324,6 +354,7 @@ public class LBGVectorQuantizer {
biggestPartition.clearTrainingData();
newEntry.clearTrainingData();
// Speedup - speed the look for closest entry.
for (final int[] trVec : partitionVectors) {
double originalPartitionDist = VectorQuantizer.distanceBetweenVectors(biggestPartition.getVector(),
trVec,
......
package azgracompress.utilities;
import org.jetbrains.annotations.NotNull;
import java.time.Duration;
import java.time.Instant;
public class Stopwatch {
private final String name;
private Instant start;
private Instant end;
Duration elapsed;
@NotNull
public static Stopwatch startNew(final String name) {
Stopwatch stopwatch = new Stopwatch(name);
stopwatch.start();
return stopwatch;
}
@NotNull
public static Stopwatch startNew() {
return startNew(null);
}
public Stopwatch(final String name) {
this.name = name;
}
public Stopwatch() {
name = null;
}
public void start() {
......@@ -41,12 +61,20 @@ public class Stopwatch {
if (elapsed == null) {
return "No time measured yet.";
}
return String.format("%dH %dmin %dsec %dms %dns", elapsed.toHoursPart(), elapsed.toMinutesPart(), elapsed.toSecondsPart(), elapsed.toMillisPart(), elapsed.toNanosPart());
return String.format("%dH %dmin %dsec %dms %dns",
elapsed.toHoursPart(),
elapsed.toMinutesPart(),
elapsed.toSecondsPart(),
elapsed.toMillisPart(),
elapsed.toNanosPart());
}
@Override
public String toString() {
if (name != null) {
return name + ": " + getElapsedTimeString();
}
return getElapsedTimeString();
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment