From d762d60605fe4e0ba81fec3e8fc0676a583262c8 Mon Sep 17 00:00:00 2001 From: Vojtech Moravec <vojtech.moravec.st@vsb.cz> Date: Tue, 13 Oct 2020 14:23:33 +0200 Subject: [PATCH] Add API to train and save all codebooks in one call. By all codebooks we mean all possible codebook sizes (ours), so 4,18,16,32,64,128,256. The training is done in single LBG trainer, so data are loaded only once. --- .../qcmp/compression/IImageCompressor.java | 8 ++ .../qcmp/compression/ImageCompressor.java | 15 ++ .../qcmp/compression/SQImageCompressor.java | 5 + .../qcmp/compression/VQImageCompressor.java | 135 ++++++++++++------ .../vector/LBGVectorQuantizer.java | 36 ++++- 5 files changed, 147 insertions(+), 52 deletions(-) diff --git a/src/main/java/cz/it4i/qcmp/compression/IImageCompressor.java b/src/main/java/cz/it4i/qcmp/compression/IImageCompressor.java index 8d8ae98..5b4277f 100644 --- a/src/main/java/cz/it4i/qcmp/compression/IImageCompressor.java +++ b/src/main/java/cz/it4i/qcmp/compression/IImageCompressor.java @@ -37,6 +37,14 @@ public interface IImageCompressor extends IListenable { */ void trainAndSaveCodebook() throws ImageCompressionException; + /** + * Train all codebook sizes from selected frames and save learned codebooks to cache files. + * + * @throws ImageCompressionException when training or saving of any file fails. + */ + void trainAndSaveAllCodebooks() throws ImageCompressionException; + + /** * Preload compressor codebook and Huffman tree for stream compressor from provided cache file. * diff --git a/src/main/java/cz/it4i/qcmp/compression/ImageCompressor.java b/src/main/java/cz/it4i/qcmp/compression/ImageCompressor.java index 23f5792..ac87edb 100644 --- a/src/main/java/cz/it4i/qcmp/compression/ImageCompressor.java +++ b/src/main/java/cz/it4i/qcmp/compression/ImageCompressor.java @@ -88,6 +88,21 @@ public class ImageCompressor extends CompressorDecompressorBase { return true; } + public boolean trainAndSaveAllCodebooks() { + reportStatusToListeners("=== Training all codebooks ==="); + if (imageCompressor == null) { + return false; + } + try { + imageCompressor.trainAndSaveAllCodebooks(); + } catch (final ImageCompressionException e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + return false; + } + return true; + } + public int streamCompressChunk(final OutputStream outputStream, final InputData inputData) { assert (imageCompressor != null); diff --git a/src/main/java/cz/it4i/qcmp/compression/SQImageCompressor.java b/src/main/java/cz/it4i/qcmp/compression/SQImageCompressor.java index 22a8dd9..381c250 100644 --- a/src/main/java/cz/it4i/qcmp/compression/SQImageCompressor.java +++ b/src/main/java/cz/it4i/qcmp/compression/SQImageCompressor.java @@ -246,4 +246,9 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm public long[] compressStreamChunk(final DataOutputStream compressStream, final InputData inputData) throws ImageCompressionException { throw new ImageCompressionException("Not implemented yet"); } + + @Override + public void trainAndSaveAllCodebooks() throws ImageCompressionException { + throw new ImageCompressionException("Not implemented yet"); + } } diff --git a/src/main/java/cz/it4i/qcmp/compression/VQImageCompressor.java b/src/main/java/cz/it4i/qcmp/compression/VQImageCompressor.java index 2165311..ae13070 100644 --- a/src/main/java/cz/it4i/qcmp/compression/VQImageCompressor.java +++ b/src/main/java/cz/it4i/qcmp/compression/VQImageCompressor.java @@ -219,53 +219,6 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm return planeDataSizes; } - @Override - public void trainAndSaveCodebook() throws ImageCompressionException { - reportStatusToListeners("Loading image data..."); - - final IPlaneLoader planeLoader; - try { - planeLoader = PlaneLoaderFactory.getPlaneLoaderForInputFile(options.getInputDataInfo()); - } catch (final Exception e) { - throw new ImageCompressionException("Unable to create plane reader. " + e.getMessage()); - } - - final int[][] trainingData; - if (options.getInputDataInfo().isPlaneIndexSet()) { - reportStatusToListeners("VQ: Loading single plane data."); - final int planeIndex = options.getInputDataInfo().getPlaneIndex(); - trainingData = planeLoader.loadVectorsFromPlaneRange(options, new Range<>(planeIndex, planeIndex + 1)); - } else if (options.getInputDataInfo().isPlaneRangeSet()) { - reportStatusToListeners("VQ: Loading plane range data."); - trainingData = planeLoader.loadVectorsFromPlaneRange(options, options.getInputDataInfo().getPlaneRange()); - } else { - reportStatusToListeners("VQ: Loading all planes data."); - trainingData = planeLoader.loadVectorsFromPlaneRange(options, - new Range<>(0, options.getInputDataInfo().getDimensions().getZ())); - } - - - final LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, - getCodebookSize(), - options.getWorkerCount(), - options.getQuantizationVector()); - - reportStatusToListeners("Starting LBG optimization."); - vqInitializer.setStatusListener(this::reportStatusToListeners); - final LBGResult lbgResult = vqInitializer.findOptimalCodebook(); - reportStatusToListeners("Learned the optimal codebook."); - - - final QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder()); - try { - final String cacheFilePath = cacheManager.saveCodebook(options.getInputDataInfo().getCacheFileName(), lbgResult.getCodebook()); - reportStatusToListeners("Saved cache file to %s", cacheFilePath); - } catch (final IOException e) { - throw new ImageCompressionException("Unable to write VQ cache.", e); - } - reportStatusToListeners("Operation completed."); - } - /** * Calculate the number of voxel layers needed for dataset of plane count. * @@ -348,4 +301,92 @@ public class VQImageCompressor extends CompressorDecompressorBase implements IIm return voxelLayersSizes; } + @Override + public void trainAndSaveCodebook() throws ImageCompressionException { + reportStatusToListeners("Loading image data..."); + + final IPlaneLoader planeLoader; + try { + planeLoader = PlaneLoaderFactory.getPlaneLoaderForInputFile(options.getInputDataInfo()); + } catch (final Exception e) { + throw new ImageCompressionException("Unable to create plane reader. " + e.getMessage()); + } + + final int[][] trainingData = loadDataForCodebookTraining(planeLoader); + + + final LBGVectorQuantizer vqInitializer = new LBGVectorQuantizer(trainingData, + getCodebookSize(), + options.getWorkerCount(), + options.getQuantizationVector()); + + reportStatusToListeners("Starting LBG optimization."); + vqInitializer.setStatusListener(this::reportStatusToListeners); + final LBGResult lbgResult = vqInitializer.findOptimalCodebook(); + reportStatusToListeners("Learned the optimal codebook."); + + + final QuantizationCacheManager cacheManager = new QuantizationCacheManager(options.getCodebookCacheFolder()); + try { + final String cacheFilePath = cacheManager.saveCodebook(options.getInputDataInfo().getCacheFileName(), lbgResult.getCodebook()); + reportStatusToListeners("Saved cache file to %s", cacheFilePath); + } catch (final IOException e) { + throw new ImageCompressionException("Unable to write VQ cache.", e); + } + reportStatusToListeners("Operation completed."); + } + + @Override + public void trainAndSaveAllCodebooks() throws ImageCompressionException { + reportStatusToListeners("trainAndSaveAllCodebooks is starting with %d workers.", options.getWorkerCount()); + + reportStatusToListeners("Loading image data..."); + final IPlaneLoader planeLoader; + try { + planeLoader = PlaneLoaderFactory.getPlaneLoaderForInputFile(options.getInputDataInfo()); + } catch (final Exception e) { + throw new ImageCompressionException("Unable to create plane reader. " + e.getMessage()); + } + final int[][] trainingData = loadDataForCodebookTraining(planeLoader); + reportStatusToListeners("Data loading is finished."); + + final QuantizationCacheManager qcm = new QuantizationCacheManager(options.getCodebookCacheFolder()); + + final LBGVectorQuantizer codebookTrainer = new LBGVectorQuantizer(trainingData, + 256, + options.getWorkerCount(), + options.getQuantizationVector()); + codebookTrainer.findOptimalCodebook(vqCodebook -> { + try { + assert ((vqCodebook.getCodebookSize() == vqCodebook.getVectors().length) && + (vqCodebook.getCodebookSize() == vqCodebook.getVectorFrequencies().length)) + : "Codebook size, Vector count, Frequencies count mismatch"; + qcm.saveCodebook(options.getInputDataInfo().getCacheFileName(), vqCodebook); + } catch (final IOException e) { + System.err.println("Failed to save trained codebook."); + e.printStackTrace(); + } + reportStatusToListeners("Optimal codebook of size %d was found.", vqCodebook.getCodebookSize()); + }); + + reportStatusToListeners("Trained all codebooks."); + } + + int[][] loadDataForCodebookTraining(final IPlaneLoader planeLoader) throws ImageCompressionException { + final int[][] trainingData; + if (options.getInputDataInfo().isPlaneIndexSet()) { + reportStatusToListeners("VQ: Loading single plane data."); + final int planeIndex = options.getInputDataInfo().getPlaneIndex(); + trainingData = planeLoader.loadVectorsFromPlaneRange(options, new Range<>(planeIndex, planeIndex + 1)); + } else if (options.getInputDataInfo().isPlaneRangeSet()) { + reportStatusToListeners("VQ: Loading plane range data."); + trainingData = planeLoader.loadVectorsFromPlaneRange(options, options.getInputDataInfo().getPlaneRange()); + } else { + reportStatusToListeners("VQ: Loading all planes data."); + trainingData = planeLoader.loadVectorsFromPlaneRange(options, + new Range<>(0, options.getInputDataInfo().getDimensions().getZ())); + } + return trainingData; + } + } diff --git a/src/main/java/cz/it4i/qcmp/quantization/vector/LBGVectorQuantizer.java b/src/main/java/cz/it4i/qcmp/quantization/vector/LBGVectorQuantizer.java index 6577004..0245ca6 100644 --- a/src/main/java/cz/it4i/qcmp/quantization/vector/LBGVectorQuantizer.java +++ b/src/main/java/cz/it4i/qcmp/quantization/vector/LBGVectorQuantizer.java @@ -28,6 +28,10 @@ public class LBGVectorQuantizer { private IStatusListener statusListener = null; private double _mse = 0.0; + public interface CodebookFoundCallback { + void process(final VQCodebook trainedCodebook); + } + public LBGVectorQuantizer(final int[][] vectors, final int codebookSize, final int workerCount, @@ -108,19 +112,23 @@ public class LBGVectorQuantizer { return new LBGResult(vectorDimensions, codebook, frequencies, mse, psnr); } + public LBGResult findOptimalCodebook() { + return findOptimalCodebook(null); + } + /** * Find the optimal codebook of vectors, used for vector quantization. * * @return Result of the search. */ - public LBGResult findOptimalCodebook() { + public LBGResult findOptimalCodebook(final CodebookFoundCallback codebookCallback) { final Stopwatch stopwatch = Stopwatch.startNew("LBG::findOptimalCodebook()"); if (uniqueVectorCount < codebookSize) { return createCodebookFromUniqueVectors(); } - final LearningCodebookEntry[] codebook = initializeCodebook(); + final LearningCodebookEntry[] codebook = initializeCodebook(codebookCallback); reportStatus("LBG::findOptimalCodebook() - Got initial codebook. Improving it..."); LBG(codebook, EPSILON * 0.1); @@ -131,7 +139,14 @@ public class LBGVectorQuantizer { psnr); stopwatch.stop(); reportStatus(stopwatch.toString()); - return new LBGResult(vectorDimensions, learningCodebookToCodebook(codebook), frequencies, finalMse, psnr); + + final LBGResult result = new LBGResult(vectorDimensions, learningCodebookToCodebook(codebook), frequencies, finalMse, psnr); + + if (codebookCallback != null) { + codebookCallback.process(result.getCodebook()); + } + + return result; } /** @@ -304,7 +319,7 @@ public class LBGVectorQuantizer { * * @return The initial codebook to be improved by LBG. */ - private LearningCodebookEntry[] initializeCodebook() { + private LearningCodebookEntry[] initializeCodebook(final CodebookFoundCallback codebookFoundCallback) { int currentCodebookSize = 1; LearningCodebookEntry[] codebook = new LearningCodebookEntry[]{createInitialEntry()}; @@ -384,11 +399,22 @@ public class LBGVectorQuantizer { currentCodebookSize *= 2; // Execute LBG Algorithm on current codebook to improve it. - LBG(codebook); + final double eps = codebookFoundCallback == null ? EPSILON : EPSILON * 0.1; + LBG(codebook, eps); final double avgMse = averageMse(codebook); reportStatus("MSE of improved divided codebook: %f", avgMse); + + if (codebookFoundCallback != null) { + + final long[] codebookFrequencies = new long[codebook.length]; + System.arraycopy(frequencies, 0, codebookFrequencies, 0, codebook.length); + + codebookFoundCallback.process(new VQCodebook(vectorDimensions, + learningCodebookToCodebook(codebook), + codebookFrequencies)); + } } return codebook; } -- GitLab