Skip to content
Snippets Groups Projects
Commit b8c0c186 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Merge branch 'cli_app'

parents b2fe96a2 9c7f6340
No related branches found
No related tags found
No related merge requests found
...@@ -17,85 +17,152 @@ public class MeasurePlaneErrorFunction extends CustomFunctionBase { ...@@ -17,85 +17,152 @@ public class MeasurePlaneErrorFunction extends CustomFunctionBase {
super(options); super(options);
} }
private final String COMP_FILE_ch0 = "D:\\biology\\tiff_data\\benchmark\\fused_tp_10_ch_0_16bit.raw";
private final String COMP_FILE_ch1 = "D:\\biology\\tiff_data\\benchmark\\fused_tp_10_ch_1_16bit.raw";
@Override @Override
public boolean run() { public boolean run() {
// if (reportPlaneDifference( final int channel = 0;
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_sq_cb256.raw", assert (channel == 0 || channel == 1);
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\sq_cb256_plane_log.data")) { final String comp_file = channel == 0 ? COMP_FILE_ch0 : COMP_FILE_ch1;
// return false; final String method = "sq";
// } final String type = "plane_codebook";
final String folder = "D:\\biology\\tiff_data\\quantized";
// if (reportPlaneDifference(
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb128.raw", if (reportPlaneDifference(
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb128_plane_log.data")) { String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb256.raw", folder, type, channel, method),
// return false; String.format("%s\\%s\\plane_diff_ch%d\\%s_cb256_plane_log.data", folder, type, channel, method),
// } comp_file)) {
// return false;
// if (reportPlaneDifference( }
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb64.raw",
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb64_plane_log.data")) { if (reportPlaneDifference(
// return false; String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb128.raw", folder, type, channel, method),
// } String.format("%s\\%s\\plane_diff_ch%d\\%s_cb128_plane_log.data", folder, type, channel, method),
// comp_file)) {
// if (reportPlaneDifference( return false;
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb32.raw", }
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb32_plane_log.data")) {
// return false;
// }
//
// if (reportPlaneDifference(
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb16.raw",
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb16_plane_log.data")) {
// return false;
// }
//
// if (reportPlaneDifference(
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_vq3x3_cb8.raw",
// "D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\vq3x3_cb8_plane_log.data")) {
// return false;
// }
//
if (reportPlaneDifference( if (reportPlaneDifference(
"D:\\biology\\tiff_data\\quantized\\middle_frame\\fused_tp_10_ch_1_16bit_sq_cb4.raw", String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb64.raw", folder, type, channel, method),
"D:\\biology\\tiff_data\\quantized\\middle_frame\\plane_diff_ch1\\sq_cb4_plane_log.data")) { String.format("%s\\%s\\plane_diff_ch%d\\%s_cb64_plane_log.data", folder, type, channel, method),
comp_file)) {
return false; return false;
} }
if (reportPlaneDifference(
String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb32.raw", folder, type, channel, method),
String.format("%s\\%s\\plane_diff_ch%d\\%s_cb32_plane_log.data", folder, type, channel, method),
comp_file)) {
return false;
}
if (reportPlaneDifference(
String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb16.raw", folder, type, channel, method),
String.format("%s\\%s\\plane_diff_ch%d\\%s_cb16_plane_log.data", folder, type, channel, method),
comp_file)) {
return false;
}
if (reportPlaneDifference(
String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb8.raw", folder, type, channel, method),
String.format("%s\\%s\\plane_diff_ch%d\\%s_cb8_plane_log.data", folder, type, channel, method),
comp_file)) {
return false;
}
if (reportPlaneDifference(
String.format("%s\\%s\\fused_tp_10_ch_%d_16bit_%s_cb4.raw", folder, type, channel, method),
String.format("%s\\%s\\plane_diff_ch%d\\%s_cb4_plane_log.data", folder, type, channel, method),
comp_file)) {
return false;
}
// if (reportPlaneDifference(
// "D:\\biology\\tiff_data\\quantized\\plane_codebook\\fused_tp_10_ch_1_16bit_sq_cb4.raw",
// "D:\\biology\\tiff_data\\quantized\\plane_codebook\\plane_diff_ch1\\sq_cb4_plane_log.data",
// comp_file)) {
// return false;
// }
return true; return true;
} }
private boolean reportPlaneDifference(final String compressedFile, final String reportFile) { private boolean reportPlaneDifference(final String compressedFile, final String reportFile, final String compFile) {
final String referenceFile = "D:\\biology\\tiff_data\\benchmark\\fused_tp_10_ch_1_16bit.raw"; final String referenceFile = compFile;
final int workerCount = 8;
final V3i dims = new V3i(1041, 996, 946); final V3i dims = new V3i(1041, 996, 946);
final int planePixelCount = dims.getX() * dims.getY(); final int planePixelCount = dims.getX() * dims.getY();
System.out.println(options.report()); System.out.println(options.report());
System.out.println("Run custom function."); System.out.println("Run custom function.");
ImageU16 compressedPlane = null; // ImageU16 compressedPlane = null;
ImageU16 originalPlane = null; // ImageU16 originalPlane = null;
ImageU16 differencePlane = null; // ImageU16 differencePlane = null;
PlaneError[] planeErrors = new PlaneError[dims.getZ()]; PlaneError[] planeErrors = new PlaneError[dims.getZ()];
for (int planeIndex = 0; planeIndex < dims.getZ(); planeIndex++) { Thread[] workers = new Thread[workerCount];
try { final int workSize = dims.getZ() / workerCount;
originalPlane = RawDataIO.loadImageU16(referenceFile, dims, planeIndex);
compressedPlane = RawDataIO.loadImageU16(compressedFile, dims, planeIndex); for (int wId = 0; wId < workerCount; wId++) {
} catch (IOException e) { final int fromIndex = wId * workSize;
e.printStackTrace(); final int toIndex = (wId == workerCount - 1) ? dims.getZ() : (workSize + (wId * workSize));
return true;
}
final int[] diffData = Utils.getDifference(originalPlane.getData(), compressedPlane.getData());
Utils.applyAbsFunction(diffData);
workers[wId] = new Thread(() -> {
final double absDiffSum = Arrays.stream(diffData).mapToDouble(v -> v).sum(); ImageU16 originalPlane, compressedPlane, differencePlane;
final double meanPixelError = absDiffSum / (double) planePixelCount; for (int planeIndex = fromIndex; planeIndex < toIndex; planeIndex++) {
try {
originalPlane = RawDataIO.loadImageU16(referenceFile, dims, planeIndex);
compressedPlane = RawDataIO.loadImageU16(compressedFile, dims, planeIndex);
} catch (IOException e) {
e.printStackTrace();
break;
}
planeErrors[planeIndex] = new PlaneError(planeIndex, absDiffSum, meanPixelError);
// System.out.println("Finished plane: " + planeIndex); final int[] diffData = Utils.getDifference(originalPlane.getData(), compressedPlane.getData());
Utils.applyAbsFunction(diffData);
final double absDiffSum = Arrays.stream(diffData).mapToDouble(v -> v).sum();
final double meanPixelError = absDiffSum / (double) planePixelCount;
planeErrors[planeIndex] = new PlaneError(planeIndex, absDiffSum, meanPixelError);
}
});
workers[wId].start();
} }
try {
for (int wId = 0; wId < workerCount; wId++) {
workers[wId].join();
}
} catch (InterruptedException e) {
e.printStackTrace();
}
// for (int planeIndex = 0; planeIndex < dims.getZ(); planeIndex++) {
// try {
// originalPlane = RawDataIO.loadImageU16(referenceFile, dims, planeIndex);
// compressedPlane = RawDataIO.loadImageU16(compressedFile, dims, planeIndex);
// } catch (IOException e) {
// e.printStackTrace();
// return true;
// }
//
//
// final int[] diffData = Utils.getDifference(originalPlane.getData(), compressedPlane.getData());
// Utils.applyAbsFunction(diffData);
//
//
// final double absDiffSum = Arrays.stream(diffData).mapToDouble(v -> v).sum();
// final double meanPixelError = absDiffSum / (double) planePixelCount;
//
// planeErrors[planeIndex] = new PlaneError(planeIndex, absDiffSum, meanPixelError);
// // System.out.println("Finished plane: " + planeIndex);
// }
try (FileOutputStream fos = new FileOutputStream(reportFile, false); try (FileOutputStream fos = new FileOutputStream(reportFile, false);
OutputStreamWriter writer = new OutputStreamWriter(fos)) { OutputStreamWriter writer = new OutputStreamWriter(fos)) {
......
...@@ -2,6 +2,7 @@ package azgracompress.quantization.scalar; ...@@ -2,6 +2,7 @@ package azgracompress.quantization.scalar;
import azgracompress.U16; import azgracompress.U16;
import azgracompress.quantization.QTrainIteration; import azgracompress.quantization.QTrainIteration;
import azgracompress.utilities.MinMaxResult;
import azgracompress.utilities.Stopwatch; import azgracompress.utilities.Stopwatch;
import azgracompress.utilities.Utils; import azgracompress.utilities.Utils;
...@@ -12,6 +13,9 @@ public class LloydMaxU16ScalarQuantization { ...@@ -12,6 +13,9 @@ public class LloydMaxU16ScalarQuantization {
private final int[] trainingData; private final int[] trainingData;
private int codebookSize; private int codebookSize;
private int dataMin;
private int dataMax;
private int dataSpan;
private int[] centroids; private int[] centroids;
private int[] boundaryPoints; private int[] boundaryPoints;
private double[] pdf; private double[] pdf;
...@@ -32,13 +36,18 @@ public class LloydMaxU16ScalarQuantization { ...@@ -32,13 +36,18 @@ public class LloydMaxU16ScalarQuantization {
private void initialize() { private void initialize() {
centroids = new int[codebookSize]; centroids = new int[codebookSize];
centroids[0] = 0;
boundaryPoints = new int[codebookSize + 1]; boundaryPoints = new int[codebookSize + 1];
boundaryPoints[0] = U16.Min; MinMaxResult<Integer> minMax = Utils.getMinAndMax(trainingData);
boundaryPoints[codebookSize] = U16.Max; dataMin = minMax.getMin();
dataMax = minMax.getMax();
dataSpan = dataMax - dataMin;
centroids[0] = dataMin;
double intervalSize = (double) (U16.Max - U16.Min) / (double) codebookSize; boundaryPoints[0] = dataMin;
boundaryPoints[codebookSize] = dataMax;
double intervalSize = (double) (dataSpan) / (double) codebookSize;
for (int i = 0; i < codebookSize; i++) { for (int i = 0; i < codebookSize; i++) {
centroids[i] = (int) Math.floor(((double) i + 0.5) * intervalSize); centroids[i] = (int) Math.floor(((double) i + 0.5) * intervalSize);
} }
...@@ -49,9 +58,11 @@ public class LloydMaxU16ScalarQuantization { ...@@ -49,9 +58,11 @@ public class LloydMaxU16ScalarQuantization {
// Speedup - for now it is fast enough // Speedup - for now it is fast enough
Stopwatch s = new Stopwatch(); Stopwatch s = new Stopwatch();
s.start(); s.start();
for (int i = 0; i < trainingData.length; i++) { for (int i = 0; i < trainingData.length; i++) {
pdf[trainingData[i]] += 1; pdf[trainingData[i]] += 1.0;
} }
s.stop(); s.stop();
if (verbose) { if (verbose) {
System.out.println("Init_PDF: " + s.getElapsedTimeString()); System.out.println("Init_PDF: " + s.getElapsedTimeString());
...@@ -60,7 +71,24 @@ public class LloydMaxU16ScalarQuantization { ...@@ -60,7 +71,24 @@ public class LloydMaxU16ScalarQuantization {
private void recalculateBoundaryPoints() { private void recalculateBoundaryPoints() {
for (int j = 1; j < codebookSize; j++) { for (int j = 1; j < codebookSize; j++) {
boundaryPoints[j] = (centroids[j] + centroids[j - 1]) / 2; boundaryPoints[j] = Math.min(dataMax,
(int) Math.floor(((double) centroids[j] + (double) centroids[j - 1]) / 2.0));
}
}
private void initializeCentroids() {
int lowerBound, upperBound;
double[] centroidPdf = new double[codebookSize];
for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) {
lowerBound = boundaryPoints[centroidIndex];
upperBound = boundaryPoints[centroidIndex + 1];
for (int rangeValue = lowerBound; rangeValue <= upperBound; rangeValue++) {
if (pdf[rangeValue] > centroidPdf[centroidIndex]) {
centroidPdf[centroidIndex] = pdf[rangeValue];
centroids[centroidIndex] = rangeValue;
}
}
} }
} }
...@@ -70,13 +98,13 @@ public class LloydMaxU16ScalarQuantization { ...@@ -70,13 +98,13 @@ public class LloydMaxU16ScalarQuantization {
int lowerBound, upperBound; int lowerBound, upperBound;
for (int j = 0; j < codebookSize; j++) { for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) {
numerator = 0.0; numerator = 0.0;
denominator = 0.0; denominator = 0.0;
lowerBound = boundaryPoints[j]; lowerBound = boundaryPoints[centroidIndex];
upperBound = boundaryPoints[j + 1]; upperBound = boundaryPoints[centroidIndex + 1];
for (int n = lowerBound; n <= upperBound; n++) { for (int n = lowerBound; n <= upperBound; n++) {
numerator += (double) n * pdf[n]; numerator += (double) n * pdf[n];
...@@ -84,7 +112,7 @@ public class LloydMaxU16ScalarQuantization { ...@@ -84,7 +112,7 @@ public class LloydMaxU16ScalarQuantization {
} }
if (denominator > 0) { if (denominator > 0) {
centroids[j] = (int) Math.floor(numerator / denominator); centroids[centroidIndex] = (int) Math.floor(numerator / denominator);
} }
} }
} }
...@@ -98,6 +126,16 @@ public class LloydMaxU16ScalarQuantization { ...@@ -98,6 +126,16 @@ public class LloydMaxU16ScalarQuantization {
throw new RuntimeException("Value couldn't be quantized!"); throw new RuntimeException("Value couldn't be quantized!");
} }
private double calculateMAE() {
double mae = 0.0;
for (final int trainingDatum : trainingData) {
int quantizedValue = quantize(trainingDatum);
mae += Math.abs((double) trainingDatum - (double) quantizedValue);
}
return (mae / (double) trainingData.length);
}
private double getCurrentMse() { private double getCurrentMse() {
double mse = 0.0; double mse = 0.0;
...@@ -149,6 +187,8 @@ public class LloydMaxU16ScalarQuantization { ...@@ -149,6 +187,8 @@ public class LloydMaxU16ScalarQuantization {
public QTrainIteration[] train(final boolean shouldBeVerbose) { public QTrainIteration[] train(final boolean shouldBeVerbose) {
this.verbose = shouldBeVerbose; this.verbose = shouldBeVerbose;
final int RECALCULATE_N_TIMES = 10; final int RECALCULATE_N_TIMES = 10;
final int PATIENCE = 1;
int noImprovementCounter = 0;
if (verbose) { if (verbose) {
System.out.println("Training data count: " + trainingData.length); System.out.println("Training data count: " + trainingData.length);
} }
...@@ -156,6 +196,8 @@ public class LloydMaxU16ScalarQuantization { ...@@ -156,6 +196,8 @@ public class LloydMaxU16ScalarQuantization {
initialize(); initialize();
initializeProbabilityDensityFunction(); initializeProbabilityDensityFunction();
double currMAE = 1.0;
double prevMse = 1.0; double prevMse = 1.0;
double currentMse = 1.0; double currentMse = 1.0;
double psnr; double psnr;
...@@ -163,7 +205,8 @@ public class LloydMaxU16ScalarQuantization { ...@@ -163,7 +205,8 @@ public class LloydMaxU16ScalarQuantization {
ArrayList<QTrainIteration> solutionHistory = new ArrayList<>(); ArrayList<QTrainIteration> solutionHistory = new ArrayList<>();
recalculateBoundaryPoints(); recalculateBoundaryPoints();
recalculateCentroids(); // recalculateCentroids();
initializeCentroids();
currentMse = getCurrentMse(); currentMse = getCurrentMse();
psnr = Utils.calculatePsnr(currentMse, U16.Max); psnr = Utils.calculatePsnr(currentMse, U16.Max);
...@@ -174,27 +217,59 @@ public class LloydMaxU16ScalarQuantization { ...@@ -174,27 +217,59 @@ public class LloydMaxU16ScalarQuantization {
solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr)); solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr));
double dist = 1; double mseImprovement = 1;
int iteration = 0; int iteration = 0;
do { do {
for (int i = 0; i < RECALCULATE_N_TIMES; i++) { for (int i = 0; i < RECALCULATE_N_TIMES; i++) {
recalculateBoundaryPoints(); recalculateBoundaryPoints();
recalculateCentroids(); recalculateCentroids();
} }
// TODO(Moravec): Check if we are improving MSE.
// Save the best centroids, the lowest MSE.
currMAE = calculateMAE();
prevMse = currentMse; prevMse = currentMse;
currentMse = getCurrentMse(); currentMse = getCurrentMse();
mseImprovement = prevMse - currentMse;
// System.out.println(String.format("Improvement: %.4f", mseImprovement));
// if ((prevMAE < currMAE) && (iteration != 0)) {
// System.err.println(String.format(
// "MAE = +%.5f",
// currMAE - prevMAE));
// }
psnr = Utils.calculatePsnr(currentMse, U16.Max); psnr = Utils.calculatePsnr(currentMse, U16.Max);
solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr)); solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr));
dist = (prevMse - currentMse) / currentMse; // dist = (prevMse - currentMse) / currentMse;
if (verbose) { if (verbose) {
System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB", currentMse, psnr)); System.out.println(String.format("Current MAE: %.4f MSE: %.4f PSNR: %.4f dB",
currMAE,
currentMse,
psnr));
} }
} while (dist > 0.001); //0.0005 // if (mseImprovement < 1.0 && mseImprovement > 0.0005) {
// System.out.println("----- low improvement " + mseImprovement);
// }
if (mseImprovement < 1.0) {
if ((++noImprovementCounter) >= PATIENCE) {
break;
}
}
} while (true); //0.001 //0.0005// || currMAE > 1500//mseImprovement > 0.0005
if (verbose) { if (verbose) {
System.out.println("\nFinished training."); System.out.println("\nFinished training.");
} }
System.out.println(String.format("Final MAE: %.4f after %d iterations", currMAE, iteration));
return solutionHistory.toArray(new QTrainIteration[0]); return solutionHistory.toArray(new QTrainIteration[0]);
} }
......
...@@ -24,12 +24,12 @@ public class LBGVectorQuantizer { ...@@ -24,12 +24,12 @@ public class LBGVectorQuantizer {
assert (vectors.length > 0) : "No training vectors provided"; assert (vectors.length > 0) : "No training vectors provided";
this.vectorSize = vectors[0].length; this.vectorSize = vectors[0].length;
final int[][] vectorsCopy = new int[vectors.length][vectorSize]; // final int[][] vectorsCopy = new int[vectors.length][vectorSize];
System.arraycopy(vectors, 0, vectorsCopy, 0, vectors.length); // System.arraycopy(vectors, 0, vectorsCopy, 0, vectors.length);
this.trainingVectors = new TrainingVector[vectors.length]; this.trainingVectors = new TrainingVector[vectors.length];
for (int i = 0; i < vectorsCopy.length; i++) { for (int i = 0; i < vectors.length; i++) {
trainingVectors[i] = new TrainingVector(vectorsCopy[i]); trainingVectors[i] = new TrainingVector(Arrays.copyOf(vectors[i],vectors[i].length));
} }
this.codebookSize = codebookSize; this.codebookSize = codebookSize;
......
package azgracompress.utilities;
public class MinMaxResult<T> {
private final T min;
private final T max;
MinMaxResult(T min, T max) {
this.min = min;
this.max = max;
}
public T getMin() {
return min;
}
public T getMax() {
return max;
}
}
\ No newline at end of file
...@@ -5,6 +5,7 @@ import java.io.FileNotFoundException; ...@@ -5,6 +5,7 @@ import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
public class Utils { public class Utils {
public static double calculatePsnr(final double mse, final int signalMax) { public static double calculatePsnr(final double mse, final int signalMax) {
...@@ -83,6 +84,22 @@ public class Utils { ...@@ -83,6 +84,22 @@ public class Utils {
} }
public static MinMaxResult<Integer> getMinAndMax(final int[] data) {
int min = Integer.MAX_VALUE;
int max = Integer.MIN_VALUE;
for (int i = 0; i < data.length; i++) {
if (data[i] < min) {
min = data[i];
}
if (data[i] > max) {
max = data[i];
}
}
return new MinMaxResult<Integer>(min, max);
}
public static double calculateMse(final int[] difference) { public static double calculateMse(final int[] difference) {
double sum = 0.0; double sum = 0.0;
for (final int val : difference) { for (final int val : difference) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment