Skip to content
Snippets Groups Projects
Commit 6dee5fbc authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Added parallel MSE calculation to LloydMax.

parent c067942d
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,8 @@ import azgracompress.cli.CliConstants;
import azgracompress.cli.ParsedCliOptions;
import azgracompress.compression.ImageCompressor;
import azgracompress.compression.ImageDecompressor;
import azgracompress.quantization.QuantizationValueCache;
import azgracompress.quantization.vector.CodebookEntry;
import org.apache.commons.cli.*;
import org.jetbrains.annotations.NotNull;
......
......@@ -27,7 +27,7 @@ public class ChunkIO {
byte[] chunkData = new byte[dataLen];
buffer.get(chunkData);
chunks.add(new Chunk3D(chunkDims, chunkOffset, TypeConverter.shortBytesToIntArray(chunkData)));
chunks.add(new Chunk3D(chunkDims, chunkOffset, TypeConverter.unsignedShortBytesToIntArray(chunkData)));
}
} catch (IOException e) {
......
......@@ -4,9 +4,7 @@ import azgracompress.data.ImageU16;
import azgracompress.data.V3i;
import azgracompress.utilities.TypeConverter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.*;
public class RawDataIO {
/**
......@@ -19,34 +17,55 @@ public class RawDataIO {
*/
public static ImageU16 loadImageU16(final String rawFile,
final V3i rawDataDimension,
final int plane) throws Exception {
FileInputStream fileStream = new FileInputStream(rawFile);
final int plane) throws IOException {
final long planeSize = (long) rawDataDimension.getX() * (long) rawDataDimension.getY() * 2;
final long expectedFileSize = planeSize * rawDataDimension.getZ();
final long fileSize = fileStream.getChannel().size();
byte[] buffer;
try (FileInputStream fileStream = new FileInputStream(rawFile)) {
final long planeSize = (long) rawDataDimension.getX() * (long) rawDataDimension.getY() * 2;
final long expectedFileSize = planeSize * rawDataDimension.getZ();
final long fileSize = fileStream.getChannel().size();
if (expectedFileSize != fileSize) {
throw new Exception(
"File specified by `rawFile` doesn't contains raw data for image of dimensions `rawDataDimension`");
}
if (expectedFileSize != fileSize) {
throw new IOException(
"File specified by `rawFile` doesn't contains raw data for image of dimensions " +
"`rawDataDimension`");
}
final long planeOffset = plane * planeSize;
final long planeOffset = plane * planeSize;
byte[] buffer = new byte[(int) planeSize];
if (fileStream.skip(planeOffset) != planeOffset) {
throw new Exception("Failed to skip.");
}
if (fileStream.read(buffer, 0, (int) planeSize) != planeSize) {
throw new Exception("Read wrong number of bytes.");
buffer = new byte[(int) planeSize];
if (fileStream.skip(planeOffset) != planeOffset) {
throw new IOException("Failed to skip.");
}
if (fileStream.read(buffer, 0, (int) planeSize) != planeSize) {
throw new IOException("Read wrong number of bytes.");
}
}
fileStream.close();
return new ImageU16(rawDataDimension.getX(),
rawDataDimension.getY(),
TypeConverter.shortBytesToIntArray(buffer));
TypeConverter.unsignedShortBytesToIntArray(buffer));
}
public static int[] loadAllPlanesData(final String rawFile, final V3i imageDims) throws IOException {
final long dataSize = (long) imageDims.getX() * (long) imageDims.getY() * (long) imageDims.getZ();
int[] values = new int[(int) dataSize];
if (dataSize > (long) Integer.MAX_VALUE) {
throw new IOException("RawFile size is too big.");
}
try (FileInputStream fileStream = new FileInputStream(rawFile);
DataInputStream dis = new DataInputStream(new BufferedInputStream(fileStream, 8192))) {
for (int i = 0; i < (int) dataSize; i++) {
values[i] = dis.readUnsignedShort();
}
}
return values;
}
public static void writeImageU16(final String rawFile,
......
......@@ -2,7 +2,7 @@ package azgracompress.quantization.scalar;
import azgracompress.U16;
import azgracompress.quantization.QTrainIteration;
import azgracompress.utilities.TypeConverter;
import azgracompress.utilities.Stopwatch;
import azgracompress.utilities.Utils;
import java.util.ArrayList;
......@@ -16,13 +16,16 @@ public class LloydMaxU16ScalarQuantization {
private int[] boundaryPoints;
private double[] pdf;
public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize) {
private final int workerCount;
public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize, final int workerCount) {
trainingData = trainData;
this.codebookSize = codebookSize;
this.workerCount = workerCount;
}
public LloydMaxU16ScalarQuantization(final short[] trainData, final int codebookSize) {
this(TypeConverter.shortArrayToIntArray(trainData), codebookSize);
public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize) {
this(trainData, codebookSize,1);
}
private void initialize() {
......@@ -41,9 +44,14 @@ public class LloydMaxU16ScalarQuantization {
private void initializeProbabilityDensityFunction() {
pdf = new double[U16.Max + 1];
// Speedup
Stopwatch s = new Stopwatch();
s.start();
for (int i = 0; i < trainingData.length; i++) {
pdf[trainingData[i]] += 1;
}
s.stop();
System.out.println("Init_PDF: " + s.getElapsedTimeString());
}
private void recalculateBoundaryPoints() {
......@@ -53,14 +61,11 @@ public class LloydMaxU16ScalarQuantization {
}
private void recalculateCentroids() {
// NOTE(Moravec): We cann't create floating points in here because we are trying to quantize to integer values.
double numerator = 0.0;
double denominator = 0.0;
int lowerBound, upperBound;
// NOTE(Moravec): Leave the first centroid at zero.
for (int j = 0; j < codebookSize; j++) {
numerator = 0.0;
......@@ -75,7 +80,6 @@ public class LloydMaxU16ScalarQuantization {
}
if (denominator > 0) {
// NOTE: Maybe try ceil instead of floor.
centroids[j] = (int) Math.floor(numerator / denominator);
}
}
......@@ -92,15 +96,59 @@ public class LloydMaxU16ScalarQuantization {
private double getCurrentMse() {
double mse = 0.0;
for (int i = 0; i < trainingData.length; i++) {
int quantizedValue = quantize(trainingData[i]);
mse += Math.pow((double) trainingData[i] - (double) quantizedValue, 2);
if (workerCount > 1) {
Stopwatch s = new Stopwatch();
s.start();
// Speedup
final int workSize = trainingData.length / workerCount;
RunnableLloydMseCalc[] runnables = new RunnableLloydMseCalc[workerCount];
Thread[] workers = new Thread[workerCount];
for (int wId = 0; wId < workerCount; wId++) {
final int fromIndex = wId * workSize;
final int toIndex = (wId == workerCount - 1) ? trainingData.length : (workSize + (wId * workSize));
runnables[wId] = new RunnableLloydMseCalc(trainingData,
fromIndex,
toIndex,
centroids,
boundaryPoints,
codebookSize);
workers[wId] = new Thread(runnables[wId]);
}
for (int wId = 0; wId < workerCount; wId++) {
workers[wId].start();
}
try {
for (int wId = 0; wId < workerCount; wId++) {
workers[wId].join();
mse += runnables[wId].getMse();
}
} catch (InterruptedException e) {
e.printStackTrace();
}
s.stop();
System.out.println("\ngetCurrentMse time: " + s.getElapsedTimeString());
} else {
for (final int trainingDatum : trainingData) {
int quantizedValue = quantize(trainingDatum);
mse += Math.pow((double) trainingDatum - (double) quantizedValue, 2);
}
}
mse /= (double) trainingData.length;
return mse;
}
public QTrainIteration[] train(final boolean verbose) {
System.out.println("Data len: " + trainingData.length);
initialize();
initializeProbabilityDensityFunction();
......@@ -108,15 +156,14 @@ public class LloydMaxU16ScalarQuantization {
double currentMse = 1.0;
double psnr;
ArrayList<QTrainIteration> solutionHistory = new ArrayList<QTrainIteration>();
ArrayList<QTrainIteration> solutionHistory = new ArrayList<>();
recalculateBoundaryPoints();
recalculateCentroids();
// printCurrentConfigration();
currentMse = getCurrentMse();
psnr = Utils.calculatePsnr(currentMse, U16.Max);
if (verbose) {
System.out.println(String.format("Initial MSE: %f", currentMse));
}
......@@ -129,8 +176,6 @@ public class LloydMaxU16ScalarQuantization {
recalculateBoundaryPoints();
recalculateCentroids();
// printCurrentConfigration();
prevMse = currentMse;
currentMse = getCurrentMse();
psnr = Utils.calculatePsnr(currentMse, U16.Max);
......@@ -147,25 +192,9 @@ public class LloydMaxU16ScalarQuantization {
if (verbose) {
System.out.println("\nFinished training.");
}
// printCurrentConfigration();
return solutionHistory.toArray(new QTrainIteration[0]);
}
private void printCurrentConfigration() {
StringBuilder sb = new StringBuilder();
sb.append("Centroids: ");
for (int i = 0; i < centroids.length; i++) {
sb.append(String.format("a[%d]=%d;", i, centroids[i]));
}
sb.append("\n");
sb.append("Boundaries: ");
for (int i = 0; i < boundaryPoints.length; i++) {
sb.append(String.format("b[%d]=%d;", i, boundaryPoints[i]));
}
System.out.println(sb);
}
public int[] getCentroids() {
return centroids;
}
......
package azgracompress.quantization.scalar;
public class RunnableLloydMseCalc implements Runnable {
final int[] trainingData;
final int fromIndex;
final int toIndex;
final int[] centroids;
final int[] boundaryPoints;
final int codebookSize;
double mse = 0.0;
public RunnableLloydMseCalc(int[] trainingData, int fromIndex, int toIndex, int[] centroids, int[] boundaryPoints,
final int codebookSize) {
this.trainingData = trainingData;
this.fromIndex = fromIndex;
this.toIndex = toIndex;
this.centroids = centroids;
this.boundaryPoints = boundaryPoints;
this.codebookSize = codebookSize;
}
@Override
public void run() {
mse = 0.0;
for (int i = fromIndex; i < toIndex; i++) {
mse += Math.pow((double) trainingData[i] - (double) quantize(trainingData[i]), 2);
}
}
public double getMse() {
return mse;
}
private int quantize(final int value) {
for (int intervalId = 1; intervalId <= codebookSize; intervalId++) {
if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) {
return centroids[intervalId - 1];
}
}
throw new RuntimeException("Value couldn't be quantized!");
}
}
......@@ -8,7 +8,7 @@ public class TypeConverter {
return ((value & 0xFF00) | (value & 0x00FF));
}
public static int[] shortBytesToIntArray(final byte[] bytes) {
public static int[] unsignedShortBytesToIntArray(final byte[] bytes) {
assert (bytes.length % 2 == 0);
int[] values = new int[bytes.length / 2];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment