From bd3bc91c471972e97a1815cf1ed925015a95198e Mon Sep 17 00:00:00 2001
From: Vojtech Moravec <vojtech.moravec.st@vsb.cz>
Date: Thu, 23 Jan 2020 16:42:18 +0100
Subject: [PATCH] Found the slow places in LBG.

---
 .../CompressorDecompressorBase.java           |  6 ++
 .../compression/VQImageCompressor.java        | 88 ++++++++++++++++---
 src/main/java/azgracompress/data/Chunk2D.java | 44 +++++-----
 src/main/java/azgracompress/data/V3i.java     |  9 ++
 .../vector/LBGVectorQuantizer.java            | 53 ++++++++---
 .../azgracompress/utilities/Stopwatch.java    | 30 ++++++-
 6 files changed, 187 insertions(+), 43 deletions(-)

diff --git a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java
index 1dd3e9b..8caad0b 100644
--- a/src/main/java/azgracompress/compression/CompressorDecompressorBase.java
+++ b/src/main/java/azgracompress/compression/CompressorDecompressorBase.java
@@ -46,6 +46,12 @@ public abstract class CompressorDecompressorBase {
         }
     }
 
+    protected void Log(final String format, final Object... args) {
+        if (options.isVerbose()) {
+            System.out.println(String.format(format, args));
+        }
+    }
+
     protected void DebugLog(final String message) {
         System.out.println(message);
     }
diff --git a/src/main/java/azgracompress/compression/VQImageCompressor.java b/src/main/java/azgracompress/compression/VQImageCompressor.java
index 3bc8ee5..3ad1972 100644
--- a/src/main/java/azgracompress/compression/VQImageCompressor.java
+++ b/src/main/java/azgracompress/compression/VQImageCompressor.java
@@ -4,9 +4,9 @@ import azgracompress.cli.ParsedCliOptions;
 import azgracompress.compression.exception.ImageCompressionException;
 import azgracompress.data.Chunk2D;
 import azgracompress.data.ImageU16;
-import azgracompress.data.V2i;
 import azgracompress.io.OutBitStream;
 import azgracompress.io.RawDataIO;
+import azgracompress.quantization.QuantizationValueCache;
 import azgracompress.quantization.vector.CodebookEntry;
 import azgracompress.quantization.vector.LBGResult;
 import azgracompress.quantization.vector.LBGVectorQuantizer;
