Skip to content
Snippets Groups Projects
  • Vojtech Moravec's avatar
    825d7a03
    Experimental rewrite of LloydMax algorithm. · 825d7a03
    Vojtech Moravec authored
    - Calculate quantizaton values in range [min, max], where min and max correspond to the minimal and maximal values found in the data.
    
    - Initialize centroids to the value with the highest pdf.
    - Changed the terminal condition of the main loop, from distortion to the minimal MSE improvement.
    825d7a03
    History
    Experimental rewrite of LloydMax algorithm.
    Vojtech Moravec authored
    - Calculate quantizaton values in range [min, max], where min and max correspond to the minimal and maximal values found in the data.
    
    - Initialize centroids to the value with the highest pdf.
    - Changed the terminal condition of the main loop, from distortion to the minimal MSE improvement.
LloydMaxU16ScalarQuantization.java 9.38 KiB
package azgracompress.quantization.scalar;

import azgracompress.U16;
import azgracompress.quantization.QTrainIteration;
import azgracompress.utilities.MinMaxResult;
import azgracompress.utilities.Stopwatch;
import azgracompress.utilities.Utils;

import java.util.ArrayList;


public class LloydMaxU16ScalarQuantization {
    private final int[] trainingData;
    private int codebookSize;

    private int dataMin;
    private int dataMax;
    private int dataSpan;
    private int[] centroids;
    private int[] boundaryPoints;
    private double[] pdf;

    private final int workerCount;

    private boolean verbose = false;

