Skip to content
Snippets Groups Projects
Commit d762d606 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Add API to train and save all codebooks in one call.

By all codebooks we mean all possible codebook sizes (ours), so 4,18,16,32,64,128,256.

The training is done in single LBG trainer, so data are loaded only once.
parent 2c6d9781
Branches
No related tags found
No related merge requests found
......@@ -37,6 +37,14 @@ public interface IImageCompressor extends IListenable {
*/
void trainAndSaveCodebook() throws ImageCompressionException;
/**
* Train all codebook sizes from selected frames and save learned codebooks to cache files.
*
* @throws ImageCompressionException when training or saving of any file fails.
*/
void trainAndSaveAllCodebooks() throws ImageCompressionException;
/**
* Preload compressor codebook and Huffman tree for stream compressor from provided cache file.
*
......
......@@ -88,6 +88,21 @@ public class ImageCompressor extends CompressorDecompressorBase {
return true;
}
public boolean trainAndSaveAllCodebooks() {
reportStatusToListeners("=== Training all codebooks ===");
if (imageCompressor == null) {
return false;
}
try {
imageCompressor.trainAndSaveAllCodebooks();
} catch (final ImageCompressionException e) {
System.err.println(e.getMessage());
e.printStackTrace();
return false;
}
return true;
}
public int streamCompressChunk(final OutputStream outputStream, final InputData inputData) {
assert (imageCompressor != null);
......
......@@ -246,4 +246,9 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
public long[] compressStreamChunk(final DataOutputStream compressStream, final InputData inputData) throws ImageCompressionException {
throw new ImageCompressionException("Not implemented yet");
}
@Override
public void trainAndSaveAllCodebooks() throws ImageCompressionException {
throw new ImageCompressionException("Not implemented yet");
}
}
......@@ -219,53 +219,6 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
return planeDataSizes;
}
@Override
public void trainAndSaveCodebook() throws ImageCompressionException {
reportStatusToListeners("Loading image data...");
final IPlaneLoader planeLoader;
try {
planeLoader = PlaneLoaderFactory.getPlaneLoaderForInputFile(options.getInputDataInfo());
} catch (final Exception e) {
throw new ImageCompressionException("Unable to create plane reader. " + e.getMessage());
}
final int[][] trainingData;
if (options.getInputDataInfo().isPlaneIndexSet()) {
reportStatusToListeners("VQ: Loading single plane data.");
final int planeIndex = options.getInputDataInfo().getPlaneIndex();
trainingData = planeLoader.loadVectorsFromPlaneRange(options, new Range<>(planeIndex, planeIndex + 1));
} else if (options.getInputDataInfo().isPlaneRangeSet()) {
reportStatusToListeners("VQ: Loading plane range data.");
trainingData = planeLoader.loadVectorsFromPlaneRange(options, options.getInputDataInfo().getPlaneRange());
} else {
reportStatusToListeners("VQ: Loading all planes data.");
trainingData = planeLoader.loadVectorsFromPlaneRange(options,
new Range<>(0, options.getInputDataInfo().getDimensions().getZ()));
}
final LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData,
getCodebookSize(),
options.getWorkerCount(),
options.getQuantizationVector());
reportStatusToListeners("Starting LBG optimization.");
vqInitializer.setStatusListener(this::reportStatusToListeners);
final LBGResult lbgResult = vqInitializer.findOptimalCodebook();
reportStatusToListeners("Learned the optimal codebook.");
final QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder());
try {
final String cacheFilePath = cacheManager.saveCodebook(options.getInputDataInfo().getCacheFileName(), lbgResult.getCodebook());
reportStatusToListeners("Saved cache file to %s", cacheFilePath);
} catch (final IOException e) {
throw new ImageCompressionException("Unable to write VQ cache.", e);
}
reportStatusToListeners("Operation completed.");
}
/**
* Calculate the number of voxel layers needed for dataset of plane count.
*
......@@ -348,4 +301,92 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm
return voxelLayersSizes;
}
@Override
public void trainAndSaveCodebook() throws ImageCompressionException {
reportStatusToListeners("Loading image data...");
final IPlaneLoader planeLoader;
try {
planeLoader = PlaneLoaderFactory.getPlaneLoaderForInputFile(options.getInputDataInfo());
} catch (final Exception e) {
throw new ImageCompressionException("Unable to create plane reader. " + e.getMessage());
}
final int[][] trainingData = loadDataForCodebookTraining(planeLoader);
final LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData,
getCodebookSize(),
options.getWorkerCount(),
options.getQuantizationVector());
reportStatusToListeners("Starting LBG optimization.");
vqInitializer.setStatusListener(this::reportStatusToListeners);
final LBGResult lbgResult = vqInitializer.findOptimalCodebook();
reportStatusToListeners("Learned the optimal codebook.");
final QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder());
try {
final String cacheFilePath = cacheManager.saveCodebook(options.getInputDataInfo().getCacheFileName(), lbgResult.getCodebook());
reportStatusToListeners("Saved cache file to %s", cacheFilePath);
} catch (final IOException e) {
throw new ImageCompressionException("Unable to write VQ cache.", e);
}
reportStatusToListeners("Operation completed.");
}
@Override
public void trainAndSaveAllCodebooks() throws ImageCompressionException {
reportStatusToListeners("trainAndSaveAllCodebooks is starting with %d workers.", options.getWorkerCount());
reportStatusToListeners("Loading image data...");
final IPlaneLoader planeLoader;
try {
planeLoader = PlaneLoaderFactory.getPlaneLoaderForInputFile(options.getInputDataInfo());
} catch (final Exception e) {
throw new ImageCompressionException("Unable to create plane reader. " + e.getMessage());
}
final int[][] trainingData = loadDataForCodebookTraining(planeLoader);
reportStatusToListeners("Data loading is finished.");
final QuantizationCacheManager qcm = new QuantizationCacheManager(options.getCodebookCacheFolder());
final LBGVectorQuantizer codebookTrainer = new LBGVectorQuantizer(trainingData,
256,
options.getWorkerCount(),
options.getQuantizationVector());
codebookTrainer.findOptimalCodebook(vqCodebook -> {
try {
assert ((vqCodebook.getCodebookSize() == vqCodebook.getVectors().length) &&
(vqCodebook.getCodebookSize() == vqCodebook.getVectorFrequencies().length))
: "Codebook size, Vector count, Frequencies count mismatch";
qcm.saveCodebook(options.getInputDataInfo().getCacheFileName(), vqCodebook);
} catch (final IOException e) {
System.err.println("Failed to save trained codebook.");
e.printStackTrace();
}
reportStatusToListeners("Optimal codebook of size %d was found.", vqCodebook.getCodebookSize());
});
reportStatusToListeners("Trained all codebooks.");
}
int[][] loadDataForCodebookTraining(final IPlaneLoader planeLoader) throws ImageCompressionException {
final int[][] trainingData;
if (options.getInputDataInfo().isPlaneIndexSet()) {
reportStatusToListeners("VQ: Loading single plane data.");
final int planeIndex = options.getInputDataInfo().getPlaneIndex();
trainingData = planeLoader.loadVectorsFromPlaneRange(options, new Range<>(planeIndex, planeIndex + 1));
} else if (options.getInputDataInfo().isPlaneRangeSet()) {
reportStatusToListeners("VQ: Loading plane range data.");
trainingData = planeLoader.loadVectorsFromPlaneRange(options, options.getInputDataInfo().getPlaneRange());
} else {
reportStatusToListeners("VQ: Loading all planes data.");
trainingData = planeLoader.loadVectorsFromPlaneRange(options,
new Range<>(0, options.getInputDataInfo().getDimensions().getZ()));
}
return trainingData;
}
}
......@@ -28,6 +28,10 @@ public class LBGVectorQuantizer {
private IStatusListener statusListener = null;
private double _mse = 0.0;
public interface CodebookFoundCallback {
void process(final VQCodebook trainedCodebook);
}
public LBGVectorQuantizer(final int[][] vectors,
final int codebookSize,
final int workerCount,
......@@ -108,19 +112,23 @@ public class LBGVectorQuantizer {
return new LBGResult(vectorDimensions, codebook, frequencies, mse, psnr);
}
public LBGResult findOptimalCodebook() {
return findOptimalCodebook(null);
}
/**
* Find the optimal codebook of vectors, used for vector quantization.
*
* @return Result of the search.
*/
public LBGResult findOptimalCodebook() {
public LBGResult findOptimalCodebook(final CodebookFoundCallback codebookCallback) {
final Stopwatch stopwatch = Stopwatch.startNew("LBG::findOptimalCodebook()");
if (uniqueVectorCount < codebookSize) {
return createCodebookFromUniqueVectors();
}
final LearningCodebookEntry[] codebook = initializeCodebook();
final LearningCodebookEntry[] codebook = initializeCodebook(codebookCallback);
reportStatus("LBG::findOptimalCodebook() - Got initial codebook. Improving it...");
LBG(codebook, EPSILON * 0.1);
......@@ -131,7 +139,14 @@ public class LBGVectorQuantizer {
psnr);
stopwatch.stop();
reportStatus(stopwatch.toString());
return new LBGResult(vectorDimensions, learningCodebookToCodebook(codebook), frequencies, finalMse, psnr);
final LBGResult result = new LBGResult(vectorDimensions, learningCodebookToCodebook(codebook), frequencies, finalMse, psnr);
if (codebookCallback != null) {
codebookCallback.process(result.getCodebook());
}
return result;
}
/**
......@@ -304,7 +319,7 @@ public class LBGVectorQuantizer {
*
* @return The initial codebook to be improved by LBG.
*/
private LearningCodebookEntry[] initializeCodebook() {
private LearningCodebookEntry[] initializeCodebook(final CodebookFoundCallback codebookFoundCallback) {
int currentCodebookSize = 1;
LearningCodebookEntry[] codebook = new LearningCodebookEntry[]{createInitialEntry()};
......@@ -384,11 +399,22 @@ public class LBGVectorQuantizer {
currentCodebookSize *= 2;
// Execute LBG Algorithm on current codebook to improve it.
LBG(codebook);
final double eps = codebookFoundCallback == null ? EPSILON : EPSILON * 0.1;
LBG(codebook, eps);
final double avgMse = averageMse(codebook);
reportStatus("MSE of improved divided codebook: %f", avgMse);
if (codebookFoundCallback != null) {
final long[] codebookFrequencies = new long[codebook.length];
System.arraycopy(frequencies, 0, codebookFrequencies, 0, codebook.length);
codebookFoundCallback.process(new VQCodebook(vectorDimensions,
learningCodebookToCodebook(codebook),
codebookFrequencies));
}
}
return codebook;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment