diff --git a/src/main/java/azgracompress/cli/functions/MeasurePlaneErrorFunction.java b/src/main/java/azgracompress/cli/functions/MeasurePlaneErrorFunction.java index 2fac49bdcd488f1d9ecf5fe300f50660e3709239..003a9d141ada3433c59a7d6060b59753fbcf3719 100644 --- a/src/main/java/azgracompress/cli/functions/MeasurePlaneErrorFunction.java +++ b/src/main/java/azgracompress/cli/functions/MeasurePlaneErrorFunction.java @@ -17,85 +17,152 @@ public class MeasurePlaneErrorFunction extends CustomFunctionBase { super(options); } + private final String COMP_FILE_ch0 = "D:\\biology\\tiff_data\\benchmark\\fused_tp_10_ch_0_16bit.raw"; + private final String COMP_FILE_ch1 = "D:\\biology\\tiff_data\\benchmark\\fused_tp_10_ch_1_16bit.raw"; + @Override public boolean run() { -// if (reportPlaneDifference( -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_sq_cb256.raw", -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\sq_cb256_plane_log.data")) { -// return false; -// } - -// if (reportPlaneDifference( -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb128.raw", -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb128_plane_log.data")) { -// return false; -// } -// -// if (reportPlaneDifference( -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb64.raw", -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb64_plane_log.data")) { -// return false; -// } -// -// if (reportPlaneDifference( -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb32.raw", -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb32_plane_log.data")) { -// return false; -// } -// -// if (reportPlaneDifference( -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb16.raw", -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb16_plane_log.data")) { -// return false; -// } -// -// if (reportPlaneDifference( -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb8.raw", -// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb8_plane_log.data")) { -// return false; -// } -// + final int channel = 0; + assert (channel == 0 || channel == 1); + final String comp_file = channel == 0 ? COMP_FILE_ch0 : COMP_FILE_ch1; + final String method = "sq"; + final String type = "plane_codebook"; + final String folder = "D:\\biology\\tiff_data\\quantized"; + + if (reportPlaneDifference( + String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb256.raw", folder, type, channel, method), + String.format("%s\\%s\\plane_diff_ch%d\\%s_cb256_plane_log.data", folder, type, channel, method), + comp_file)) { + return false; + } + + if (reportPlaneDifference( + String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb128.raw", folder, type, channel, method), + String.format("%s\\%s\\plane_diff_ch%d\\%s_cb128_plane_log.data", folder, type, channel, method), + comp_file)) { + return false; + } + if (reportPlaneDifference( - "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_sq_cb4.raw", - "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\sq_cb4_plane_log.data")) { + String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb64.raw", folder, type, channel, method), + String.format("%s\\%s\\plane_diff_ch%d\\%s_cb64_plane_log.data", folder, type, channel, method), + comp_file)) { return false; } + if (reportPlaneDifference( + String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb32.raw", folder, type, channel, method), + String.format("%s\\%s\\plane_diff_ch%d\\%s_cb32_plane_log.data", folder, type, channel, method), + comp_file)) { + return false; + } + + if (reportPlaneDifference( + String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb16.raw", folder, type, channel, method), + String.format("%s\\%s\\plane_diff_ch%d\\%s_cb16_plane_log.data", folder, type, channel, method), + comp_file)) { + return false; + } + + if (reportPlaneDifference( + String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb8.raw", folder, type, channel, method), + String.format("%s\\%s\\plane_diff_ch%d\\%s_cb8_plane_log.data", folder, type, channel, method), + comp_file)) { + return false; + } + + if (reportPlaneDifference( + String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb4.raw", folder, type, channel, method), + String.format("%s\\%s\\plane_diff_ch%d\\%s_cb4_plane_log.data", folder, type, channel, method), + comp_file)) { + return false; + } + + // if (reportPlaneDifference( + // "D:\\biology\\tiff_data\\quantized\\plane_codebook\\fused_tp_10_ch_1_16bit_sq_cb4.raw", + // "D:\\biology\\tiff_data\\quantized\\plane_codebook\\plane_diff_ch1\\sq_cb4_plane_log.data", + // comp_file)) { + // return false; + // } + return true; } - private boolean reportPlaneDifference(final String compressedFile, final String reportFile) { - final String referenceFile = "D:\\biology\\tiff_data\\benchmark\\fused_tp_10_ch_1_16bit.raw"; - + private boolean reportPlaneDifference(final String compressedFile, final String reportFile, final String compFile) { + final String referenceFile = compFile; + final int workerCount = 8; final V3i dims = new V3i(1041, 996, 946); final int planePixelCount = dims.getX() * dims.getY(); System.out.println(options.report()); System.out.println("Run custom function."); - ImageU16 compressedPlane = null; - ImageU16 originalPlane = null; - ImageU16 differencePlane = null; + // ImageU16 compressedPlane = null; + // ImageU16 originalPlane = null; + // ImageU16 differencePlane = null; PlaneError[] planeErrors = new PlaneError[dims.getZ()]; - for (int planeIndex = 0; planeIndex < dims.getZ(); planeIndex++) { - try { - originalPlane = RawDataIO.loadImageU16(referenceFile, dims, planeIndex); - compressedPlane = RawDataIO.loadImageU16(compressedFile, dims, planeIndex); - } catch (IOException e) { - e.printStackTrace(); - return true; - } - final int[] diffData = Utils.getDifference(originalPlane.getData(), compressedPlane.getData()); - Utils.applyAbsFunction(diffData); + Thread[] workers = new Thread[workerCount]; + final int workSize = dims.getZ() / workerCount; + + for (int wId = 0; wId < workerCount; wId++) { + final int fromIndex = wId * workSize; + final int toIndex = (wId == workerCount - 1) ? dims.getZ() : (workSize + (wId * workSize)); + workers[wId] = new Thread(() -> { - final double absDiffSum = Arrays.stream(diffData).mapToDouble(v -> v).sum(); - final double meanPixelError = absDiffSum / (double) planePixelCount; + ImageU16 originalPlane, compressedPlane, differencePlane; + for (int planeIndex = fromIndex; planeIndex < toIndex; planeIndex++) { + try { + originalPlane = RawDataIO.loadImageU16(referenceFile, dims, planeIndex); + compressedPlane = RawDataIO.loadImageU16(compressedFile, dims, planeIndex); + } catch (IOException e) { + e.printStackTrace(); + break; + } - planeErrors[planeIndex] = new PlaneError(planeIndex, absDiffSum, meanPixelError); - // System.out.println("Finished plane: " + planeIndex); + + final int[] diffData = Utils.getDifference(originalPlane.getData(), compressedPlane.getData()); + Utils.applyAbsFunction(diffData); + + final double absDiffSum = Arrays.stream(diffData).mapToDouble(v -> v).sum(); + final double meanPixelError = absDiffSum / (double) planePixelCount; + + planeErrors[planeIndex] = new PlaneError(planeIndex, absDiffSum, meanPixelError); + } + }); + + workers[wId].start(); } + try { + for (int wId = 0; wId < workerCount; wId++) { + workers[wId].join(); + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + + + // for (int planeIndex = 0; planeIndex < dims.getZ(); planeIndex++) { + // try { + // originalPlane = RawDataIO.loadImageU16(referenceFile, dims, planeIndex); + // compressedPlane = RawDataIO.loadImageU16(compressedFile, dims, planeIndex); + // } catch (IOException e) { + // e.printStackTrace(); + // return true; + // } + // + // + // final int[] diffData = Utils.getDifference(originalPlane.getData(), compressedPlane.getData()); + // Utils.applyAbsFunction(diffData); + // + // + // final double absDiffSum = Arrays.stream(diffData).mapToDouble(v -> v).sum(); + // final double meanPixelError = absDiffSum / (double) planePixelCount; + // + // planeErrors[planeIndex] = new PlaneError(planeIndex, absDiffSum, meanPixelError); + // // System.out.println("Finished plane: " + planeIndex); + // } try (FileOutputStream fos = new FileOutputStream(reportFile, false); OutputStreamWriter writer = new OutputStreamWriter(fos)) { diff --git a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java index dc463c38a3f4d999cf272f8b680635ff9a4dd6cd..681823e09d7334190814598b9bb2e485434d6287 100644 --- a/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java +++ b/src/main/java/azgracompress/quantization/scalar/LloydMaxU16ScalarQuantization.java @@ -2,6 +2,7 @@ package azgracompress.quantization.scalar; import azgracompress.U16; import azgracompress.quantization.QTrainIteration; +import azgracompress.utilities.MinMaxResult; import azgracompress.utilities.Stopwatch; import azgracompress.utilities.Utils; @@ -12,6 +13,9 @@ public class LloydMaxU16ScalarQuantization { private final int[] trainingData; private int codebookSize; + private int dataMin; + private int dataMax; + private int dataSpan; private int[] centroids; private int[] boundaryPoints; private double[] pdf; @@ -32,13 +36,18 @@ public class LloydMaxU16ScalarQuantization { private void initialize() { centroids = new int[codebookSize]; - centroids[0] = 0; + boundaryPoints = new int[codebookSize + 1]; - boundaryPoints[0] = U16.Min; - boundaryPoints[codebookSize] = U16.Max; + MinMaxResult<Integer> minMax = Utils.getMinAndMax(trainingData); + dataMin = minMax.getMin(); + dataMax = minMax.getMax(); + dataSpan = dataMax - dataMin; + centroids[0] = dataMin; - double intervalSize = (double) (U16.Max - U16.Min) / (double) codebookSize; + boundaryPoints[0] = dataMin; + boundaryPoints[codebookSize] = dataMax; + double intervalSize = (double) (dataSpan) / (double) codebookSize; for (int i = 0; i < codebookSize; i++) { centroids[i] = (int) Math.floor(((double) i + 0.5) * intervalSize); } @@ -49,9 +58,11 @@ public class LloydMaxU16ScalarQuantization { // Speedup - for now it is fast enough Stopwatch s = new Stopwatch(); s.start(); + for (int i = 0; i < trainingData.length; i++) { - pdf[trainingData[i]] += 1; + pdf[trainingData[i]] += 1.0; } + s.stop(); if (verbose) { System.out.println("Init_PDF: " + s.getElapsedTimeString()); @@ -60,7 +71,24 @@ public class LloydMaxU16ScalarQuantization { private void recalculateBoundaryPoints() { for (int j = 1; j < codebookSize; j++) { - boundaryPoints[j] = (centroids[j] + centroids[j - 1]) / 2; + boundaryPoints[j] = Math.min(dataMax, + (int) Math.floor(((double) centroids[j] + (double) centroids[j - 1]) / 2.0)); + } + } + + private void initializeCentroids() { + int lowerBound, upperBound; + double[] centroidPdf = new double[codebookSize]; + for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) { + lowerBound = boundaryPoints[centroidIndex]; + upperBound = boundaryPoints[centroidIndex + 1]; + + for (int rangeValue = lowerBound; rangeValue <= upperBound; rangeValue++) { + if (pdf[rangeValue] > centroidPdf[centroidIndex]) { + centroidPdf[centroidIndex] = pdf[rangeValue]; + centroids[centroidIndex] = rangeValue; + } + } } } @@ -70,13 +98,13 @@ public class LloydMaxU16ScalarQuantization { int lowerBound, upperBound; - for (int j = 0; j < codebookSize; j++) { + for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) { numerator = 0.0; denominator = 0.0; - lowerBound = boundaryPoints[j]; - upperBound = boundaryPoints[j + 1]; + lowerBound = boundaryPoints[centroidIndex]; + upperBound = boundaryPoints[centroidIndex + 1]; for (int n = lowerBound; n <= upperBound; n++) { numerator += (double) n * pdf[n]; @@ -84,7 +112,7 @@ public class LloydMaxU16ScalarQuantization { } if (denominator > 0) { - centroids[j] = (int) Math.floor(numerator / denominator); + centroids[centroidIndex] = (int) Math.floor(numerator / denominator); } } } @@ -98,6 +126,16 @@ public class LloydMaxU16ScalarQuantization { throw new RuntimeException("Value couldn't be quantized!"); } + private double calculateMAE() { + double mae = 0.0; + for (final int trainingDatum : trainingData) { + int quantizedValue = quantize(trainingDatum); + mae += Math.abs((double) trainingDatum - (double) quantizedValue); + } + return (mae / (double) trainingData.length); + } + + private double getCurrentMse() { double mse = 0.0; @@ -149,6 +187,8 @@ public class LloydMaxU16ScalarQuantization { public QTrainIteration[] train(final boolean shouldBeVerbose) { this.verbose = shouldBeVerbose; final int RECALCULATE_N_TIMES = 10; + final int PATIENCE = 1; + int noImprovementCounter = 0; if (verbose) { System.out.println("Training data count: " + trainingData.length); } @@ -156,6 +196,8 @@ public class LloydMaxU16ScalarQuantization { initialize(); initializeProbabilityDensityFunction(); + + double currMAE = 1.0; double prevMse = 1.0; double currentMse = 1.0; double psnr; @@ -163,7 +205,8 @@ public class LloydMaxU16ScalarQuantization { ArrayList<QTrainIteration> solutionHistory = new ArrayList<>(); recalculateBoundaryPoints(); - recalculateCentroids(); + // recalculateCentroids(); + initializeCentroids(); currentMse = getCurrentMse(); psnr = Utils.calculatePsnr(currentMse, U16.Max); @@ -174,27 +217,59 @@ public class LloydMaxU16ScalarQuantization { solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr)); - double dist = 1; + double mseImprovement = 1; int iteration = 0; do { for (int i = 0; i < RECALCULATE_N_TIMES; i++) { recalculateBoundaryPoints(); recalculateCentroids(); } + + // TODO(Moravec): Check if we are improving MSE. + // Save the best centroids, the lowest MSE. + + currMAE = calculateMAE(); + prevMse = currentMse; currentMse = getCurrentMse(); + mseImprovement = prevMse - currentMse; + + // System.out.println(String.format("Improvement: %.4f", mseImprovement)); + + // if ((prevMAE < currMAE) && (iteration != 0)) { + // System.err.println(String.format( + // "MAE = +%.5f", + // currMAE - prevMAE)); + // } + psnr = Utils.calculatePsnr(currentMse, U16.Max); solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr)); - dist = (prevMse - currentMse) / currentMse; + // dist = (prevMse - currentMse) / currentMse; if (verbose) { - System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB", currentMse, psnr)); + System.out.println(String.format("Current MAE: %.4f MSE: %.4f PSNR: %.4f dB", + currMAE, + currentMse, + psnr)); } - } while (dist > 0.001); //0.0005 + // if (mseImprovement < 1.0 && mseImprovement > 0.0005) { + // System.out.println("----- low improvement " + mseImprovement); + // } + + if (mseImprovement < 1.0) { + if ((++noImprovementCounter) >= PATIENCE) { + break; + } + + } + + + } while (true); //0.001 //0.0005// || currMAE > 1500//mseImprovement > 0.0005 if (verbose) { System.out.println("\nFinished training."); } + System.out.println(String.format("Final MAE: %.4f after %d iterations", currMAE, iteration)); return solutionHistory.toArray(new QTrainIteration[0]); } diff --git a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java index 283947c9a73d588459f3e89b74f9bd51ba6359ca..4adf3f0599be3ecae5b00ee40f5cb7fa19084ea3 100644 --- a/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/LBGVectorQuantizer.java @@ -24,12 +24,12 @@ public class LBGVectorQuantizer { assert (vectors.length > 0) : "No training vectors provided"; this.vectorSize = vectors[0].length; - final int[][] vectorsCopy = new int[vectors.length][vectorSize]; - System.arraycopy(vectors, 0, vectorsCopy, 0, vectors.length); +// final int[][] vectorsCopy = new int[vectors.length][vectorSize]; +// System.arraycopy(vectors, 0, vectorsCopy, 0, vectors.length); this.trainingVectors = new TrainingVector[vectors.length]; - for (int i = 0; i < vectorsCopy.length; i++) { - trainingVectors[i] = new TrainingVector(vectorsCopy[i]); + for (int i = 0; i < vectors.length; i++) { + trainingVectors[i] = new TrainingVector(Arrays.copyOf(vectors[i],vectors[i].length)); } this.codebookSize = codebookSize; diff --git a/src/main/java/azgracompress/utilities/MinMaxResult.java b/src/main/java/azgracompress/utilities/MinMaxResult.java new file mode 100644 index 0000000000000000000000000000000000000000..5a37e757d13c921fe9b301b9e9f8189742bd46f4 --- /dev/null +++ b/src/main/java/azgracompress/utilities/MinMaxResult.java @@ -0,0 +1,19 @@ +package azgracompress.utilities; + +public class MinMaxResult<T> { + private final T min; + private final T max; + + MinMaxResult(T min, T max) { + this.min = min; + this.max = max; + } + + public T getMin() { + return min; + } + + public T getMax() { + return max; + } +} \ No newline at end of file diff --git a/src/main/java/azgracompress/utilities/Utils.java b/src/main/java/azgracompress/utilities/Utils.java index af619162e4b6903bc1136377492bd21b5993509a..a249eea5893dd0609e88cb4d0395b92677567c69 100644 --- a/src/main/java/azgracompress/utilities/Utils.java +++ b/src/main/java/azgracompress/utilities/Utils.java @@ -5,6 +5,7 @@ import java.io.FileNotFoundException; import java.io.IOException; import java.util.ArrayList; + public class Utils { public static double calculatePsnr(final double mse, final int signalMax) { @@ -83,6 +84,22 @@ public class Utils { } + public static MinMaxResult<Integer> getMinAndMax(final int[] data) { + int min = Integer.MAX_VALUE; + int max = Integer.MIN_VALUE; + + for (int i = 0; i < data.length; i++) { + if (data[i] < min) { + min = data[i]; + } + if (data[i] > max) { + max = data[i]; + } + } + return new MinMaxResult<Integer>(min, max); + } + + public static double calculateMse(final int[] difference) { double sum = 0.0; for (final int val : difference) {