    public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize, final int workerCount) {
        trainingData = trainData;
        this.codebookSize = codebookSize;
        this.workerCount = workerCount;
    }

    public LloydMaxU16ScalarQuantization(final int[] trainData, final int codebookSize) {
        this(trainData, codebookSize, 1);
    }

    private void initialize() {
        centroids = new int[codebookSize];

        boundaryPoints = new int[codebookSize + 1];

        MinMaxResult<Integer> minMax = Utils.getMinAndMax(trainingData);
        dataMin = minMax.getMin();
        dataMax = minMax.getMax();
        dataSpan = dataMax - dataMin;
        centroids[0] = dataMin;

        boundaryPoints[0] = dataMin;
        boundaryPoints[codebookSize] = dataMax;
        double intervalSize = (double) (dataSpan) / (double) codebookSize;
        for (int i = 0; i < codebookSize; i++) {
            centroids[i] = (int) Math.floor(((double) i + 0.5) * intervalSize);
        }
    }

    private void initializeProbabilityDensityFunction() {
        pdf = new double[U16.Max + 1];
        // Speedup - for now it is fast enough
        Stopwatch s = new Stopwatch();
        s.start();

        for (int i = 0; i < trainingData.length; i++) {
            pdf[trainingData[i]] += 1.0;
        }

        s.stop();
        if (verbose) {
            System.out.println("Init_PDF: " + s.getElapsedTimeString());
        }
    }

    private void recalculateBoundaryPoints() {
        for (int j = 1; j < codebookSize; j++) {
            boundaryPoints[j] = Math.min(dataMax,
                                         (int) Math.floor(((double) centroids[j] + (double) centroids[j - 1]) / 2.0));
        }
    }

    private void initializeCentroids() {
        int lowerBound, upperBound;
        double[] centroidPdf = new double[codebookSize];
        for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) {
            lowerBound = boundaryPoints[centroidIndex];
            upperBound = boundaryPoints[centroidIndex + 1];

            for (int rangeValue = lowerBound; rangeValue <= upperBound; rangeValue++) {
                if (pdf[rangeValue] > centroidPdf[centroidIndex]) {
                    centroidPdf[centroidIndex] = pdf[rangeValue];
                    centroids[centroidIndex] = rangeValue;
                }
            }
        }
    }

    private void recalculateCentroids() {
        double numerator = 0.0;
        double denominator = 0.0;

        int lowerBound, upperBound;

        for (int centroidIndex = 0; centroidIndex < codebookSize; centroidIndex++) {

            numerator = 0.0;
            denominator = 0.0;

            lowerBound = boundaryPoints[centroidIndex];
            upperBound = boundaryPoints[centroidIndex + 1];

            for (int n = lowerBound; n <= upperBound; n++) {
                numerator += (double) n * pdf[n];
                denominator += pdf[n];
            }

            if (denominator > 0) {
                centroids[centroidIndex] = (int) Math.floor(numerator / denominator);
            }
        }
    }

    public int quantize(final int value) {
        for (int intervalId = 1; intervalId <= codebookSize; intervalId++) {
            if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) {
                return centroids[intervalId - 1];
            }
        }
        throw new RuntimeException("Value couldn't be quantized!");
    }

    private double calculateMAE() {
        double mae = 0.0;
        for (final int trainingDatum : trainingData) {
            int quantizedValue = quantize(trainingDatum);
            mae += Math.abs((double) trainingDatum - (double) quantizedValue);
        }
        return (mae / (double) trainingData.length);
    }


    private double getCurrentMse() {
        double mse = 0.0;

        Stopwatch s = new Stopwatch();
        s.start();
        if (workerCount > 1) {
            final int workSize = trainingData.length / workerCount;

            RunnableLloydMseCalc[] runnables = new RunnableLloydMseCalc[workerCount];
            Thread[] workers = new Thread[workerCount];
            for (int wId = 0; wId < workerCount; wId++) {
                final int fromIndex = wId * workSize;
                final int toIndex = (wId == workerCount - 1) ? trainingData.length : (workSize + (wId * workSize));


                runnables[wId] = new RunnableLloydMseCalc(trainingData,
                                                          fromIndex,
                                                          toIndex,
                                                          centroids,
                                                          boundaryPoints,
                                                          codebookSize);
                workers[wId] = new Thread(runnables[wId]);
                workers[wId].start();
            }
            try {
                for (int wId = 0; wId < workerCount; wId++) {
                    workers[wId].join();
                    mse += runnables[wId].getMse();
                }
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            for (final int trainingDatum : trainingData) {
                int quantizedValue = quantize(trainingDatum);
                mse += Math.pow((double) trainingDatum - (double) quantizedValue, 2);
            }
        }
        s.stop();
        if (verbose) {
            System.out.println("\nLloydMax: getCurrentMse time: " + s.getElapsedTimeString());
        }

        mse /= (double) trainingData.length;

        return mse;
    }

    public QTrainIteration[] train(final boolean shouldBeVerbose) {
        this.verbose = shouldBeVerbose;
        final int RECALCULATE_N_TIMES = 10;
        final int PATIENCE = 1;
        int noImprovementCounter = 0;
        if (verbose) {
            System.out.println("Training data count: " + trainingData.length);
        }

        initialize();
        initializeProbabilityDensityFunction();


        double currMAE = 1.0;
        double prevMse = 1.0;
        double currentMse = 1.0;
        double psnr;

        ArrayList<QTrainIteration> solutionHistory = new ArrayList<>();

        recalculateBoundaryPoints();
        //        recalculateCentroids();
        initializeCentroids();
        currentMse = getCurrentMse();
        psnr = Utils.calculatePsnr(currentMse, U16.Max);

        if (verbose) {
            System.out.println(String.format("Initial MSE: %f", currentMse));
        }

        solutionHistory.add(new QTrainIteration(0, currentMse, currentMse, psnr, psnr));

        double mseImprovement = 1;
        int iteration = 0;
        do {
            for (int i = 0; i < RECALCULATE_N_TIMES; i++) {
                recalculateBoundaryPoints();
                recalculateCentroids();
            }

            // TODO(Moravec):   Check if we are improving MSE.
            //                  Save the best centroids, the lowest MSE.

            currMAE = calculateMAE();

            prevMse = currentMse;
            currentMse = getCurrentMse();
            mseImprovement = prevMse - currentMse;

            //            System.out.println(String.format("Improvement: %.4f", mseImprovement));

            //            if ((prevMAE < currMAE) && (iteration != 0)) {
            //                System.err.println(String.format(
            //                        "MAE = +%.5f",
            //                        currMAE - prevMAE));
            //            }

            psnr = Utils.calculatePsnr(currentMse, U16.Max);
            solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr));
            //            dist = (prevMse - currentMse) / currentMse;

            if (verbose) {
                System.out.println(String.format("Current MAE: %.4f MSE: %.4f PSNR: %.4f dB",
                                                 currMAE,
                                                 currentMse,
                                                 psnr));
            }

            //            if (mseImprovement < 1.0 && mseImprovement > 0.0005) {
            //                System.out.println("----- low improvement " + mseImprovement);
            //            }

            if (mseImprovement < 1.0) {
                if ((++noImprovementCounter) >= PATIENCE) {
                    break;
                }

            }


        } while (true); //0.001 //0.0005// || currMAE > 1500//mseImprovement > 0.0005
        if (verbose) {
            System.out.println("\nFinished training.");
        }
        System.out.println(String.format("Final MAE: %.4f after %d iterations", currMAE, iteration));
        return solutionHistory.toArray(new QTrainIteration[0]);
    }

    public int[] getCentroids() {
        return centroids;
    }
}