From ce71677316437bf19ead39293538060a80750e42 Mon Sep 17 00:00:00 2001
From: Vojtech Moravec <vojtech.moravec.st@vsb.cz>
Date: Fri, 22 Nov 2019 12:10:28 +0100
Subject: [PATCH] Caching calculated quantization values

---
 src/main/java/bdv/server/BigDataServer.java | 28 ++++++++++++++++++---
 src/main/java/bdv/server/CellHandler.java   | 10 ++++----
 2 files changed, 29 insertions(+), 9 deletions(-)

diff --git a/src/main/java/bdv/server/BigDataServer.java b/src/main/java/bdv/server/BigDataServer.java
index 76ee54a..53c43a1 100644
--- a/src/main/java/bdv/server/BigDataServer.java
+++ b/src/main/java/bdv/server/BigDataServer.java
@@ -1,6 +1,9 @@
 package bdv.server;
 
+import compression.U16;
+import compression.quantization.QuantizationValueCache;
 import compression.quantization.scalar.LloydMaxU16ScalarQuantization;
+import compression.quantization.scalar.ScalarQuantizer;
 import mpicbg.spim.data.SpimDataException;
 
 import org.apache.commons.cli.*;
@@ -16,6 +19,7 @@ import org.eclipse.jetty.server.handler.StatisticsHandler;
 import org.eclipse.jetty.util.log.Log;
 import org.eclipse.jetty.util.thread.QueuedThreadPool;
 
+import java.io.File;
 import java.io.IOException;
 import java.net.InetAddress;
 import java.net.UnknownHostException;
@@ -55,7 +59,7 @@ import java.util.Map.Entry;
 public class BigDataServer {
     private static final org.eclipse.jetty.util.log.Logger LOG = Log.getLogger(BigDataServer.class);
 
-    private static LloydMaxU16ScalarQuantization quantizer;
+    private static ScalarQuantizer quantizer;
 
     static Parameters getDefaultParameters() {
         final int port = 8080;
@@ -96,8 +100,24 @@ public class BigDataServer {
         if (compParams.shouldCompressData() || compParams.renderDifference()) {
             //TODO(Moravec): Replace LloydMaxU16ScalarQuantization with some ICompressor.
 
-            quantizer = new LloydMaxU16ScalarQuantization(compParams.getTrainFile(), compParams.getBitTarget());
-            quantizer.train(true);
+            QuantizationValueCache quantizationCache = new QuantizationValueCache("D:\\tmp\\bdv_cache");
+            final int quantizationValueCount = (int) Math.pow(2, compParams.getBitTarget());
+
+            final String trainFilename = new File(compParams.getTrainFile()).getName();
+            if (quantizationCache.areQuantizationValueCached(trainFilename, quantizationValueCount)) {
+                LOG.info("Found cached quantization values...");
+                final int[] centroids = quantizationCache.readCachedValues(trainFilename, quantizationValueCount);
+                assert (centroids.length == quantizationValueCount) : "Cache is corrupted";
+                quantizer = new ScalarQuantizer(U16.Min, U16.Max, centroids);
+                LOG.info("Initialized quantizer...");
+            } else {
+                LOG.info("Calculating quantization values...");
+                LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(compParams.getTrainFile(), compParams.getBitTarget());
+                lloydMax.train(false);
+                quantizationCache.saveQuantizationValue(trainFilename, lloydMax.getCentroids());
+                LOG.info("Saving quantization values...");
+                quantizer = new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCentroids());
+            }
         }
 
 
@@ -196,7 +216,7 @@ public class BigDataServer {
     @SuppressWarnings("static-access")
     static private Parameters processOptions(final String[] args, final Parameters defaultParameters) throws IOException {
         final String BIT_TARGET = "bits";
-        final String ENABLE_COMPRESSION = "compression";
+        final String ENABLE_COMPRESSION = "compress";
         final String ENABLE_COMPRESSION_DIFF = "diff";
         final String DUMP_FILE = "dump";
         final String TRAIN_FILE = "train";
diff --git a/src/main/java/bdv/server/CellHandler.java b/src/main/java/bdv/server/CellHandler.java
index b294558..a70d84a 100644
--- a/src/main/java/bdv/server/CellHandler.java
+++ b/src/main/java/bdv/server/CellHandler.java
@@ -10,7 +10,7 @@ import javax.imageio.ImageIO;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
-import compression.quantization.scalar.LloydMaxU16ScalarQuantization;
+import compression.quantization.scalar.ScalarQuantizer;
 import compression.utilities.Utils;
 import org.eclipse.jetty.server.Request;
 import org.eclipse.jetty.server.handler.ContextHandler;
@@ -91,11 +91,11 @@ public class CellHandler extends ContextHandler {
      */
     private final String thumbnailFilename;
     final CustomCompressionParameters compressionParams;
-    private LloydMaxU16ScalarQuantization quantizer;
+    private ScalarQuantizer quantizer;
 
     public CellHandler(final String baseUrl, final String xmlFilename, final String datasetName, final String thumbnailsDirectory,
                        final CustomCompressionParameters compressionParams,
-                       final LloydMaxU16ScalarQuantization quantizer) throws SpimDataException, IOException {
+                       final ScalarQuantizer quantizer) throws SpimDataException, IOException {
 
         final XmlIoSpimDataMinimal io = new XmlIoSpimDataMinimal();
         final SpimDataMinimal spimData = io.load(xmlFilename);
@@ -172,9 +172,9 @@ public class CellHandler extends ContextHandler {
 
                 for (int i = 0; i < data.length; i++) {
                     // Original - Compressed
-                    data[i] = Utils.u16BitsToShort(data[i]-compressedData[i]);
+                    //data[i] = Utils.u16BitsToShort(data[i]-compressedData[i]);
                     // Compressed - Original
-                    //data[i] = Utils.u16BitsToShort(compressedData[i]-data[i]);
+                    data[i] = Utils.u16BitsToShort(compressedData[i]-data[i]);
                 }
 
                 //LOG.warn("Not yet implemented.");
-- 
GitLab