@@ -29,15 +29,7 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
      * @return Image vectors.
      */
     private int[][] getPlaneVectors(final ImageU16 plane) {
-        final V2i qVector = options.getVectorDimension();
-
-        if (qVector.getY() > 1) {
-            // 2D Quantization, return `matrices`.
-            return Chunk2D.chunksAsImageVectors(plane.as2dChunk().divideIntoChunks(qVector));
-        } else {
-            // 1D Quantization, return row vectors.
-            return plane.as2dChunk().divideInto1DVectors(qVector.getX());
-        }
+        return plane.toQuantizationVectors(options.getVectorDimension());
     }
 
     /**
@@ -147,9 +139,83 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
         }
     }
 
+    /**
+     * Load plane and convert the plane into quantization vectors.
+     *
+     * @param planeIndex Zero based plane index.
+     * @return Quantization vectors of configured quantization.
+     * @throws IOException When reading fails.
+     */
+    private int[][] loadPlaneQuantizationVectors(final int planeIndex) throws IOException {
+        ImageU16 refPlane = RawDataIO.loadImageU16(options.getInputFile(),
+                                                   options.getImageDimension(),
+                                                   planeIndex);
+
+        return refPlane.toQuantizationVectors(options.getVectorDimension());
+    }
+
+    private int[][] loadConfiguredPlanesData() throws ImageCompressionException {
+        final int vectorSize = options.getVectorDimension().getX() * options.getVectorDimension().getY();
+        int[][] trainData = null;
+        Stopwatch s = new Stopwatch();
+        s.start();
+        if (options.hasPlaneIndexSet()) {
+            Log("VQ: Loading single plane data.");
+            try {
+                trainData = loadPlaneQuantizationVectors(options.getPlaneIndex());
+            } catch (IOException e) {
+                throw new ImageCompressionException("Failed to load reference image data.", e);
+            }
+        } else {
+            Log(options.hasPlaneRangeSet() ? "Loading plane range data." : "Loading all planes data.");
+            final int[] planeIndices = getPlaneIndicesForCompression();
+
+            final int chunkCountPerPlane = Chunk2D.calculateRequiredChunkCountPerPlane(
+                    options.getImageDimension().toV2i(),
+                    options.getVectorDimension());
+            final int totalChunkCount = chunkCountPerPlane * planeIndices.length;
+
+            trainData = new int[totalChunkCount][vectorSize];
+
+            int[][] planeVectors;
+            int planeCounter = 0;
+            for (final int planeIndex : planeIndices) {
+                Log("Loading plane %d vectors", planeIndex);
+                try {
+                    planeVectors = loadPlaneQuantizationVectors(planeIndex);
+                    assert (planeVectors.length == chunkCountPerPlane) : "Wrong chunk count per plane";
+                } catch (IOException e) {
+                    throw new ImageCompressionException(String.format("Failed to load plane %d image data.",
+                                                                      planeIndex), e);
+                }
+
+                System.arraycopy(planeVectors, 0, trainData, (planeCounter * chunkCountPerPlane), chunkCountPerPlane);
+                ++planeCounter;
+            }
+        }
+        s.stop();
+        Log("Quantization vector load took: " + s.getElapsedTimeString());
+        return trainData;
+    }
+
     @Override
     public void trainAndSaveCodebook() throws ImageCompressionException {
-        throw new ImageCompressionException("Not implemented yet");
+        final int[][] trainingData = loadConfiguredPlanesData();
+
+        LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, codebookSize);
+        Log("Starting LBG optimization.");
+        LBGResult lbgResult = vqInitializer.findOptimalCodebook(options.isVerbose());
+        Log("Learned the optimal codebook.");
+
+
+        Log("Saving cache file to %s", options.getOutputFile());
+        QuantizationValueCache cache = new QuantizationValueCache(options.getOutputFile());
+        try {
+            cache.saveQuantizationValues(options.getInputFile(), lbgResult.getCodebook());
+        } catch (IOException e) {
+            throw new ImageCompressionException("Unable to write cache.", e);
+        }
+        Log("Operation completed.");
     }
 
 
