Skip to content
Snippets Groups Projects
Commit 980b7cac authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Merge commit from huffman branch

parent 0dc632a6
Branches
No related tags found
No related merge requests found
Showing
with 114 additions and 1177 deletions
...@@ -51,6 +51,4 @@ ...@@ -51,6 +51,4 @@
<version>2.8.6</version> <version>2.8.6</version>
</dependency> </dependency>
</dependencies> </dependencies>
</project> </project>
\ No newline at end of file
...@@ -2,9 +2,8 @@ package azgracompress.benchmark; ...@@ -2,9 +2,8 @@ package azgracompress.benchmark;
import azgracompress.U16; import azgracompress.U16;
import azgracompress.cli.ParsedCliOptions; import azgracompress.cli.ParsedCliOptions;
import azgracompress.data.V3i; import azgracompress.io.IPlaneLoader;
import azgracompress.de.DeException; import azgracompress.io.PlaneLoaderFactory;
import azgracompress.de.shade.ILShadeSolver;
import azgracompress.quantization.QTrainIteration; import azgracompress.quantization.QTrainIteration;
import azgracompress.quantization.QuantizationValueCache; import azgracompress.quantization.QuantizationValueCache;
import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization; import azgracompress.quantization.scalar.LloydMaxU16ScalarQuantization;
...@@ -43,7 +42,8 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase { ...@@ -43,7 +42,8 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase {
QuantizationValueCache cache = new QuantizationValueCache(cacheFolder); QuantizationValueCache cache = new QuantizationValueCache(cacheFolder);
try { try {
final int[] quantizationValues = cache.readCachedValues(inputFile, codebookSize); final int[] quantizationValues = cache.readCachedValues(inputFile, codebookSize);
quantizer = new ScalarQuantizer(U16.Min, U16.Max, quantizationValues); // TODO(Moravec): FIXME!
quantizer = null;//new ScalarQuantizer(U16.Min, U16.Max, quantizationValues);
} catch (IOException e) { } catch (IOException e) {
System.err.println("Failed to read quantization values from cache file."); System.err.println("Failed to read quantization values from cache file.");
e.printStackTrace(); e.printStackTrace();
...@@ -57,7 +57,8 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase { ...@@ -57,7 +57,8 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase {
return; return;
} }
if (useDiffEvolution) { if (useDiffEvolution) {
quantizer = trainDifferentialEvolution(refPlaneData, codebookSize); assert (false) : "DE is depracated";
quantizer = null;
} else { } else {
quantizer = trainLloydMaxQuantizer(refPlaneData, codebookSize); quantizer = trainLloydMaxQuantizer(refPlaneData, codebookSize);
} }
...@@ -75,8 +76,10 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase { ...@@ -75,8 +76,10 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase {
if (!hasGeneralQuantizer) { if (!hasGeneralQuantizer) {
if (useDiffEvolution) { if (useDiffEvolution) {
quantizer = trainDifferentialEvolution(planeData, codebookSize); assert (false) : "DE is depracated";
quantizer = null;
} else { } else {
quantizer = trainLloydMaxQuantizer(planeData, codebookSize); quantizer = trainLloydMaxQuantizer(planeData, codebookSize);
} }
...@@ -136,22 +139,23 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase { ...@@ -136,22 +139,23 @@ public class ScalarQuantizationBenchmark extends BenchmarkBase {
//saveQTrainLog(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize), trainingReport); //saveQTrainLog(String.format("p%d_cb_%d_lloyd.csv", planeIndex, codebookSize), trainingReport);
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCentroids()); // TODO(Moravec): FIXME
} return new ScalarQuantizer(U16.Min, U16.Max, null);//lloydMax.getCentroids());
}
private ScalarQuantizer trainDifferentialEvolution(final int[] data,
final int codebookSize) { // private ScalarQuantizer trainDifferentialEvolution(final int[] data,
ILShadeSolver ilshade = new ILShadeSolver(codebookSize, 100, 2000, 15); // final int codebookSize) {
ilshade.setTrainingData(data); // ILShadeSolver ilshade = new ILShadeSolver(codebookSize, 100, 2000, 15);
// ilshade.setTrainingData(data);
try { //
ilshade.train(); // try {
} catch (DeException deEx) { // ilshade.train();
deEx.printStackTrace(); // } catch (DeException deEx) {
return null; // deEx.printStackTrace();
} // return null;
return new ScalarQuantizer(U16.Min, U16.Max, ilshade.getBestSolution().getAttributes()); // }
} // return new ScalarQuantizer(U16.Min, U16.Max, ilshade.getBestSolution().getAttributes());
// }
public boolean isUseDiffEvolution() { public boolean isUseDiffEvolution() {
......
...@@ -31,7 +31,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -31,7 +31,7 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
codebookSize, codebookSize,
options.getWorkerCount()); options.getWorkerCount());
lloydMax.train(false); lloydMax.train(false);
return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCentroids()); return new ScalarQuantizer(U16.Min, U16.Max, lloydMax.getCodebook());
} }
/** /**
...@@ -65,8 +65,10 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm ...@@ -65,8 +65,10 @@ public class SQImageCompressor extends CompressorDecompressorBase implements IIm
private ScalarQuantizer loadQuantizerFromCache() throws ImageCompressionException { private ScalarQuantizer loadQuantizerFromCache() throws ImageCompressionException {
QuantizationValueCache cache = new QuantizationValueCache(options.getCodebookCacheFolder()); QuantizationValueCache cache = new QuantizationValueCache(options.getCodebookCacheFolder());
try { try {
final int[] quantizationValues = cache.readCachedValues(options.getInputFile(), codebookSize); final int[] quantizationValues = cache.readCachedValues(options.getInputFileInfo().getFilePath(),
return new ScalarQuantizer(U16.Min, U16.Max, quantizationValues); codebookSize);
// TODO(Moravec): FIXME the null value.
return new ScalarQuantizer(U16.Min, U16.Max, null);
} catch (IOException e) { } catch (IOException e) {
throw new ImageCompressionException("Failed to read quantization values from cache file.", e); throw new ImageCompressionException("Failed to read quantization values from cache file.", e);
} }
......
package azgracompress.de;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import java.util.Arrays;
public class DEIndividual implements IDEIndividual, Comparable<DEIndividual> {
protected Double m_fitness = null;
protected double m_mutationFactor;
protected double m_crossoverProbability;
protected int[] m_attributes;
protected DEIndividual(final int dimensionCount) {
m_attributes = new int[dimensionCount];
}
protected DEIndividual(final int[] attributes) {
m_attributes = attributes;
}
@Override
public double getFitness() {
assert (m_fitness != null);
return m_fitness;
}
@Override
public void setFitness(double fitness) {
m_fitness = fitness;
}
@Override
public boolean isFitnessCached() {
return (m_fitness != null);
}
@Override
public int compareTo(final DEIndividual other) {
return Double.compare(m_fitness, other.m_fitness);
}
@Override
public double getCrossoverProbability() {
return m_crossoverProbability;
}
@Override
public void setCrossoverProbability(final double cr) {
m_crossoverProbability = cr;
}
@Override
public double getMutationFactor() {
return m_mutationFactor;
}
@Override
public void setMutationFactor(final double f) {
m_mutationFactor = f;
}
@Override
public int[] getAttributes() {
return m_attributes;
}
@Override
public int getAttribute(final int dimension) {
return m_attributes[dimension];
}
@Override
public int hashCode() {
return Arrays.hashCode(m_attributes);
}
@Override
public DEIndividual createOffspringBinominalCrossover(final int[] mutationVector, final int jRand,
UniformRealDistribution rndCrDist) {
assert (m_attributes.length == mutationVector.length);
DEIndividual offspring = new DEIndividual(m_attributes.length);
for (int j = 0; j < m_attributes.length; j++) {
double crRnd = rndCrDist.sample();
if ((j == jRand) || (crRnd < m_crossoverProbability)) {
offspring.m_attributes[j] = mutationVector[j];
} else {
offspring.m_attributes[j] = m_attributes[j];
}
}
return offspring;
}
@Override
public boolean equals(Object obj) {
if (obj instanceof DEIndividual) {
return (this.hashCode() == ((DEIndividual) obj).hashCode());
} else {
return super.equals(obj);
}
}
}
package azgracompress.de;
import azgracompress.U16;
import azgracompress.de.jade.RunnablePopulationFitness;
import azgracompress.utilities.Utils;
import org.apache.commons.math3.distribution.CauchyDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import java.util.Arrays;
public abstract class DESolver implements IDESolver {
public static final int MINIMAL_POPULATION_SIZE = 5;
protected int minConstraint = U16.Min;
protected int maxConstraint = U16.Max;
protected int populationSize;
protected int generationCount;
protected int dimensionCount;
protected int[] trainingData;
protected int threadCount;
protected int currentPopulationSize;
protected DEIndividual bestSolution = null;
protected DEIndividual[] currentPopulation;
protected DEIndividual[] currentPopulationSorted;
public DESolver(final int dimension, final int populationSize, final int generationCount) {
//threadCount = Runtime.getRuntime().availableProcessors() - 1;
// NOTE(Moravec): Let's go with 4 threads for now.
threadCount = 4;
//assert (threadCount > 1);
this.dimensionCount = dimension;
this.populationSize = populationSize;
this.generationCount = generationCount;
this.currentPopulationSize = populationSize;
}
/**
* Generate individual attributes, so that all are unique.
*
* @param distribution Uniform integer distribution with constraints.
* @return Array of unique attributes.
*/
private int[] generateIndividualAttribues(UniformIntegerDistribution distribution) {
int[] attributes = new int[dimensionCount];
// NOTE(Moravec): We are cheting here, when we set the first attribute to be zero, because we know that this is the best value there is.
attributes[0] = minConstraint;
for (int dim = 1; dim < dimensionCount; dim++) {
int rndValue = distribution.sample();
while (Utils.arrayContainsToIndex(attributes, dim, rndValue)) {
rndValue = distribution.sample();
}
attributes[dim] = rndValue;
}
Arrays.sort(attributes);
return attributes;
}
/**
* Generate initial population based on population size and constraints. Also allocates archive.
*
* @throws DeException Throws exception on wrong population size or wrong dimension.
*/
protected void generateInitialPopulation() throws DeException {
assertPopulationSize();
assertDimension();
currentPopulation = new DEIndividual[populationSize];
UniformIntegerDistribution uniformIntDistribution = new UniformIntegerDistribution(new MersenneTwister(), minConstraint, maxConstraint);
for (int individualIndex = 0; individualIndex < populationSize; individualIndex++) {
int[] attributes = generateIndividualAttribues(uniformIntDistribution);
currentPopulation[individualIndex] = new DEIndividual(attributes);
}
}
protected double getMseFromCalculatedFitness(final DEIndividual[] population) {
double mse = 0.0;
for (final DEIndividual individual : population) {
assert (individual.isFitnessCached());
mse += individual.getFitness();
}
return (mse / (double) population.length);
}
/**
* Parallelized calculation of fitness values for individuals in current population.
*/
protected double calculateFitnessForPopulationParallel(DEIndividual[] population) {
double avg = 0.0;
RunnablePopulationFitness[] workerInfos = new RunnablePopulationFitness[threadCount];
Thread[] workers = new Thread[threadCount];
int threadWorkSize = population.length / threadCount;
for (int workerId = 0; workerId < threadCount; workerId++) {
int workerFrom = workerId * threadWorkSize;
int workerTo = (workerId == (threadCount - 1)) ? population.length : (workerId * threadWorkSize) + threadWorkSize;
workerInfos[workerId] = new RunnablePopulationFitness(trainingData, population, workerFrom, workerTo);
workers[workerId] = new Thread(workerInfos[workerId]);
workers[workerId].start();
}
try {
for (int workerId = 0; workerId < threadCount; workerId++) {
workers[workerId].join();
avg += workerInfos[workerId].getTotalMse();
}
} catch (InterruptedException ignored) {
}
avg /= (double) population.length;
return avg;
}
/**
* Select random individual from p*100% top individuals.
*
* @param pBestDistribution Distribution for p*100% random.
* @param others Other individuals.
* @return Random individual from p*100%.
*/
protected DEIndividual getRandomFromPBest(UniformIntegerDistribution pBestDistribution,
final DEIndividual... others) {
assert (currentPopulationSorted != null);
int rndIndex = pBestDistribution.sample();
while (Utils.arrayContains(others, currentPopulationSorted[rndIndex])) {
rndIndex = pBestDistribution.sample();
}
return currentPopulationSorted[rndIndex];
}
/**
* Get random individual from current population, distinct from the other.
*
* @param rndIndDist Distribution of current population random.
* @param others Other individuals.
* @return Distinct random individual from the another one.
*/
protected DEIndividual getRandomFromCurrentPopulation(UniformIntegerDistribution rndIndDist,
final DEIndividual... others) {
DEIndividual rndIndiv = currentPopulation[rndIndDist.sample()];
while (Utils.arrayContains(others, rndIndiv)) {
rndIndiv = currentPopulation[rndIndDist.sample()];
}
return rndIndiv;
}
protected int[] createMutationVectorCurrentToPBest(final DEIndividual current,
final DEIndividual x_p_Best,
final DEIndividual x_r1,
final DEIndividual x_r2) {
int[] mutationVector = new int[dimensionCount];
double mutationFactor = current.getMutationFactor();
for (int j = 0; j < dimensionCount; j++) {
mutationVector[j] = (int) Math.floor(current.getAttribute(j) +
(mutationFactor * ((double) x_p_Best.getAttribute(j) - current.getAttribute(j))) +
(mutationFactor * ((double) x_r1.getAttribute(j) - x_r2.getAttribute(j))));
if (mutationVector[j] < minConstraint) {
mutationVector[j] = (minConstraint + current.getAttribute(j)) / 2;
} else if (mutationVector[j] > maxConstraint) {
mutationVector[j] = (maxConstraint + current.getAttribute(j)) / 2;
}
}
return mutationVector;
}
protected DEIndividual[] createSortedCopyOfCurrentPopulation() {
currentPopulationSorted = Arrays.copyOf(currentPopulation, currentPopulation.length);
Arrays.sort(currentPopulationSorted);
return currentPopulationSorted;
}
protected double generateMutationFactor(CauchyDistribution dist) {
double factor = dist.sample();
while (factor <= 0.0) { // || Double.isNaN(factor)) {
factor = dist.sample();
}
if (factor > 1.0) {
factor = 1.0;
}
assert (factor > 0.0 && factor <= 1.0);
return factor;
}
protected double generateCrossoverProbability(NormalDistribution dist) {
double prob = dist.sample();
// NOTE(Moravec): Sometimes dist.sample() returns NaN...
while (Double.isNaN(prob)) {
prob = dist.sample();
}
if (prob < 0.0) {
prob = 0.0;
} else if (prob > 1.0) {
prob = 1.0;
}
assert (prob >= 0.0 && prob <= 1.0);
return prob;
}
protected void assertPopulationSize() throws DeException {
if (populationSize < MINIMAL_POPULATION_SIZE) {
throw new DeException("Population size is too low. Required population size >= 5.");
}
}
protected void assertDimension() throws DeException {
if (dimensionCount < 1) {
throw new DeException("Dimension is too low. Required dimension >= 1.");
}
}
@Override
public void setMinimalValueConstraint(int min) {
minConstraint = min;
}
@Override
public void setMaximalValueConstraint(int max) {
maxConstraint = max;
}
@Override
public void setPopulationSize(int populationSize) throws DeException {
assertPopulationSize();
this.populationSize = populationSize;
currentPopulationSize = populationSize;
}
@Override
public void setGenerationCount(int generationCount) {
this.generationCount = generationCount;
}
@Override
public void setDimensionCount(int dimensionCount) {
this.dimensionCount = dimensionCount;
}
@Override
public void setTrainingData(int[] data) {
trainingData = data;
}
@Override
public IDEIndividual getBestSolution() {
return bestSolution;
}
}
package azgracompress.de;
import azgracompress.utilities.Utils;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import java.util.ArrayList;
import java.util.Random;
public abstract class DESolverWithArchive extends DESolver {
protected int maxArchiveSize;
protected ArrayList<DEIndividual> archive;
protected DESolverWithArchive(int dimension, int currentPopulationSize, int generationCount, int maxArchiveSize) {
super(dimension, currentPopulationSize, generationCount);
this.maxArchiveSize = maxArchiveSize;
archive = new ArrayList<DEIndividual>(maxArchiveSize);
}
protected void truncateArchive() {
int deleteCount = archive.size() - maxArchiveSize;
if (deleteCount > 0) {
Random random = new Random();
for (int i = 0; i < deleteCount; i++) {
archive.remove(random.nextInt(archive.size()));
}
}
assert (archive.size() <= maxArchiveSize);
}
/**
* Get random individual (different from others) from union of current population and archive.
*
* @param rndUnionDist Random distribution for the union of current population and archive.
* @param others Other individuals.
* @return Random individual from union of current population and archive.
*/
protected DEIndividual getRandomFromPopulationAndArchive(UniformIntegerDistribution rndUnionDist,
final DEIndividual... others) {
int rndIndex = rndUnionDist.sample();
DEIndividual rndIndiv = (rndIndex >= currentPopulationSize) ? archive.get(rndIndex - currentPopulationSize) : currentPopulation[rndIndex];
while (Utils.arrayContains(others, rndIndiv)) {
rndIndex = rndUnionDist.sample();
rndIndiv = (rndIndex >= currentPopulationSize) ? archive.get(rndIndex - currentPopulationSize) : currentPopulation[rndIndex];
}
return rndIndiv;
}
public int getMaxArchiveSize() {
return maxArchiveSize;
}
public void setMaxArchiveSize(int maxArchiveSize) {
this.maxArchiveSize = maxArchiveSize;
}
}
package azgracompress.de;
public class DeException extends Exception {
public DeException(final String message) {
super(message);
}
}
package azgracompress.de;
import org.apache.commons.math3.distribution.UniformRealDistribution;
public interface IDEIndividual {
double getFitness();
void setFitness(final double fitness);
boolean isFitnessCached();
double getCrossoverProbability();
void setCrossoverProbability(final double cr);
double getMutationFactor();
void setMutationFactor(final double f);
int[] getAttributes();
int getAttribute(final int dimension);
IDEIndividual createOffspringBinominalCrossover(final int[] mutationVector, final int jRand, UniformRealDistribution rndCrDist);
}
package azgracompress.de;
import azgracompress.quantization.QTrainIteration;
public interface IDESolver {
void setMinimalValueConstraint(final int min);
void setMaximalValueConstraint(final int max);
void setPopulationSize(final int populationSize) throws DeException;
void setGenerationCount(final int generationCount);
void setDimensionCount(final int dimensionCount);
QTrainIteration[] train() throws DeException;
void setTrainingData(final int[] data);
IDEIndividual getBestSolution();
}
package azgracompress.de.jade;
import azgracompress.U16;
import azgracompress.de.DEIndividual;
import azgracompress.de.DESolverWithArchive;
import azgracompress.de.DeException;
import azgracompress.quantization.QTrainIteration;
import azgracompress.utilities.Means;
import azgracompress.utilities.Stopwatch;
import azgracompress.utilities.Utils;
import org.apache.commons.math3.distribution.CauchyDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import java.util.ArrayList;
import java.util.Arrays;
public class JadeSolver extends DESolverWithArchive {
private double muCr = 0.5;
private double muF = 0.5;
private double parameterAdaptationRate = 0.05;
private double mutationGreediness = 0.1;
private DEIndividual[] currentPopulationSorted = null;
public JadeSolver(final int dimension, final int populationSize, final int generationCount) {
super(dimension, populationSize, generationCount, populationSize);
}
public JadeSolver(final int dimension, final int populationSize, final int generationCount,
final double parameterAdaptationRate, final double mutationGreediness) {
this(dimension, populationSize, generationCount);
this.parameterAdaptationRate = parameterAdaptationRate;
this.mutationGreediness = mutationGreediness;
}
@Override
public QTrainIteration[] train() throws DeException {
final String delimiter = "-------------------------------------------";
QTrainIteration[] solutionHistory = new QTrainIteration[generationCount];
if (trainingData == null || trainingData.length <= 0) {
throw new DeException("Training data weren't set.");
}
muCr = 0.5;
muF = 0.5;
generateInitialPopulation();
double avgFitness = calculateFitnessForPopulationParallel(currentPopulation);
System.out.println(String.format("Generation %d average fitness(COST): %.5f", 0, avgFitness));
ArrayList<Double> successfulCr = new ArrayList<Double>();
ArrayList<Double> successfulF = new ArrayList<Double>();
Stopwatch stopwatch = new Stopwatch();
RandomGenerator rg = new MersenneTwister();
int pBestUpperLimit = (int) Math.floor(populationSize * mutationGreediness);
UniformIntegerDistribution rndPBestDist = new UniformIntegerDistribution(rg, 0, (pBestUpperLimit - 1));
UniformIntegerDistribution rndIndDist = new UniformIntegerDistribution(rg, 0, (populationSize - 1));
UniformIntegerDistribution rndJRandDist = new UniformIntegerDistribution(rg, 0, (dimensionCount - 1));
UniformRealDistribution rndCrDist = new UniformRealDistribution(rg, 0.0, 1.0);
DEIndividual[] offsprings = new DEIndividual[populationSize];
for (int generation = 0; generation < generationCount; generation++) {
stopwatch.restart();
StringBuilder generationLog = new StringBuilder(String.format("%s\nGeneration: %d\n", delimiter, (generation + 1)));
currentPopulationSorted = createSortedCopyOfCurrentPopulation();
successfulCr.clear();
successfulF.clear();
UniformIntegerDistribution rndPopArchiveDist =
new UniformIntegerDistribution(rg, 0, ((populationSize - 1) + archive.size()));
NormalDistribution crNormalDistribution = new NormalDistribution(rg, muCr, 0.1);
CauchyDistribution fCauchyDistribution = new CauchyDistribution(rg, muF, 0.1);
for (int i = 0; i < populationSize; i++) {
DEIndividual current = currentPopulation[i];
current.setCrossoverProbability(generateCrossoverProbability(crNormalDistribution));
current.setMutationFactor(generateMutationFactor(fCauchyDistribution));
DEIndividual x_p_Best = getRandomFromPBest(rndPBestDist, current);
DEIndividual x_r1 = getRandomFromCurrentPopulation(rndIndDist, current, x_p_Best);
DEIndividual x_r2 = getRandomFromPopulationAndArchive(rndPopArchiveDist, current, x_p_Best, x_r1);
int[] mutationVector = createMutationVectorCurrentToPBest(current, x_p_Best, x_r1, x_r2);
int jRand = rndJRandDist.sample();
offsprings[i] = current.createOffspringBinominalCrossover(mutationVector, jRand, rndCrDist);
}
calculateFitnessForPopulationParallel(offsprings);
// NOTE(Moravec): We are minimalizing!
for (int i = 0; i < populationSize; i++) {
if (offsprings[i].getFitness() <= currentPopulation[i].getFitness()) {
final DEIndividual old = currentPopulation[i];
currentPopulation[i] = offsprings[i];
successfulCr.add(old.getCrossoverProbability());
successfulF.add(old.getMutationFactor());
archive.add(old);
}
}
double oldMuCr = muCr, oldMuF = muF;
muCr = ((1.0 - parameterAdaptationRate) * muCr) + (parameterAdaptationRate * Means.arithmeticMean(successfulCr));
muF = ((1.0 - parameterAdaptationRate) * muF) + (parameterAdaptationRate * Means.lehmerMean(successfulF));
generationLog.append(String.format("|S_Cr| = %d |S_F| = %d\n", successfulCr.size(), successfulF.size()));
generationLog.append(String.format("Old μCR: %.5f New μCR: %.5f\nOld μF: %.5f New μF: %.5f\n",
oldMuCr, muCr, oldMuF, muF));
truncateArchive();
generationLog.append(String.format("Archive size after truncate: %d\n", archive.size()));
avgFitness = getMseFromCalculatedFitness(currentPopulation);
stopwatch.stop();
final double currentBestFitness = currentPopulationSorted[0].getFitness();
final double avgPsnr = Utils.calculatePsnr(avgFitness, U16.Max);
generationLog.append("Current best fitness: ").append(currentBestFitness);
generationLog.append(String.format("\nAverage fitness(cost): %.6f\nIteration finished in: %d ms", avgFitness, stopwatch.totalElapsedMilliseconds()));
System.out.println(generationLog.toString());
solutionHistory[generation] = new QTrainIteration(generation, avgFitness, currentBestFitness, Utils.calculatePsnr(currentBestFitness, U16.Max), avgPsnr);
}
Arrays.sort(currentPopulationSorted);
bestSolution = currentPopulationSorted[0];
return solutionHistory;
}
public double getParameterAdaptationRate() {
return parameterAdaptationRate;
}
public void setParameterAdaptationRate(double parameterAdaptationRate) {
this.parameterAdaptationRate = parameterAdaptationRate;
}
public double getMutationGreediness() {
return mutationGreediness;
}
public void setMutationGreediness(double mutationGreediness) {
this.mutationGreediness = mutationGreediness;
}
}
package azgracompress.de.jade;
import azgracompress.quantization.scalar.ScalarQuantizer;
import azgracompress.de.DEIndividual;
public class RunnablePopulationFitness implements Runnable {
private double mse = 0.0;
private final int[] testData;
private final DEIndividual[] population;
private final int fromIndex;
private final int toIndex;
public RunnablePopulationFitness(final int[] testData, final DEIndividual[] population, final int popFrom, final int popTo) {
this.testData = testData;
this.population = population;
this.fromIndex = popFrom;
this.toIndex = popTo;
}
@Override
public void run() {
double mse = 0.0;
for (int i = fromIndex; i < toIndex; i++) {
double indivMse;
if (population[i].isFitnessCached()) {
indivMse = population[i].getFitness();
} else {
ScalarQuantizer quantizer = new ScalarQuantizer(0, 0xffff, population[i].getAttributes());
indivMse = quantizer.getMse(testData);
}
population[i].setFitness(indivMse);
mse += indivMse;
}
this.mse = mse;
}
public double getTotalMse() {
return mse;
}
public double getAvgMse() {
return (mse / (double) (toIndex - fromIndex));
}
}
package azgracompress.de.shade;
import azgracompress.U16;
import azgracompress.de.DeException;
import azgracompress.quantization.QTrainIteration;
import azgracompress.utilities.Means;
import azgracompress.utilities.Utils;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import azgracompress.de.DEIndividual;
import azgracompress.utilities.Stopwatch;
import java.util.ArrayList;
public class ILShadeSolver extends LShadeSolver {
private double currentMutationGreediness;
private double minMutationGreediness = 0.1;
public ILShadeSolver(int dimension, int populationSize, int generationCount, int memorySize) {
super(dimension, populationSize, generationCount, memorySize);
maxMutationGreediness = 0.2;
minMutationGreediness = 0.1;
currentMutationGreediness = maxMutationGreediness;
}
@SuppressWarnings("DuplicatedCode")
@Override
public QTrainIteration[] train() throws DeException {
final String delimiter = "-------------------------------------------";
int maxNfe = (populationSize * generationCount);
int nfe = 0;
QTrainIteration[] solutionHistory = new QTrainIteration[generationCount];
RandomGenerator rg = new MersenneTwister();
initializeMemory(0.8, 0.5);
ArrayList<Double> successfulCr = new ArrayList<Double>();
ArrayList<Double> successfulF = new ArrayList<Double>();
ArrayList<Double> absDelta = new ArrayList<Double>();
generateInitialPopulation();
double averageMSE = calculateFitnessForPopulationParallel(currentPopulation);
UniformIntegerDistribution memoryIndexDist = new UniformIntegerDistribution(rg, 0, (memorySize - 1));
UniformIntegerDistribution jRandDist = new UniformIntegerDistribution(rg, 0, (dimensionCount - 1));
UniformRealDistribution crDist = new UniformRealDistribution(rg, 0.0, 1.0);
Stopwatch stopwatch = new Stopwatch();
for (int generation = 0; generation < generationCount; generation++) {
stopwatch.restart();
successfulCr.clear();
successfulF.clear();
absDelta.clear();
StringBuilder generationLog = new StringBuilder(String.format("%s\niL-SHADE\nGeneration: %d\n", delimiter, (generation + 1)));
currentPopulationSorted = createSortedCopyOfCurrentPopulation();
DEIndividual[] offsprings = new DEIndividual[currentPopulationSize];
UniformIntegerDistribution rndPopArchiveDist =
new UniformIntegerDistribution(rg, 0, ((currentPopulationSize - 1) + archive.size()));
int pBestUpperLimit = (int) Math.floor(currentPopulationSize * currentMutationGreediness);
UniformIntegerDistribution rndPBestDist = new UniformIntegerDistribution(rg, 0, (pBestUpperLimit - 1));
UniformIntegerDistribution rndIndDist = new UniformIntegerDistribution(rg, 0, (currentPopulationSize - 1));
for (int i = 0; i < currentPopulationSize; i++) {
int randomMemIndex = memoryIndexDist.sample();
if (randomMemIndex == (memorySize - 1)) {
memoryCr[randomMemIndex] = 0.9;
memoryF[randomMemIndex] = 0.9;
}
currentPopulation[i].setCrossoverProbability(iLShadeGenerateCrossoverProbability(randomMemIndex, generation));
currentPopulation[i].setMutationFactor(iLShadeGenerateMutationFactor(randomMemIndex, generation));
DEIndividual x_p_Best = getRandomFromPBest(rndPBestDist, currentPopulation[i]);
DEIndividual x_r1 = getRandomFromCurrentPopulation(rndIndDist, currentPopulation[i], x_p_Best);
DEIndividual x_r2 = getRandomFromPopulationAndArchive(rndPopArchiveDist, currentPopulation[i], x_p_Best, x_r1);
final int[] mutationVector = createMutationVectorCurrentToPBest(currentPopulation[i], x_p_Best, x_r1, x_r2);
offsprings[i] = currentPopulation[i].createOffspringBinominalCrossover(mutationVector, jRandDist.sample(), crDist);
}
calculateFitnessForPopulationParallel(offsprings);
nfe += currentPopulationSize;
DEIndividual[] nextPopulation = new DEIndividual[currentPopulationSize];
// NOTE(Moravec): We are minimalizing!
for (int i = 0; i < currentPopulationSize; i++) {
final DEIndividual old = currentPopulation[i];
if (offsprings[i].getFitness() <= old.getFitness()) {
nextPopulation[i] = offsprings[i];
if (offsprings[i].getFitness() < old.getFitness()) {
archive.add(old);
absDelta.add(Math.abs(offsprings[i].getFitness() - currentPopulation[i].getFitness()));
successfulCr.add(old.getCrossoverProbability());
successfulF.add(old.getMutationFactor());
}
} else {
nextPopulation[i] = currentPopulation[i];
}
}
updateMemory(successfulCr, successfulF, absDelta);
currentPopulation = nextPopulation;
applyLinearReductionOfPopulationSize(nfe, maxNfe);
truncateArchive();
updateMutationGreediness(nfe, maxNfe);
averageMSE = getMseFromCalculatedFitness(currentPopulation);
// NOTE(Moravec): After LRPS the population is sorted according.
final double bestMSE = currentPopulation[0].getFitness();
final double bestPSNR = Utils.calculatePsnr(bestMSE, U16.Max);
final double averagePSNR = Utils.calculatePsnr(averageMSE, U16.Max);
solutionHistory[generation] = new QTrainIteration(generation, averageMSE, bestMSE, averagePSNR, bestPSNR);
stopwatch.stop();
generationLog.append(String.format("Current population size: %d\n", currentPopulationSize));
generationLog.append(String.format("Mutation greediness: %.5f\n", currentMutationGreediness));
generationLog.append(String.format("Current best fitness: %.5f Current PSNR: %.5f dB", bestMSE, bestPSNR));
generationLog.append(String.format("\nAvg. cost(after LPSR): %.6f\nAvg. PSNR (after LPSR): %.6f dB\nIteration finished in: %d ms", averageMSE, averagePSNR, stopwatch.totalElapsedMilliseconds()));
System.out.println(generationLog.toString());
}
return solutionHistory;
}
private double iLShadeGenerateCrossoverProbability(final int memIndex, final int currentGeneration) {
double cr = generateCrossoverProbability(memIndex);
if ((double) currentGeneration < (0.25 * (double) generationCount)) {
cr = Math.max(cr, 0.5);
} else if ((double) currentGeneration < (0.5 * (double) generationCount)) {
cr = Math.max(cr, 0.25);
}
return cr;
}
private double iLShadeGenerateMutationFactor(final int memIndex, final int currentGeneration) {
double f = generateMutationFactor(memIndex);
if ((double) currentGeneration < (0.25 * (double) generationCount)) {
f = Math.max(f, 0.7);
} else if ((double) currentGeneration < (0.5 * (double) generationCount)) {
f = Math.max(f, 0.8);
} else if ((double) currentGeneration < (0.75 * (double) generationCount)) {
f = Math.max(f, 0.9);
}
return f;
}
private void updateMutationGreediness(final int nfes, final int maxNfes) {
currentMutationGreediness = (((maxMutationGreediness - minMutationGreediness) / (double) maxNfes) * nfes) + minMutationGreediness;
}
@Override
protected void updateMemory(final ArrayList<Double> successfulCr,
final ArrayList<Double> successfulF,
final ArrayList<Double> absDelta) {
if ((!successfulCr.isEmpty()) && (!successfulF.isEmpty())) {
assert ((absDelta.size() == successfulCr.size()) && (successfulCr.size() == successfulF.size()));
double[] weights = calculateLehmerWeihts(absDelta);
if ((Double.isNaN(memoryCr[memoryIndex])) || (Utils.arrayListMax(successfulCr) == 0)) {
memoryCr[memoryIndex] = Double.NaN;
} else {
memoryCr[memoryIndex] = ((Means.weightedLehmerMean(successfulCr, weights) + memoryCr[memoryIndex]) / 2.0);
}
memoryF[memoryIndex] = ((Means.weightedLehmerMean(successfulF, weights) + memoryF[memoryIndex]) / 2.0);
++memoryIndex;
if (memoryIndex >= memorySize) {
memoryIndex = 0;
}
}
}
public double getMinMutationGreediness() {
return minMutationGreediness;
}
public void setMinMutationGreediness(double minMutationGreediness) {
this.minMutationGreediness = minMutationGreediness;
}
@Override
public void setMaxMutationGreediness(double maxMutationGreediness) {
super.setMaxMutationGreediness(maxMutationGreediness);
currentMutationGreediness = maxMutationGreediness;
}
}
package azgracompress.de.shade;
import azgracompress.U16;
import azgracompress.de.DeException;
import azgracompress.quantization.QTrainIteration;
import azgracompress.utilities.Utils;
import org.apache.commons.math3.distribution.CauchyDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import azgracompress.de.DEIndividual;
import azgracompress.de.DESolverWithArchive;
import azgracompress.utilities.Means;
import azgracompress.utilities.Stopwatch;
import java.util.ArrayList;
public class LShadeSolver extends DESolverWithArchive {
protected int memorySize;
protected int memoryIndex = 0;
protected double[] memoryCr;
protected double[] memoryF;
protected double maxMutationGreediness = 0.1;
protected int minimalPopulationSize = MINIMAL_POPULATION_SIZE;
public LShadeSolver(int dimension, int populationSize, int generationCount, int memorySize) {
super(dimension, populationSize, generationCount, populationSize);
this.memorySize = memorySize;
}
protected void initializeMemory(final double initialCrValue, final double initialFValue) {
memoryIndex = 0;
memoryCr = new double[memorySize];
memoryF = new double[memorySize];
for (int memIndex = 0; memIndex < memorySize; memIndex++) {
memoryCr[memIndex] = initialCrValue;
memoryF[memIndex] = initialFValue;
}
}
protected double generateCrossoverProbability(final int memIndex) {
double memCr = memoryCr[memIndex];
if (Double.isNaN(memCr)) {
return 0.0;
} else {
return generateCrossoverProbability(new NormalDistribution(memCr, 0.1));
}
}
protected double generateMutationFactor(final int memIndex) {
return generateMutationFactor(new CauchyDistribution(memoryF[memIndex], 0.1));
}
@Override
public QTrainIteration[] train() throws DeException {
final String delimiter = "-------------------------------------------";
int maxNfe = (populationSize * generationCount);
int nfe = 0;
QTrainIteration[] solutionHistory = new QTrainIteration[generationCount];
RandomGenerator rg = new MersenneTwister();
initializeMemory(0.5, 0.5);
ArrayList<Double> successfulCr = new ArrayList<Double>();
ArrayList<Double> successfulF = new ArrayList<Double>();
ArrayList<Double> absDelta = new ArrayList<Double>();
generateInitialPopulation();
double averageMSE = calculateFitnessForPopulationParallel(currentPopulation);
UniformIntegerDistribution memoryIndexDist = new UniformIntegerDistribution(rg, 0, (memorySize - 1));
UniformIntegerDistribution jRandDist = new UniformIntegerDistribution(rg, 0, (dimensionCount - 1));
UniformRealDistribution crDist = new UniformRealDistribution(rg, 0.0, 1.0);
Stopwatch stopwatch = new Stopwatch();
for (int generation = 0; generation < generationCount; generation++) {
stopwatch.restart();
successfulCr.clear();
successfulF.clear();
absDelta.clear();
StringBuilder generationLog = new StringBuilder(String.format("%s\nGeneration: %d\n", delimiter, (generation + 1)));
currentPopulationSorted = createSortedCopyOfCurrentPopulation();
DEIndividual[] offsprings = new DEIndividual[currentPopulationSize];
UniformIntegerDistribution rndPopArchiveDist =
new UniformIntegerDistribution(rg, 0, ((currentPopulationSize - 1) + archive.size()));
int pBestUpperLimit = (int) Math.floor(currentPopulationSize * maxMutationGreediness);
UniformIntegerDistribution rndPBestDist = new UniformIntegerDistribution(rg, 0, (pBestUpperLimit - 1));
UniformIntegerDistribution rndIndDist = new UniformIntegerDistribution(rg, 0, (currentPopulationSize - 1));
for (int i = 0; i < currentPopulationSize; i++) {
int randomMemIndex = memoryIndexDist.sample();
currentPopulation[i].setCrossoverProbability(generateCrossoverProbability(randomMemIndex));
currentPopulation[i].setMutationFactor(generateMutationFactor(randomMemIndex));
DEIndividual x_p_Best = getRandomFromPBest(rndPBestDist, currentPopulation[i]);
DEIndividual x_r1 = getRandomFromCurrentPopulation(rndIndDist, currentPopulation[i], x_p_Best);
DEIndividual x_r2 = getRandomFromPopulationAndArchive(rndPopArchiveDist, currentPopulation[i], x_p_Best, x_r1);
final int[] mutationVector = createMutationVectorCurrentToPBest(currentPopulation[i], x_p_Best, x_r1, x_r2);
offsprings[i] = currentPopulation[i].createOffspringBinominalCrossover(mutationVector, jRandDist.sample(), crDist);
}
calculateFitnessForPopulationParallel(offsprings);
nfe += currentPopulationSize;
DEIndividual[] nextPopulation = new DEIndividual[currentPopulationSize];
// NOTE(Moravec): We are minimalizing!
for (int i = 0; i < currentPopulationSize; i++) {
final DEIndividual old = currentPopulation[i];
if (offsprings[i].getFitness() <= old.getFitness()) {
nextPopulation[i] = offsprings[i];
if (offsprings[i].getFitness() < old.getFitness()) {
archive.add(old);
absDelta.add(Math.abs(offsprings[i].getFitness() - currentPopulation[i].getFitness()));
successfulCr.add(old.getCrossoverProbability());
successfulF.add(old.getMutationFactor());
}
} else {
nextPopulation[i] = currentPopulation[i];
}
}
updateMemory(successfulCr, successfulF, absDelta);
currentPopulation = nextPopulation;
applyLinearReductionOfPopulationSize(nfe, maxNfe);
truncateArchive();
averageMSE = getMseFromCalculatedFitness(currentPopulation);
// NOTE(Moravec): After LRPS the population is sorted according.
final double bestMSE = currentPopulation[0].getFitness();
final double bestPSNR = Utils.calculatePsnr(bestMSE, U16.Max);
final double averagePSNR = Utils.calculatePsnr(averageMSE, U16.Max);
solutionHistory[generation] = new QTrainIteration(generation, averageMSE, bestMSE, averagePSNR, bestPSNR);
stopwatch.stop();
generationLog.append(String.format("Archive size after truncate: %d\n", archive.size()));
generationLog.append(String.format("Current population size: %d\n", currentPopulationSize));
generationLog.append(String.format("Current best fitness: %.5f Current PSNR: %.5f dB", bestMSE, bestPSNR));
generationLog.append(String.format("\nAvg. cost(after LPSR): %.6f\nAvg. PSNR (after LPSR): %.6f dB\nIteration finished in: %d ms", averageMSE, averagePSNR, stopwatch.totalElapsedMilliseconds()));
System.out.println(generationLog.toString());
}
return solutionHistory;
}
protected void applyLinearReductionOfPopulationSize(final int nfe, final int maxNfe) {
final int oldPopulationSize = currentPopulationSize;
currentPopulationSize = getNewPopulationSize(nfe, maxNfe);
maxArchiveSize = currentPopulationSize;
if (currentPopulationSize < oldPopulationSize) {
DEIndividual[] reducedPopulation = new DEIndividual[currentPopulationSize];
System.arraycopy(currentPopulationSorted, 0, reducedPopulation, 0, currentPopulationSize);
currentPopulation = reducedPopulation;
}
}
private int getNewPopulationSize(final int nfe, final int maxNfe) {
int newPopulationSize = (int) Math.round(((((double) minimalPopulationSize - (double) populationSize) / (double) maxNfe) * (double) nfe) + (double) populationSize);
return newPopulationSize;
}
protected double[] calculateLehmerWeihts(final ArrayList<Double> absDelta) {
int kCount = absDelta.size();
double[] weights = new double[kCount];
for (int k = 0; k < kCount; k++) {
final double numerator = absDelta.get(k);
final double denominator = Utils.arrayListSum(absDelta);
weights[k] = (numerator / denominator);
}
return weights;
}
protected void updateMemory(final ArrayList<Double> successfulCr,
final ArrayList<Double> successfulF,
final ArrayList<Double> absDelta) {
if ((!successfulCr.isEmpty()) && (!successfulF.isEmpty())) {
assert ((absDelta.size() == successfulCr.size()) && (successfulCr.size() == successfulF.size()));
double[] weights = calculateLehmerWeihts(absDelta);
if ((Double.isNaN(memoryCr[memoryIndex])) || (Utils.arrayListMax(successfulCr) == 0)) {
memoryCr[memoryIndex] = Double.NaN;
} else {
memoryCr[memoryIndex] = Means.weightedLehmerMean(successfulCr, weights);
}
memoryF[memoryIndex] = Means.weightedLehmerMean(successfulF, weights);
++memoryIndex;
if (memoryIndex >= memorySize) {
memoryIndex = 0;
}
}
// StringBuilder sb = new StringBuilder();
// sb.append("MEMORY F: ");
// for (int i = 0; i < memoryF.length; i++) {
// sb.append(String.format("%.5f ", memoryF[i]));
// }
// sb.append("\n");
// sb.append("MEMORY Cr: ");
// for (int i = 0; i < memoryCr.length; i++) {
// sb.append(String.format("%.5f ", memoryCr[i]));
// }
// System.out.println(sb.toString());
}
public int getMemorySize() {
return memorySize;
}
public void setMemorySize(final int memorySize) {
this.memorySize = memorySize;
}
public void setMaxMutationGreediness(final double maxMutationGreediness) {
this.maxMutationGreediness = maxMutationGreediness;
}
public double getMaxMutationGreediness() {
return maxMutationGreediness;
}
public int getMinimalPopulationSize() {
return minimalPopulationSize;
}
public void setMinimalPopulationSize(int minimalPopulationSize) {
this.minimalPopulationSize = minimalPopulationSize;
}
}
...@@ -5,6 +5,9 @@ import azgracompress.quantization.vector.CodebookEntry; ...@@ -5,6 +5,9 @@ import azgracompress.quantization.vector.CodebookEntry;
import java.io.*; import java.io.*;
// TODO(Moravec): If we want to use Huffman codes we have to save additional information with the codebook.
// This information can be probability or the absolute frequencies of codebook indices.
public class QuantizationValueCache { public class QuantizationValueCache {
private final String cacheFolder; private final String cacheFolder;
......
...@@ -7,6 +7,7 @@ import azgracompress.utilities.Stopwatch; ...@@ -7,6 +7,7 @@ import azgracompress.utilities.Stopwatch;
import azgracompress.utilities.Utils; import azgracompress.utilities.Utils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
public class LloydMaxU16ScalarQuantization { public class LloydMaxU16ScalarQuantization {
...@@ -15,9 +16,11 @@ public class LloydMaxU16ScalarQuantization { ...@@ -15,9 +16,11 @@ public class LloydMaxU16ScalarQuantization {
private int dataMin; private int dataMin;
private int dataMax; private int dataMax;
private int dataSpan;
private long[] frequencies;
private int[] centroids; private int[] centroids;
private int[] boundaryPoints; private int[] boundaryPoints;
private double[] pdf; private double[] pdf;
private final int workerCount; private final int workerCount;
...@@ -35,6 +38,7 @@ public class LloydMaxU16ScalarQuantization { ...@@ -35,6 +38,7 @@ public class LloydMaxU16ScalarQuantization {
} }
private void initialize() { private void initialize() {
frequencies = new long[codebookSize];
centroids = new int[codebookSize]; centroids = new int[codebookSize];
boundaryPoints = new int[codebookSize + 1]; boundaryPoints = new int[codebookSize + 1];
...@@ -42,7 +46,7 @@ public class LloydMaxU16ScalarQuantization { ...@@ -42,7 +46,7 @@ public class LloydMaxU16ScalarQuantization {
MinMaxResult<Integer> minMax = Utils.getMinAndMax(trainingData); MinMaxResult<Integer> minMax = Utils.getMinAndMax(trainingData);
dataMin = minMax.getMin(); dataMin = minMax.getMin();
dataMax = minMax.getMax(); dataMax = minMax.getMax();
dataSpan = dataMax - dataMin; final int dataSpan = dataMax - dataMin;
centroids[0] = dataMin; centroids[0] = dataMin;
boundaryPoints[0] = dataMin; boundaryPoints[0] = dataMin;
...@@ -59,8 +63,8 @@ public class LloydMaxU16ScalarQuantization { ...@@ -59,8 +63,8 @@ public class LloydMaxU16ScalarQuantization {
Stopwatch s = new Stopwatch(); Stopwatch s = new Stopwatch();
s.start(); s.start();
for (int i = 0; i < trainingData.length; i++) { for (final int trainingDatum : trainingData) {
pdf[trainingData[i]] += 1.0; pdf[trainingDatum] += 1.0;
} }
s.stop(); s.stop();
...@@ -120,24 +124,23 @@ public class LloydMaxU16ScalarQuantization { ...@@ -120,24 +124,23 @@ public class LloydMaxU16ScalarQuantization {
public int quantize(final int value) { public int quantize(final int value) {
for (int intervalId = 1; intervalId <= codebookSize; intervalId++) { for (int intervalId = 1; intervalId <= codebookSize; intervalId++) {
if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) { if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) {
++frequencies[intervalId - 1];
return centroids[intervalId - 1]; return centroids[intervalId - 1];
} }
} }
throw new RuntimeException("Value couldn't be quantized!"); throw new RuntimeException("Value couldn't be quantized!");
} }
private double calculateMAE() { /**
double mae = 0.0; * Reset the frequencies array to zeros.
for (final int trainingDatum : trainingData) { */
int quantizedValue = quantize(trainingDatum); private void resetFrequencies() {
mae += Math.abs((double) trainingDatum - (double) quantizedValue); Arrays.fill(frequencies, 0);
}
return (mae / (double) trainingData.length);
} }
private double getCurrentMse() { private double getCurrentMse() {
double mse = 0.0; double mse = 0.0;
resetFrequencies();
Stopwatch s = new Stopwatch(); Stopwatch s = new Stopwatch();
s.start(); s.start();
...@@ -163,6 +166,7 @@ public class LloydMaxU16ScalarQuantization { ...@@ -163,6 +166,7 @@ public class LloydMaxU16ScalarQuantization {
try { try {
for (int wId = 0; wId < workerCount; wId++) { for (int wId = 0; wId < workerCount; wId++) {
workers[wId].join(); workers[wId].join();
addWorkerFrequencies(runnables[wId].getFrequencies());
mse += runnables[wId].getMse(); mse += runnables[wId].getMse();
} }
} catch (InterruptedException e) { } catch (InterruptedException e) {
...@@ -184,6 +188,13 @@ public class LloydMaxU16ScalarQuantization { ...@@ -184,6 +188,13 @@ public class LloydMaxU16ScalarQuantization {
return mse; return mse;
} }
private void addWorkerFrequencies(final long[] workerFrequencies) {
assert (frequencies.length == workerFrequencies.length) : "Frequency array length mismatch.";
for (int i = 0; i < frequencies.length; i++) {
frequencies[i] += workerFrequencies[i];
}
}
public QTrainIteration[] train(final boolean shouldBeVerbose) { public QTrainIteration[] train(final boolean shouldBeVerbose) {
this.verbose = shouldBeVerbose; this.verbose = shouldBeVerbose;
final int RECALCULATE_N_TIMES = 10; final int RECALCULATE_N_TIMES = 10;
...@@ -224,21 +235,15 @@ public class LloydMaxU16ScalarQuantization { ...@@ -224,21 +235,15 @@ public class LloydMaxU16ScalarQuantization {
recalculateCentroids(); recalculateCentroids();
} }
currMAE = calculateMAE();
prevMse = currentMse; prevMse = currentMse;
currentMse = getCurrentMse(); currentMse = getCurrentMse();
mseImprovement = prevMse - currentMse; mseImprovement = prevMse - currentMse;
// System.out.println(String.format("Improvement: %.4f", mseImprovement));
psnr = Utils.calculatePsnr(currentMse, U16.Max); psnr = Utils.calculatePsnr(currentMse, U16.Max);
solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr)); solutionHistory.add(new QTrainIteration(++iteration, currentMse, currentMse, psnr, psnr));
// dist = (prevMse - currentMse) / currentMse;
if (verbose) { if (verbose) {
System.out.println(String.format("Current MAE: %.4f MSE: %.4f PSNR: %.4f dB", System.out.println(String.format("Current MSE: %.4f PSNR: %.4f dB",
currMAE,
currentMse, currentMse,
psnr)); psnr));
} }
...@@ -260,5 +265,9 @@ public class LloydMaxU16ScalarQuantization { ...@@ -260,5 +265,9 @@ public class LloydMaxU16ScalarQuantization {
public int[] getCentroids() { public int[] getCentroids() {
return centroids; return centroids;
} }
public ScalarQuantizationCodebook getCodebook() {
return new ScalarQuantizationCodebook(centroids, frequencies);
}
} }
...@@ -8,6 +8,7 @@ public class RunnableLloydMseCalc implements Runnable { ...@@ -8,6 +8,7 @@ public class RunnableLloydMseCalc implements Runnable {
final int[] boundaryPoints; final int[] boundaryPoints;
final int codebookSize; final int codebookSize;
double mse = 0.0; double mse = 0.0;
final long[] frequencies;
public RunnableLloydMseCalc(int[] trainingData, int fromIndex, int toIndex, int[] centroids, int[] boundaryPoints, public RunnableLloydMseCalc(int[] trainingData, int fromIndex, int toIndex, int[] centroids, int[] boundaryPoints,
final int codebookSize) { final int codebookSize) {
...@@ -17,13 +18,16 @@ public class RunnableLloydMseCalc implements Runnable { ...@@ -17,13 +18,16 @@ public class RunnableLloydMseCalc implements Runnable {
this.centroids = centroids; this.centroids = centroids;
this.boundaryPoints = boundaryPoints; this.boundaryPoints = boundaryPoints;
this.codebookSize = codebookSize; this.codebookSize = codebookSize;
this.frequencies = new long[centroids.length];
} }
public long[] getFrequencies() {
return frequencies;
}
@Override @Override
public void run() { public void run() {
mse = 0.0; mse = 0.0;
for (int i = fromIndex; i < toIndex; i++) { for (int i = fromIndex; i < toIndex; i++) {
mse += Math.pow((double) trainingData[i] - (double) quantize(trainingData[i]), 2); mse += Math.pow((double) trainingData[i] - (double) quantize(trainingData[i]), 2);
} }
...@@ -36,6 +40,7 @@ public class RunnableLloydMseCalc implements Runnable { ...@@ -36,6 +40,7 @@ public class RunnableLloydMseCalc implements Runnable {
private int quantize(final int value) { private int quantize(final int value) {
for (int intervalId = 1; intervalId <= codebookSize; intervalId++) { for (int intervalId = 1; intervalId <= codebookSize; intervalId++) {
if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) { if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) {
++frequencies[intervalId - 1];
return centroids[intervalId - 1]; return centroids[intervalId - 1];
} }
} }
......
package azgracompress.quantization.scalar;
public class ScalarQuantizationCodebook {
/**
* Quantization values.
*/
final int[] centroids;
/**
* Absolute frequencies of quantization values.
*/
final long[] indexFrequencies;
final int codebookSize;
/**
* @param centroids Quantization values.
* @param indexFrequencies Absolute frequencies of quantization values.
*/
public ScalarQuantizationCodebook(final int[] centroids, final long[] indexFrequencies) {
this.centroids = centroids;
this.indexFrequencies = indexFrequencies;
this.codebookSize = this.centroids.length;
}
public int[] getCentroids() {
return centroids;
}
public long[] getIndicesFrequency() {
return indexFrequencies;
}
public int getCodebookSize() {
return codebookSize;
}
}
...@@ -3,12 +3,12 @@ package azgracompress.quantization.scalar; ...@@ -3,12 +3,12 @@ package azgracompress.quantization.scalar;
public class ScalarQuantizer { public class ScalarQuantizer {
private final int min; private final int min;
private final int max; private final int max;
private int[] centroids; private final ScalarQuantizationCodebook codebook;
private int[] boundaryPoints; private int[] boundaryPoints;
public ScalarQuantizer(final int min, final int max, final int[] centroids) { public ScalarQuantizer(final int min, final int max, final ScalarQuantizationCodebook codebook) {
this.centroids = centroids; this.codebook = codebook;
boundaryPoints = new int[centroids.length + 1]; boundaryPoints = new int[codebook.getCodebookSize() + 1];
this.min = min; this.min = min;
this.max = max; this.max = max;
...@@ -63,14 +63,15 @@ public class ScalarQuantizer { ...@@ -63,14 +63,15 @@ public class ScalarQuantizer {
private void calculateBoundaryPoints() { private void calculateBoundaryPoints() {
boundaryPoints[0] = min; boundaryPoints[0] = min;
boundaryPoints[centroids.length] = max; boundaryPoints[codebook.getCodebookSize()] = max;
final int[] centroids = codebook.getCentroids();
for (int j = 1; j < centroids.length; j++) { for (int j = 1; j < centroids.length; j++) {
boundaryPoints[j] = (this.centroids[j] + this.centroids[j - 1]) / 2; boundaryPoints[j] = (centroids[j] + centroids[j - 1]) / 2;
} }
} }
public int quantizeIndex(final int value) { public int quantizeIndex(final int value) {
for (int intervalId = 1; intervalId <= centroids.length; intervalId++) { for (int intervalId = 1; intervalId <= codebook.getCodebookSize(); intervalId++) {
if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) { if ((value >= boundaryPoints[intervalId - 1]) && (value <= boundaryPoints[intervalId])) {
return (intervalId - 1); return (intervalId - 1);
} }
...@@ -79,7 +80,7 @@ public class ScalarQuantizer { ...@@ -79,7 +80,7 @@ public class ScalarQuantizer {
} }
public int quantize(final int value) { public int quantize(final int value) {
return centroids[quantizeIndex(value)]; return codebook.getCentroids()[quantizeIndex(value)];
} }
public double getMse(final int[] data) { public double getMse(final int[] data) {
...@@ -93,6 +94,6 @@ public class ScalarQuantizer { ...@@ -93,6 +94,6 @@ public class ScalarQuantizer {
} }
public int[] getCentroids() { public int[] getCentroids() {
return centroids; return codebook.getCentroids();
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment