Skip to content
Snippets Groups Projects
Commit e9141d5d authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Use BenchmarkBase in Scalar benchmark

parent ca373387
No related branches found
No related tags found
No related merge requests found
package compression.benchmark; package compression.benchmark;
import compression.U16; import compression.U16;
import compression.data.ImageU16;
import compression.data.V3i; import compression.data.V3i;
import compression.de.DeException; import compression.de.DeException;
import compression.de.shade.ILShadeSolver; import compression.de.shade.ILShadeSolver;
...@@ -10,81 +9,19 @@ import compression.quantization.QTrainIteration; ...@@ -10,81 +9,19 @@ import compression.quantization.QTrainIteration;
import compression.quantization.scalar.LloydMaxU16ScalarQuantization; import compression.quantization.scalar.LloydMaxU16ScalarQuantization;
import compression.quantization.scalar.ScalarQuantizer; import compression.quantization.scalar.ScalarQuantizer;
import compression.utilities.TypeConverter; import compression.utilities.TypeConverter;
import compression.utilities.Utils;
import java.io.File; public class ScalarQuantizationBenchmark extends BenchmarkBase {
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
public class ScalarQuantizationBenchmark {
private final String inputFile;
private final String outputDirectory;
private final int[] planes;
private boolean useDiffEvolution = false; private boolean useDiffEvolution = false;
final V3i rawImageDims;
public ScalarQuantizationBenchmark(final String inputFile, public ScalarQuantizationBenchmark(final String inputFile,
final String outputDirectory, final String outputDirectory,
final int[] planes, final int[] planes,
final V3i rawImageDims) { final V3i rawImageDims) {
this.inputFile = inputFile; super(inputFile, outputDirectory, planes, rawImageDims);
this.outputDirectory = outputDirectory;
this.planes = planes;
this.rawImageDims = rawImageDims;
}
private short[] loadPlaneData(final int planeIndex) {
try {
ImageU16 image = RawDataIO.loadImageU16(inputFile, rawImageDims, planeIndex);
return image.getData();
} catch (Exception ex) {
ex.printStackTrace();
}
return new short[0];
} }
private boolean saveQuantizedPlaneData(final short[] data, final String filename) {
ImageU16 img = new ImageU16(rawImageDims.getX(), rawImageDims.getY(), data);
try {
// NOTE(Moravec): Use big endian so that FIJI can read the image.
RawDataIO.writeImageU16(getFileNamePath(filename), img, false);
System.out.println(String.format("Saved %s", filename));
} catch (Exception e) {
e.printStackTrace();
return false;
}
return true;
}
private boolean saveDifference(final short[] originalData,
final short[] transformedData,
final String diffFile,
final String absDiffFile) {
final int[] differenceData = Utils.getDifference(originalData, transformedData);
final int[] absDifferenceData = Utils.applyAbsToValues(differenceData);
final String diffFilePath = getFileNamePath(diffFile);
final String absDiffFilePath = getFileNamePath(absDiffFile);
ImageU16 img = new ImageU16(rawImageDims.getX(),
rawImageDims.getY(),
TypeConverter.intArrayToShortArray(absDifferenceData));
try {
// NOTE(Moravec): Use little endian so that gnuplot can read the array.
RawDataIO.writeImageU16(absDiffFilePath, img, true);
System.out.println("Saved absolute difference to: " + absDiffFilePath);
RawDataIO.writeDataI32(diffFilePath, differenceData, true);
System.out.println("Saved difference to: " + absDiffFilePath);
} catch (Exception e) {
e.printStackTrace();
System.err.println("Failed to save difference.");
return false;
}
return true;
}
@Override
public void startBenchmark() { public void startBenchmark() {
for (final int planeIndex : planes) { for (final int planeIndex : planes) {
...@@ -114,10 +51,10 @@ public class ScalarQuantizationBenchmark { ...@@ -114,10 +51,10 @@ public class ScalarQuantizationBenchmark {
System.out.println("Scalar quantizer ready."); System.out.println("Scalar quantizer ready.");
final String method = useDiffEvolution ? "ilshade" : "lloyd"; final String method = useDiffEvolution ? "ilshade" : "lloyd";
final String centroidsFile = getFileNamePath(String.format("p%d_cb%d%s_centroids.raw", final String centroidsFile = getFileNamePathIntoOutDir(String.format("p%d_cb%d%s_centroids.raw",
(planeIndex + 1), (planeIndex + 1),
codebookSize, codebookSize,
method)); method));
// NOTE(Moravec): Centroids are saved in little endian order. // NOTE(Moravec): Centroids are saved in little endian order.
if (!RawDataIO.writeDataI32(centroidsFile, quantizer.getCentroids(), true)) { if (!RawDataIO.writeDataI32(centroidsFile, quantizer.getCentroids(), true)) {
...@@ -142,16 +79,11 @@ public class ScalarQuantizationBenchmark { ...@@ -142,16 +79,11 @@ public class ScalarQuantizationBenchmark {
} }
} }
private String getFileNamePath(final String fileName) {
final File file = new File(outputDirectory, fileName);
return file.getAbsolutePath();
}
private ScalarQuantizer trainLloydMaxQuantizer(final short[] data, final int codebookSize, final int planeIndex) { private ScalarQuantizer trainLloydMaxQuantizer(final short[] data, final int codebookSize, final int planeIndex) {
LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize); LloydMaxU16ScalarQuantization lloydMax = new LloydMaxU16ScalarQuantization(data, codebookSize);
QTrainIteration[] trainingReport = lloydMax.train(); QTrainIteration[] trainingReport = lloydMax.train();
saveQTrainLog(getFileNamePath(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize)), trainingReport); saveQTrainLog(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize), trainingReport);
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCentroids()); return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCentroids());
} }
...@@ -169,35 +101,10 @@ public class ScalarQuantizationBenchmark { ...@@ -169,35 +101,10 @@ public class ScalarQuantizationBenchmark {
deEx.printStackTrace(); deEx.printStackTrace();
return null; return null;
} }
saveQTrainLog(getFileNamePath(String.format("p%d_cb_%d_il_shade.csv", planeIndex, codebookSize)), saveQTrainLog(String.format("p%d_cb_%d_il_shade.csv", planeIndex, codebookSize), trainingReport);
trainingReport);
return new ScalarQuantizer(U16.Min, U16.Max, ilshade.getBestSolution().getAttributes()); return new ScalarQuantizer(U16.Min, U16.Max, ilshade.getBestSolution().getAttributes());
} }
private void saveQTrainLog(final String filename, final QTrainIteration[] trainingLog) {
final String CSV_HEADER = "It;AvgMSE;BestMSE;AvgPSNR;BestPSNR\n";
try {
FileOutputStream fileStream = new FileOutputStream(filename);
OutputStreamWriter writer = new OutputStreamWriter(fileStream);
writer.write(CSV_HEADER);
for (final QTrainIteration it : trainingLog) {
writer.write(String.format("%d;%.5f;%.5f;%.5f;%.5f\n",
it.getIteration(),
it.getAverageMSE(),
it.getBestMSE(),
it.getAveragePSNR(),
it.getBestPSNR()));
}
writer.flush();
fileStream.flush();
fileStream.close();
} catch (IOException ioE) {
ioE.printStackTrace();
System.err.println("Failed to save QTtrain log.");
}
}
public boolean isUseDiffEvolution() { public boolean isUseDiffEvolution() {
return useDiffEvolution; return useDiffEvolution;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment