diff --git a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java index 1dd3e9bde4bb1b2ba0e9405f81b1fe8d1791480a..8caad0b80551e89d8f49e48a8897ad03c48bd260 100644 --- a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java +++ b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java @@ -46,6 +46,12 @@ public abstract class CompressorDecompressorBase { } } + protected void Log(final String format, final Object... args) { + if (options.isVerbose()) { + System.out.println(String.format(format, args)); + } + } + protected void DebugLog(final String message) { System.out.println(message); } diff --git a/src/main/java/azgracompress/compression/VQImageCompressor.java b/src/main/java/azgracompress/compression/VQImageCompressor.java index 3bc8ee5bcfb39aaae90bb8cc17aeb1e7706b8a5e..3ad19727644fee3e66fb36a5fc890a27cea0bb9a 100644 --- a/src/main/java/azgracompress/compression/VQImageCompressor.java +++ b/src/main/java/azgracompress/compression/VQImageCompressor.java @@ -4,9 +4,9 @@ import azgracompress.cli.ParsedCliOptions; import azgracompress.compression.exception.ImageCompressionException; import azgracompress.data.Chunk2D; import azgracompress.data.ImageU16; -import azgracompress.data.V2i; import azgracompress.io.OutBitStream; import azgracompress.io.RawDataIO; +import azgracompress.quantization.QuantizationValueCache; import azgracompress.quantization.vector.CodebookEntry; import azgracompress.quantization.vector.LBGResult; import azgracompress.quantization.vector.LBGVectorQuantizer; @@ -29,15 +29,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm * @return Image vectors. */ private int[][] getPlaneVectors(final ImageU16 plane) { - final V2i qVector = options.getVectorDimension(); - - if (qVector.getY() > 1) { - // 2D Quantization, return `matrices`. - return Chunk2D.chunksAsImageVectors(plane.as2dChunk().divideIntoChunks(qVector)); - } else { - // 1D Quantization, return row vectors. - return plane.as2dChunk().divideInto1DVectors(qVector.getX()); - } + return plane.toQuantizationVectors(options.getVectorDimension()); } /** @@ -147,9 +139,83 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm } } + /** + * Load plane and convert the plane into quantization vectors. + * + * @param planeIndex Zero based plane index. + * @return Quantization vectors of configured quantization. + * @throws IOException When reading fails. + */ + private int[][] loadPlaneQuantizationVectors(final int planeIndex) throws IOException { + ImageU16 refPlane = RawDataIO.loadImageU16(options.getInputFile(), + options.getImageDimension(), + planeIndex); + + return refPlane.toQuantizationVectors(options.getVectorDimension()); + } + + private int[][] loadConfiguredPlanesData() throws ImageCompressionException { + final int vectorSize = options.getVectorDimension().getX() * options.getVectorDimension().getY(); + int[][] trainData = null; + Stopwatch s = new Stopwatch(); + s.start(); + if (options.hasPlaneIndexSet()) { + Log("VQ: Loading single plane data."); + try { + trainData = loadPlaneQuantizationVectors(options.getPlaneIndex()); + } catch (IOException e) { + throw new ImageCompressionException("Failed to load reference image data.", e); + } + } else { + Log(options.hasPlaneRangeSet() ? "Loading plane range data." : "Loading all planes data."); + final int[] planeIndices = getPlaneIndicesForCompression(); + + final int chunkCountPerPlane = Chunk2D.calculateRequiredChunkCountPerPlane( + options.getImageDimension().toV2i(), + options.getVectorDimension()); + final int totalChunkCount = chunkCountPerPlane * planeIndices.length; + + trainData = new int[totalChunkCount][vectorSize]; + + int[][] planeVectors; + int planeCounter = 0; + for (final int planeIndex : planeIndices) { + Log("Loading plane %d vectors", planeIndex); + try { + planeVectors = loadPlaneQuantizationVectors(planeIndex); + assert (planeVectors.length == chunkCountPerPlane) : "Wrong chunk count per plane"; + } catch (IOException e) { + throw new ImageCompressionException(String.format("Failed to load plane %d image data.", + planeIndex), e); + } + + System.arraycopy(planeVectors, 0, trainData, (planeCounter * chunkCountPerPlane), chunkCountPerPlane); + ++planeCounter; + } + } + s.stop(); + Log("Quantization vector load took: " + s.getElapsedTimeString()); + return trainData; + } + @Override public void trainAndSaveCodebook() throws ImageCompressionException { - throw new ImageCompressionException("Not implemented yet"); + final int[][] trainingData = loadConfiguredPlanesData(); + + LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, codebookSize); + Log("Starting LBG optimization."); + LBGResult lbgResult = vqInitializer.findOptimalCodebook(options.isVerbose()); + Log("Learned the optimal codebook."); + + + Log("Saving cache file to %s", options.getOutputFile()); + QuantizationValueCache cache = new QuantizationValueCache(options.getOutputFile()); + try { + cache.saveQuantizationValues(options.getInputFile(), lbgResult.getCodebook()); + } catch (IOException e) { + throw new ImageCompressionException("Unable to write cache.", e); + } + Log("Operation completed."); } diff --git a/src/main/java/azgracompress/data/Chunk2D.java b/src/main/java/azgracompress/data/Chunk2D.java index d6dd97a817b4bf58a575142be93036935c1ce7bb..858c66b2fd175ce74d519a8b3f893c9b5fbfaf05 100644 --- a/src/main/java/azgracompress/data/Chunk2D.java +++ b/src/main/java/azgracompress/data/Chunk2D.java @@ -74,7 +74,7 @@ public class Chunk2D { final int xSize = dims.getX(); final int ySize = dims.getY(); - final int chunkCount = getRequiredChunkCount(chunkDims); + final int chunkCount = calculateRequiredChunkCountPerPlane(chunkDims); Chunk2D[] chunks = new Chunk2D[chunkCount]; int chunkIndex = 0; @@ -95,7 +95,7 @@ public class Chunk2D { final int chunkYSize = qVectorDims.getY(); final int chunkSize = chunkXSize * chunkYSize; - final int chunkCount = getRequiredChunkCount(qVectorDims); + final int chunkCount = calculateRequiredChunkCountPerPlane(qVectorDims); int[][] vectors = new int[chunkCount][chunkSize]; int vecIndex = 0; @@ -108,9 +108,13 @@ public class Chunk2D { return vectors; } - private int getRequiredChunkCount(final V2i chunkDims) { - final int xChunkCount = (int) Math.ceil(dims.getX() / (double) chunkDims.getX()); - final int yChunkCount = (int) Math.ceil(dims.getY() / (double) chunkDims.getY()); + private int calculateRequiredChunkCountPerPlane(final V2i chunkDims) { + return calculateRequiredChunkCountPerPlane(dims, chunkDims); + } + + public static int calculateRequiredChunkCountPerPlane(final V2i imageDims, final V2i chunkDims) { + final int xChunkCount = (int) Math.ceil(imageDims.getX() / (double) chunkDims.getX()); + final int yChunkCount = (int) Math.ceil(imageDims.getY() / (double) chunkDims.getY()); return (xChunkCount * yChunkCount); } @@ -119,7 +123,7 @@ public class Chunk2D { assert (chunks.length > 0) : "No chunks in reconstruct"; final V2i chunkDims = chunks[0].getDims(); - assert (getRequiredChunkCount(chunkDims) == chunks.length) : "Wrong chunk count in reconstruct"; + assert (calculateRequiredChunkCountPerPlane(chunkDims) == chunks.length) : "Wrong chunk count in reconstruct"; for (final Chunk2D chunk : chunks) { copyFromChunk(chunk); } @@ -265,20 +269,20 @@ public class Chunk2D { data = newData; } -// public static int[][] chunksAsImageVectors(final Chunk2D[] chunks) { -// if (chunks.length == 0) { -// return new int[0][0]; -// } -// final int vectorCount = chunks.length; -// final int vectorSize = chunks[0].data.length; -// int[][] imageVectors = new int[vectorCount][vectorSize]; -// -// for (int i = 0; i < vectorCount; i++) { -// assert (chunks[i].data.length == vectorSize); -// System.arraycopy(chunks[i].data, 0, imageVectors[i], 0, vectorSize); -// } -// return imageVectors; -// } + // public static int[][] chunksAsImageVectors(final Chunk2D[] chunks) { + // if (chunks.length == 0) { + // return new int[0][0]; + // } + // final int vectorCount = chunks.length; + // final int vectorSize = chunks[0].data.length; + // int[][] imageVectors = new int[vectorCount][vectorSize]; + // + // for (int i = 0; i < vectorCount; i++) { + // assert (chunks[i].data.length == vectorSize); + // System.arraycopy(chunks[i].data, 0, imageVectors[i], 0, vectorSize); + // } + // return imageVectors; + // } public static void updateChunkData(Chunk2D[] chunks, final int[][] newData) { assert (chunks.length == newData.length) : "chunks len newData len mismatch."; diff --git a/src/main/java/azgracompress/data/V3i.java b/src/main/java/azgracompress/data/V3i.java index 5c343be04753bd4700fd800b9c4ea9731c76ab31..db37d74d99d3ec995398639ea3431b98278158cc 100644 --- a/src/main/java/azgracompress/data/V3i.java +++ b/src/main/java/azgracompress/data/V3i.java @@ -54,4 +54,13 @@ public class V3i { public V3l toV3l() { return new V3l(x, y, z); } + + /** + * Convert this vector to V2i by dropping the Z value. + * + * @return V2i vector with X and Y values. + */ + public V2i toV2i() { + return new V2i(x, y); + } } diff --git a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java index 0c056f41504cc7a554f8635615dee854d2bc13a0..d851c6e5e2575f2614a9e9954ad17a4adc1b55ce 100644 --- a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java @@ -1,6 +1,7 @@ package azgracompress.quantization.vector; import azgracompress.U16; +import azgracompress.utilities.Stopwatch; import azgracompress.utilities.Utils; import java.util.ArrayList; @@ -15,6 +16,8 @@ public class LBGVectorQuantizer { private final int[][] trainingVectors; private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean; + boolean verbose = false; + public LBGVectorQuantizer(final int[][] trainingVectors, final int codebookSize) { assert (trainingVectors.length > 0) : "No training vectors provided"; @@ -28,12 +31,13 @@ public class LBGVectorQuantizer { return findOptimalCodebook(true); } - public LBGResult findOptimalCodebook(boolean verbose) { - ArrayList<LearningCodebookEntry> codebook = initializeCodebook(verbose); + public LBGResult findOptimalCodebook(boolean isVerbose) { + this.verbose = isVerbose; + ArrayList<LearningCodebookEntry> codebook = initializeCodebook(); if (verbose) { System.out.println("Got initial codebook. Improving codebook..."); } - LBG(codebook, EPSILON * 0.01, verbose); + LBG(codebook, EPSILON * 0.01); final double finalMse = averageMse(codebook); final double psnr = Utils.calculatePsnr(finalMse, U16.Max); if (verbose) { @@ -100,7 +104,7 @@ public class LBGVectorQuantizer { } - private ArrayList<LearningCodebookEntry> initializeCodebook(final boolean verbose) { + private ArrayList<LearningCodebookEntry> initializeCodebook() { ArrayList<LearningCodebookEntry> codebook = new ArrayList<>(codebookSize); // Initialize first codebook entry as average of training vectors int k = 1; @@ -182,7 +186,7 @@ public class LBGVectorQuantizer { if (verbose) { System.out.println("Improving current codebook..."); } - LBG(codebook, verbose); + LBG(codebook); if (verbose) { @@ -194,11 +198,13 @@ public class LBGVectorQuantizer { } - private void LBG(ArrayList<LearningCodebookEntry> codebook, final boolean verbose) { - LBG(codebook, EPSILON, verbose); + private void LBG(ArrayList<LearningCodebookEntry> codebook) { + LBG(codebook, EPSILON); } - private void LBG(ArrayList<LearningCodebookEntry> codebook, final double epsilon, final boolean verbose) { + private void LBG(ArrayList<LearningCodebookEntry> codebook, final double epsilon) { + Stopwatch totalLbgFun = Stopwatch.startNew("Whole LBG function"); + codebook.forEach(entry -> { entry.clearTrainingData(); assert (entry.getTrainingVectors().size() == 0) : "Using entries which are not cleared."; @@ -207,9 +213,17 @@ public class LBGVectorQuantizer { double previousDistortion = Double.POSITIVE_INFINITY; int iteration = 1; + Stopwatch innerLoopStopwatch = new Stopwatch("LBG inner loop"); + Stopwatch findingClosestEntryStopwatch = new Stopwatch("FindingClosestEntry"); + Stopwatch distCalcStopwatch = new Stopwatch("DistortionCalc"); + Stopwatch fixEmptyStopwatch = new Stopwatch("FixEmpty"); while (true) { + System.out.println("================"); + innerLoopStopwatch.restart(); // Step 1 + // Speedup - speed the finding of the closest codebook entry. + findingClosestEntryStopwatch.restart(); for (final int[] trainingVec : trainingVectors) { double minDist = Double.POSITIVE_INFINITY; LearningCodebookEntry closestEntry = null; @@ -232,20 +246,29 @@ public class LBGVectorQuantizer { System.err.println("Did not found closest entry."); } } + findingClosestEntryStopwatch.stop(); + System.out.println(findingClosestEntryStopwatch); + fixEmptyStopwatch.restart(); fixEmptyEntries(codebook, verbose); + fixEmptyStopwatch.stop(); + System.out.println(fixEmptyStopwatch); // Step 2 + distCalcStopwatch.restart(); double avgDistortion = 0; for (LearningCodebookEntry entry : codebook) { avgDistortion += entry.getAverageDistortion(); } avgDistortion /= (double) codebook.size(); + distCalcStopwatch.stop(); + + System.out.println(distCalcStopwatch); // Step 3 double dist = (previousDistortion - avgDistortion) / avgDistortion; if (verbose) { - System.out.println(String.format("It: %d Distortion: %.5f", iteration++, dist)); + // System.out.println(String.format("It: %d Distortion: %.5f", iteration++, dist)); } if (dist < epsilon) { @@ -260,7 +283,14 @@ public class LBGVectorQuantizer { entry.clearTrainingData(); } } + innerLoopStopwatch.stop(); + + System.out.println(innerLoopStopwatch); + System.out.println("================"); } + + totalLbgFun.stop(); + System.out.println(totalLbgFun); } @@ -324,6 +354,7 @@ public class LBGVectorQuantizer { biggestPartition.clearTrainingData(); newEntry.clearTrainingData(); + // Speedup - speed the look for closest entry. for (final int[] trVec : partitionVectors) { double originalPartitionDist = VectorQuantizer.distanceBetweenVectors(biggestPartition.getVector(), trVec, @@ -337,8 +368,8 @@ public class LBGVectorQuantizer { } } -// assert (biggestPartition.getTrainingVectors().size() > 0) : "Biggest partition is empty"; -// assert (newEntry.getTrainingVectors().size() > 0) : "New entry is empty"; + // assert (biggestPartition.getTrainingVectors().size() > 0) : "Biggest partition is empty"; + // assert (newEntry.getTrainingVectors().size() > 0) : "New entry is empty"; } public int getVectorSize() { diff --git a/src/main/java/azgracompress/utilities/Stopwatch.java b/src/main/java/azgracompress/utilities/Stopwatch.java index fc695d084062701fa1ea2a7a83193d8f98b23219..9554171d5497185f12727cf0422493fa89c8e7ec 100644 --- a/src/main/java/azgracompress/utilities/Stopwatch.java +++ b/src/main/java/azgracompress/utilities/Stopwatch.java @@ -1,15 +1,35 @@ package azgracompress.utilities; +import org.jetbrains.annotations.NotNull; + import java.time.Duration; import java.time.Instant; public class Stopwatch { + private final String name; private Instant start; private Instant end; Duration elapsed; + @NotNull + public static Stopwatch startNew(final String name) { + Stopwatch stopwatch = new Stopwatch(name); + stopwatch.start(); + return stopwatch; + } + + @NotNull + public static Stopwatch startNew() { + return startNew(null); + } + + public Stopwatch(final String name) { + this.name = name; + } + public Stopwatch() { + name = null; } public void start() { @@ -41,12 +61,20 @@ public class Stopwatch { if (elapsed == null) { return "No time measured yet."; } - return String.format("%dH %dmin %dsec %dms %dns", elapsed.toHoursPart(), elapsed.toMinutesPart(), elapsed.toSecondsPart(), elapsed.toMillisPart(), elapsed.toNanosPart()); + return String.format("%dH %dmin %dsec %dms %dns", + elapsed.toHoursPart(), + elapsed.toMinutesPart(), + elapsed.toSecondsPart(), + elapsed.toMillisPart(), + elapsed.toNanosPart()); } @Override public String toString() { + if (name != null) { + return name + ": " + getElapsedTimeString(); + } return getElapsedTimeString(); } }