Newer
Older
package compression.benchmark;
import compression.U16;
import compression.data.ImageU16;
import compression.data.V3i;
import compression.de.DeException;
import compression.de.shade.ILShadeSolver;
import compression.io.RawDataIO;
import compression.quantization.QTrainIteration;
import compression.quantization.scalar.LloydMaxU16ScalarQuantization;
import compression.quantization.scalar.ScalarQuantizer;
import compression.utilities.TypeConverter;
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import compression.utilities.Utils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
public class ScalarQuantizationBenchmark {
private final String inputFile;
private final String outputDirectory;
private final int fromPlaneIndex;
private final int toPlaneIndex;
private boolean useDiffEvolution = false;
final V3i rawImageDims;
public ScalarQuantizationBenchmark(final String inputFile, final String outputDirectory,
final int fromPlaneIndex, final int toPlaneIndex,
final V3i rawImageDims) {
this.inputFile = inputFile;
this.outputDirectory = outputDirectory;
this.fromPlaneIndex = fromPlaneIndex;
this.toPlaneIndex = toPlaneIndex;
this.rawImageDims = rawImageDims;
}
private short[] loadPlaneData(final int planeIndex) {
try {
ImageU16 image = RawDataIO.loadImageU16(inputFile, rawImageDims, planeIndex);
return image.getData();
} catch (Exception ex) {
ex.printStackTrace();
}
return new short[0];
}
private boolean saveQuantizedPlaneData(final short[] data, final String filename) {
ImageU16 img = new ImageU16(rawImageDims.getX(), rawImageDims.getY(), data);
try {
RawDataIO.writeImageU16(getFileNamePath(filename), img);
} catch (Exception e) {
e.printStackTrace();
return false;
}
return true;
}
private boolean saveDifference(final short[] originalData, final short[] transformedData, final String filename) {
final int[] differenceData = Utils.getAbsoluteDifference(originalData, transformedData);
final String path = getFileNamePath(filename);
ImageU16 img = new ImageU16(rawImageDims.getX(), rawImageDims.getY(), TypeConverter.intArrayToShortArray(differenceData));
RawDataIO.writeImageU16(path, img);
System.out.println("Saved difference to: " + path);
} catch (Exception e) {
e.printStackTrace();
System.err.println("Failed to save difference.");
return false;
}
return true;
}
public void startBenchmark() {
// Test codebook sizes from 2^2 to 2^8
for (int bitCount = 8; bitCount <= 8; bitCount++) {
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
final int codebookSize = (int) Math.pow(2, bitCount);
System.out.println(String.format("Starting benchmark for codebook of size %d", codebookSize));
for (int planeIndex = fromPlaneIndex; planeIndex <= toPlaneIndex; planeIndex++) {
System.out.println(String.format("Loading plane %d ...", planeIndex));
final short[] planeData = loadPlaneData(planeIndex);
if (planeData.length == 0) {
System.err.println(String.format("Failed to load plane %d data. Skipping plane.", planeIndex));
continue;
}
ScalarQuantizer quantizer = null;
if (useDiffEvolution) {
quantizer = trainDifferentialEvolution(planeData, codebookSize, planeIndex);
} else {
quantizer = trainLloydMaxQuantizer(planeData, codebookSize, planeIndex);
}
if (quantizer == null) {
System.err.println("Failed to initialize scalar quantizer. Skipping plane.");
continue;
}
System.out.println("Scalar quantizer is initialized...");
final String method = useDiffEvolution ? "ilshade" : "lloyd";
final String quantizedFile = String.format("quantized_%s_plane_%d_cb_%d.raw", method, planeIndex, codebookSize);
final String absoluteDiffFile = String.format("absolute_%s_plane_%d_cb_%d.raw", method, planeIndex, codebookSize);
final short[] quantizedData = quantizer.quantize(planeData);
if (saveQuantizedPlaneData(quantizedData, quantizedFile)) {
System.out.println(String.format("Quantized plane %d data and wrote to file...", planeIndex));
} else {
System.err.println("Failed to save quantized plane.");
}
saveDifference(planeData, quantizedData, absoluteDiffFile);
}
}
}
private String getFileNamePath(final String fileName) {
final File file = new File(outputDirectory, fileName);
return file.getAbsolutePath();
}
private ScalarQuantizer trainLloydMaxQuantizer(final short[] data, final int codebookSize, final int planeIndex) {
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize);
QTrainIteration[] trainingReport = lloydMax.train();
saveQTrainLog(getFileNamePath(String.format("lloyd_max_plane_%d_CB_%d.csv", planeIndex, codebookSize)), trainingReport);
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCentroids());
}
private ScalarQuantizer trainDifferentialEvolution(final short[] data, final int codebookSize, final int planeIndex) {
ILShadeSolver ilshade = new ILShadeSolver(codebookSize, 100, 2000, 15);
ilshade.setTrainingData(TypeConverter.shortArrayToIntArray(data));
QTrainIteration[] trainingReport = null;
try {
trainingReport = ilshade.train();
} catch (DeException deEx) {
deEx.printStackTrace();
return null;
}
saveQTrainLog(getFileNamePath(String.format("il_shade_plane_%d_CB_%d.csv", planeIndex, codebookSize)), trainingReport);
return new ScalarQuantizer(U16.Min, U16.Max, ilshade.getBestSolution().getAttributes());
}
private void saveQTrainLog(final String filename, final QTrainIteration[] trainingLog) {
final String CSV_HEADER = "It;AvgMSE;BestMSE;AvgPSNR;BestPSNR\n";
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
try {
FileOutputStream fileStream = new FileOutputStream(filename);
OutputStreamWriter writer = new OutputStreamWriter(fileStream);
writer.write(CSV_HEADER);
for (final QTrainIteration it : trainingLog) {
writer.write(String.format("%d;%.5f;%.5f;%.5f;%.5f\n",
it.getIteration(),
it.getAverageMSE(),
it.getBestMSE(),
it.getAveragePSNR(),
it.getBestPSNR()));
}
writer.flush();
fileStream.flush();
fileStream.close();
} catch (IOException ioE) {
ioE.printStackTrace();
System.err.println("Failed to save QTtrain log.");
}
}
public boolean isUseDiffEvolution() {
return useDiffEvolution;
}
public void setUseDiffEvolution(boolean useDiffEvolution) {
this.useDiffEvolution = useDiffEvolution;
}
}