diff --git a/src/main/java/azgracompress/data/Chunk2D.java b/src/main/java/azgracompress/data/Chunk2D.java
index d6dd97a..858c66b 100644
--- a/src/main/java/azgracompress/data/Chunk2D.java
+++ b/src/main/java/azgracompress/data/Chunk2D.java
@@ -74,7 +74,7 @@ public class Chunk2D {
         final int xSize = dims.getX();
         final int ySize = dims.getY();
 
-        final int chunkCount = getRequiredChunkCount(chunkDims);
+        final int chunkCount = calculateRequiredChunkCountPerPlane(chunkDims);
 
         Chunk2D[] chunks = new Chunk2D[chunkCount];
         int chunkIndex = 0;
@@ -95,7 +95,7 @@ public class Chunk2D {
         final int chunkYSize = qVectorDims.getY();
         final int chunkSize = chunkXSize * chunkYSize;
 
-        final int chunkCount = getRequiredChunkCount(qVectorDims);
+        final int chunkCount = calculateRequiredChunkCountPerPlane(qVectorDims);
 
         int[][] vectors = new int[chunkCount][chunkSize];
         int vecIndex = 0;
@@ -108,9 +108,13 @@ public class Chunk2D {
         return vectors;
     }
 
-    private int getRequiredChunkCount(final V2i chunkDims) {
-        final int xChunkCount = (int) Math.ceil(dims.getX() / (double) chunkDims.getX());
-        final int yChunkCount = (int) Math.ceil(dims.getY() / (double) chunkDims.getY());
+    private int calculateRequiredChunkCountPerPlane(final V2i chunkDims) {
+        return calculateRequiredChunkCountPerPlane(dims, chunkDims);
+    }
+
+    public static int calculateRequiredChunkCountPerPlane(final V2i imageDims, final V2i chunkDims) {
+        final int xChunkCount = (int) Math.ceil(imageDims.getX() / (double) chunkDims.getX());
+        final int yChunkCount = (int) Math.ceil(imageDims.getY() / (double) chunkDims.getY());
 
         return (xChunkCount * yChunkCount);
     }
@@ -119,7 +123,7 @@ public class Chunk2D {
         assert (chunks.length > 0) : "No chunks in reconstruct";
         final V2i chunkDims = chunks[0].getDims();
 
-        assert (getRequiredChunkCount(chunkDims) == chunks.length) : "Wrong chunk count in reconstruct";
+        assert (calculateRequiredChunkCountPerPlane(chunkDims) == chunks.length) : "Wrong chunk count in reconstruct";
         for (final Chunk2D chunk : chunks) {
             copyFromChunk(chunk);
         }
@@ -265,20 +269,20 @@ public class Chunk2D {
         data = newData;
     }
 
-//    public static int[][] chunksAsImageVectors(final Chunk2D[] chunks) {
-//        if (chunks.length == 0) {
-//            return new int[0][0];
-//        }
-//        final int vectorCount = chunks.length;
-//        final int vectorSize = chunks[0].data.length;
-//        int[][] imageVectors = new int[vectorCount][vectorSize];
-//
-//        for (int i = 0; i < vectorCount; i++) {
-//            assert (chunks[i].data.length == vectorSize);
-//            System.arraycopy(chunks[i].data, 0, imageVectors[i], 0, vectorSize);
-//        }
-//        return imageVectors;
-//    }
+    //    public static int[][] chunksAsImageVectors(final Chunk2D[] chunks) {
+    //        if (chunks.length == 0) {
+    //            return new int[0][0];
+    //        }
+    //        final int vectorCount = chunks.length;
+    //        final int vectorSize = chunks[0].data.length;
+    //        int[][] imageVectors = new int[vectorCount][vectorSize];
+    //
+    //        for (int i = 0; i < vectorCount; i++) {
+    //            assert (chunks[i].data.length == vectorSize);
+    //            System.arraycopy(chunks[i].data, 0, imageVectors[i], 0, vectorSize);
+    //        }
+    //        return imageVectors;
+    //    }
 
     public static void updateChunkData(Chunk2D[] chunks, final int[][] newData) {
         assert (chunks.length == newData.length) : "chunks len newData len mismatch.";
diff --git a/src/main/java/azgracompress/data/V3i.java b/src/main/java/azgracompress/data/V3i.java
index 5c343be..db37d74 100644
--- a/src/main/java/azgracompress/data/V3i.java
+++ b/src/main/java/azgracompress/data/V3i.java
@@ -54,4 +54,13 @@ public class V3i {
     public V3l toV3l() {
         return new V3l(x, y, z);
     }
+
+    /**
+     * Convert this vector to V2i by dropping the Z value.
+     *
+     * @return V2i vector with X and Y values.
+     */
+    public V2i toV2i() {
+        return new V2i(x, y);
+    }
 }
diff --git a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java
index 0c056f4..d851c6e 100644
--- a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java
+++ b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java
@@ -1,6 +1,7 @@
 package azgracompress.quantization.vector;
 
 import azgracompress.U16;
+import azgracompress.utilities.Stopwatch;
 import azgracompress.utilities.Utils;
 
 import java.util.ArrayList;
@@ -15,6 +16,8 @@ public class LBGVectorQuantizer {
     private final int[][] trainingVectors;
     private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean;
 
+    boolean verbose = false;
+
     public LBGVectorQuantizer(final int[][] trainingVectors, final int codebookSize) {
 
         assert (trainingVectors.length > 0) : "No training vectors provided";
@@ -28,12 +31,13 @@ public class LBGVectorQuantizer {
         return findOptimalCodebook(true);
     }
 
-    public LBGResult findOptimalCodebook(boolean verbose) {
-        ArrayList<LearningCodebookEntry> codebook = initializeCodebook(verbose);
+    public LBGResult findOptimalCodebook(boolean isVerbose) {
+        this.verbose = isVerbose;
+        ArrayList<LearningCodebookEntry> codebook = initializeCodebook();
         if (verbose) {
             System.out.println("Got initial codebook. Improving codebook...");
         }
-        LBG(codebook, EPSILON * 0.01, verbose);
+        LBG(codebook, EPSILON * 0.01);
         final double finalMse = averageMse(codebook);
         final double psnr = Utils.calculatePsnr(finalMse, U16.Max);
         if (verbose) {
@@ -100,7 +104,7 @@ public class LBGVectorQuantizer {
     }
 
 
-    private ArrayList<LearningCodebookEntry> initializeCodebook(final boolean verbose) {
+    private ArrayList<LearningCodebookEntry> initializeCodebook() {
         ArrayList<LearningCodebookEntry> codebook = new ArrayList<>(codebookSize);
         // Initialize first codebook entry as average of training vectors
         int k = 1;
@@ -182,7 +186,7 @@ public class LBGVectorQuantizer {
             if (verbose) {
                 System.out.println("Improving current codebook...");
             }
-            LBG(codebook, verbose);
+            LBG(codebook);
 
 
             if (verbose) {
@@ -194,11 +198,13 @@ public class LBGVectorQuantizer {
     }
 
 
-    private void LBG(ArrayList<LearningCodebookEntry> codebook, final boolean verbose) {
-        LBG(codebook, EPSILON, verbose);
+    private void LBG(ArrayList<LearningCodebookEntry> codebook) {
+        LBG(codebook, EPSILON);
     }
 
-    private void LBG(ArrayList<LearningCodebookEntry> codebook, final double epsilon, final boolean verbose) {
+    private void LBG(ArrayList<LearningCodebookEntry> codebook, final double epsilon) {
+        Stopwatch totalLbgFun = Stopwatch.startNew("Whole LBG function");
+
         codebook.forEach(entry -> {
             entry.clearTrainingData();
             assert (entry.getTrainingVectors().size() == 0) : "Using entries which are not cleared.";
@@ -207,9 +213,17 @@ public class LBGVectorQuantizer {
         double previousDistortion = Double.POSITIVE_INFINITY;
 
         int iteration = 1;
+        Stopwatch innerLoopStopwatch = new Stopwatch("LBG inner loop");
+        Stopwatch findingClosestEntryStopwatch = new Stopwatch("FindingClosestEntry");
+        Stopwatch distCalcStopwatch = new Stopwatch("DistortionCalc");
+        Stopwatch fixEmptyStopwatch = new Stopwatch("FixEmpty");
         while (true) {
+            System.out.println("================");
+            innerLoopStopwatch.restart();
 
             // Step 1
+            // Speedup - speed the finding of the closest codebook entry.
+            findingClosestEntryStopwatch.restart();
             for (final int[] trainingVec : trainingVectors) {
                 double minDist = Double.POSITIVE_INFINITY;
                 LearningCodebookEntry closestEntry = null;
@@ -232,20 +246,29 @@ public class LBGVectorQuantizer {
                     System.err.println("Did not found closest entry.");
                 }
             }
+            findingClosestEntryStopwatch.stop();
+            System.out.println(findingClosestEntryStopwatch);
 
+            fixEmptyStopwatch.restart();
             fixEmptyEntries(codebook, verbose);
+            fixEmptyStopwatch.stop();
+            System.out.println(fixEmptyStopwatch);
 
             // Step 2
+            distCalcStopwatch.restart();
             double avgDistortion = 0;
             for (LearningCodebookEntry entry : codebook) {
                 avgDistortion += entry.getAverageDistortion();
             }
             avgDistortion /= (double) codebook.size();
+            distCalcStopwatch.stop();
+
+            System.out.println(distCalcStopwatch);
 
             // Step 3
             double dist = (previousDistortion - avgDistortion) / avgDistortion;
             if (verbose) {
-                System.out.println(String.format("It: %d Distortion: %.5f", iteration++, dist));
+                //                System.out.println(String.format("It: %d Distortion: %.5f", iteration++, dist));
             }
 
             if (dist < epsilon) {
@@ -260,7 +283,14 @@ public class LBGVectorQuantizer {
                     entry.clearTrainingData();
                 }
             }
+            innerLoopStopwatch.stop();
+
+            System.out.println(innerLoopStopwatch);
+            System.out.println("================");
         }
+
+        totalLbgFun.stop();
+        System.out.println(totalLbgFun);
     }
 
 
@@ -324,6 +354,7 @@ public class LBGVectorQuantizer {
         biggestPartition.clearTrainingData();
         newEntry.clearTrainingData();
 
+        // Speedup - speed the look for closest entry.
         for (final int[] trVec : partitionVectors) {
             double originalPartitionDist = VectorQuantizer.distanceBetweenVectors(biggestPartition.getVector(),
                                                                                   trVec,
@@ -337,8 +368,8 @@ public class LBGVectorQuantizer {
             }
         }
 
-//        assert (biggestPartition.getTrainingVectors().size() > 0) : "Biggest partition is empty";
-//        assert (newEntry.getTrainingVectors().size() > 0) : "New entry is empty";
+        //        assert (biggestPartition.getTrainingVectors().size() > 0) : "Biggest partition is empty";
+        //        assert (newEntry.getTrainingVectors().size() > 0) : "New entry is empty";
     }
 
     public int getVectorSize() {
diff --git a/src/main/java/azgracompress/utilities/Stopwatch.java b/src/main/java/azgracompress/utilities/Stopwatch.java
index fc695d0..9554171 100644
--- a/src/main/java/azgracompress/utilities/Stopwatch.java
+++ b/src/main/java/azgracompress/utilities/Stopwatch.java
@@ -1,15 +1,35 @@
 package azgracompress.utilities;
 
+import org.jetbrains.annotations.NotNull;
+
 import java.time.Duration;
 import java.time.Instant;
 
 public class Stopwatch {
 
+    private final String name;
     private Instant start;
     private Instant end;
     Duration elapsed;
 
+    @NotNull
+    public static Stopwatch startNew(final String name) {
+        Stopwatch stopwatch = new Stopwatch(name);
+        stopwatch.start();
+        return stopwatch;
+    }
+
+    @NotNull
+    public static Stopwatch startNew() {
+        return startNew(null);
+    }
+
+    public Stopwatch(final String name) {
+        this.name = name;
+    }
+
     public Stopwatch() {
+        name = null;
     }
 
     public void start() {
@@ -41,12 +61,20 @@ public class Stopwatch {
         if (elapsed == null) {
             return "No time measured yet.";
         }
-        return String.format("%dH %dmin %dsec %dms %dns", elapsed.toHoursPart(), elapsed.toMinutesPart(), elapsed.toSecondsPart(), elapsed.toMillisPart(), elapsed.toNanosPart());
+        return String.format("%dH %dmin %dsec %dms %dns",
+                             elapsed.toHoursPart(),
+                             elapsed.toMinutesPart(),
+                             elapsed.toSecondsPart(),
+                             elapsed.toMillisPart(),
+                             elapsed.toNanosPart());
 
     }
 
     @Override
     public String toString() {
+        if (name != null) {
+            return name + ": " + getElapsedTimeString();
+        }
         return getElapsedTimeString();
     }
 }
-- 
GitLab