Skip to content
Snippets Groups Projects
Commit 578a24ef authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Implemented scalar quantization benchmark

parent e82cc1aa
Branches
No related tags found
No related merge requests found
import compression.data.ImageU16;
import compression.benchmark.ScalarQuantizationBenchmark;
import compression.data.V3i;
import compression.de.DeException;
import compression.de.DeHistory;
import compression.de.IDESolver;
import compression.de.jade.JadeSolver;
import compression.de.shade.ILShadeSolver;
import compression.de.shade.LShadeSolver;
import compression.io.RawDataIO;
import compression.quantization.scalar.LloydMaxIteration;
import compression.quantization.scalar.LloydMaxU16ScalarQuantization;
import compression.quantization.vector.LBGResult;
import compression.quantization.vector.LBGVectorQuantizer;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.Random;
public class DataCompressor {
static int[] getRandomData(int len) {
Random r = new Random();
int[] data = new int[len];
for (int i = 0; i < data.length; i++) {
data[i] = r.nextInt(20000);
}
return data;
}
// static int[] getRandomData(int len) {
// Random r = new Random();
// int[] data = new int[len];
// for (int i = 0; i < data.length; i++) {
// data[i] = r.nextInt(20000);
// }
// return data;
// }
public static void main(String[] args) throws IOException {
ImageU16 img = null;
try {
img = RawDataIO.loadImageU16("D:\\tmp\\tiff_data\\fused_tp_10_ch_1_16bit.raw", new V3i(1041, 996, 946), 359 - 1);
RawDataIO.writeImageU16("D:\\tmp\\tiff_data\\fused_tp_10_ch_1_16bit_p359.raw", img);
} catch (Exception e) {
e.printStackTrace();
}
ScalarQuantizationBenchmark sqBenchmark = new ScalarQuantizationBenchmark(
"D:\\biology\\tiff_data\\fused_tp_10_ch_1_16bit.raw",
"D:\\biology\\benchmark\\scalar",
358,
358,
new V3i(1041,996,946));
sqBenchmark.startBenchmark();
/*
......@@ -91,7 +75,7 @@ public class DataCompressor {
//ilshade(values, Dimension, 100, 800, "iL-SHADE-2bits-800it.csv");
*/
}
/*
private static void appendLineToFile(final String fileName, final String line) {
try {
FileOutputStream os = new FileOutputStream(fileName, true);
......@@ -239,5 +223,5 @@ public class DataCompressor {
}
}
*/
}
package compression.benchmark;
import compression.U16;
import compression.data.ImageU16;
import compression.data.V3i;
import compression.de.DeException;
import compression.de.shade.ILShadeSolver;
import compression.io.RawDataIO;
import compression.quantization.QTrainIteration;
import compression.quantization.scalar.LloydMaxU16ScalarQuantization;
import compression.quantization.scalar.ScalarQuantizer;
import compression.utilities.Utils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
public class ScalarQuantizationBenchmark {
private final String inputFile;
private final String outputDirectory;
private final int fromPlaneIndex;
private final int toPlaneIndex;
private boolean useDiffEvolution = false;
final V3i rawImageDims;
public ScalarQuantizationBenchmark(final String inputFile, final String outputDirectory,
final int fromPlaneIndex, final int toPlaneIndex,
final V3i rawImageDims) {
this.inputFile = inputFile;
this.outputDirectory = outputDirectory;
this.fromPlaneIndex = fromPlaneIndex;
this.toPlaneIndex = toPlaneIndex;
this.rawImageDims = rawImageDims;
}
private short[] loadPlaneData(final int planeIndex) {
try {
ImageU16 image = RawDataIO.loadImageU16(inputFile, rawImageDims, planeIndex);
return image.getData();
} catch (Exception ex) {
ex.printStackTrace();
}
return new short[0];
}
private boolean saveQuantizedPlaneData(final short[] data, final String filename) {
ImageU16 img = new ImageU16(rawImageDims.getX(), rawImageDims.getY(), data);
try {
RawDataIO.writeImageU16(getFileNamePath(filename), img);
} catch (Exception e) {
e.printStackTrace();
return false;
}
return true;
}
private boolean saveDifference(final short[] originalData, final short[] transformedData, final String filename) {
final int[] differenceData = Utils.getAbsoluteDifference(originalData, transformedData);
try {
RawDataIO.writeDataI32(getFileNamePath(filename), differenceData);
} catch (Exception e) {
e.printStackTrace();
return false;
}
return true;
}
public void startBenchmark() {
// Test codebook sizes from 2^2 to 2^8
for (int bitCount = 2; bitCount <= 8; bitCount++) {
final int codebookSize = (int) Math.pow(2, bitCount);
System.out.println(String.format("Starting benchmark for codebook of size %d", codebookSize));
for (int planeIndex = fromPlaneIndex; planeIndex <= toPlaneIndex; planeIndex++) {
System.out.println(String.format("Loading plane %d ...", planeIndex));
final short[] planeData = loadPlaneData(planeIndex);
if (planeData.length == 0) {
System.err.println(String.format("Failed to load plane %d data. Skipping plane.", planeIndex));
continue;
}
ScalarQuantizer quantizer = null;
if (useDiffEvolution) {
quantizer = trainDifferentialEvolution(planeData, codebookSize, planeIndex);
} else {
quantizer = trainLloydMaxQuantizer(planeData, codebookSize, planeIndex);
}
if (quantizer == null) {
System.err.println("Failed to initialize scalar quantizer. Skipping plane.");
continue;
}
System.out.println("Scalar quantizer is initialized...");
final String method = useDiffEvolution ? "ilshade" : "lloyd";
final String quantizedFile = String.format("quantized_%s_plane_%d_cb_%d.raw", method, planeIndex, codebookSize);
final String absoluteDiffFile = String.format("absolute_%s_plane_%d_cb_%d.raw", method, planeIndex, codebookSize);
final short[] quantizedData = quantizer.quantize(planeData);
if (saveQuantizedPlaneData(quantizedData, quantizedFile)) {
System.out.println(String.format("Quantized plane %d data and wrote to file...", planeIndex));
} else {
System.err.println("Failed to save quantized plane.");
}
if (saveDifference(planeData, quantizedData, absoluteDiffFile)) {
System.out.println("Saved difference.");
} else {
System.err.println("Failed to save difference.s");
}
}
}
}
private String getFileNamePath(final String fileName) {
final File file = new File(outputDirectory, fileName);
return file.getAbsolutePath();
}
private ScalarQuantizer trainLloydMaxQuantizer(final short[] data, final int codebookSize, final int planeIndex) {
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize);
QTrainIteration[] trainingReport = lloydMax.train();
saveQTrainLog(getFileNamePath(String.format("lloyd_max_plane_%d_CB_%d.csv", planeIndex, codebookSize)), trainingReport);
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCentroids());
}
private ScalarQuantizer trainDifferentialEvolution(final short[] data, final int codebookSize, final int planeIndex) {
ILShadeSolver ilshade = new ILShadeSolver(codebookSize, 100, 2000, 15);
ilshade.setTrainingData(Utils.convertShortArrayToIntArray(data));
QTrainIteration[] trainingReport = null;
try {
trainingReport = ilshade.train();
} catch (DeException deEx) {
deEx.printStackTrace();
return null;
}
saveQTrainLog(getFileNamePath(String.format("il_shade_plane_%d_CB_%d.csv", planeIndex, codebookSize)), trainingReport);
return new ScalarQuantizer(U16.Min, U16.Max, ilshade.getBestSolution().getAttributes());
}
private void saveQTrainLog(final String filename, final QTrainIteration[] trainingLog) {
final String CSV_HEADER = "It;AvgMSE;BestMSE;AvgPSNR;BestPSNR";
try {
FileOutputStream fileStream = new FileOutputStream(filename);
OutputStreamWriter writer = new OutputStreamWriter(fileStream);
writer.write(CSV_HEADER);
for (final QTrainIteration it : trainingLog) {
writer.write(String.format("%d;%.5f;%.5f;%.5f;%.5f\n",
it.getIteration(),
it.getAverageMSE(),
it.getBestMSE(),
it.getAveragePSNR(),
it.getBestPSNR()));
}
writer.flush();
fileStream.flush();
fileStream.close();
} catch (IOException ioE) {
ioE.printStackTrace();
System.err.println("Failed to save QTtrain log.");
}
}
public boolean isUseDiffEvolution() {
return useDiffEvolution;
}
public void setUseDiffEvolution(boolean useDiffEvolution) {
this.useDiffEvolution = useDiffEvolution;
}
}
......@@ -6,6 +6,7 @@ import compression.utilities.Utils;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
public class RawDataIO {
/**
......@@ -44,9 +45,18 @@ public class RawDataIO {
return image;
}
public static void writeImageU16(final String rawFile, final ImageU16 image) throws Exception {
FileOutputStream fileStream = new FileOutputStream(rawFile, false);
public static void writeImageU16(final String rawFile, final ImageU16 image) throws IOException {
byte[] buffer = Utils.convertShortArrayToByteArray(image.getData());
writeBytesToFile(rawFile, buffer);
}
public static void writeDataI32(String rawFile, int[] differenceData) throws IOException {
byte[] buffer = Utils.convertIntArrayToByteArray(differenceData);
writeBytesToFile(rawFile, buffer);
}
private static void writeBytesToFile(String rawFile, byte[] buffer) throws IOException {
FileOutputStream fileStream = new FileOutputStream(rawFile, false);
fileStream.write(buffer, 0, buffer.length);
fileStream.flush();
fileStream.close();
......
......@@ -130,4 +130,38 @@ public class Utils {
}
return result;
}
public static int[] getAbsoluteDifference(final short[] original, final short[] transformed) {
assert (original.length == transformed.length) : "Array lengths doesn't match";
int[] difference = new int[original.length];
for (int i = 0; i < original.length; i++) {
difference[i] = Math.abs((int) original[i] - (int) transformed[i]);
}
return difference;
}
public static int[] getSquaredDifference(final short[] original, final short[] transformed) {
assert (original.length == transformed.length) : "Array lengths doesn't match";
int[] difference = new int[original.length];
for (int i = 0; i < original.length; i++) {
difference[i] = (int) Math.pow(((int) original[i] - (int) transformed[i]), 2);
}
return difference;
}
public static byte[] convertIntArrayToByteArray(final int[] data) {
byte[] buffer = new byte[data.length * 4];
int j = 0;
for (final int v : data) {
buffer[j++] = (byte) ((v >>> 24) & 0xFF);
buffer[j++] = (byte) ((v >>> 16) & 0xFF);
buffer[j++] = (byte) ((v >>> 8) & 0xFF);
buffer[j++] = (byte) (v & 0xFF);
}
return buffer;
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment