Skip to content
Snippets Groups Projects
Commit e6b90f7a authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Return solution history from train method.

parent 8e3bae61
No related branches found
No related tags found
No related merge requests found
...@@ -2,10 +2,14 @@ import org.apache.commons.math3.distribution.CauchyDistribution; ...@@ -2,10 +2,14 @@ import org.apache.commons.math3.distribution.CauchyDistribution;
import quantization.LloydMaxU16ScalarQuantization; import quantization.LloydMaxU16ScalarQuantization;
import quantization.Utils; import quantization.Utils;
import quantization.de.DeException; import quantization.de.DeException;
import quantization.de.DeHistory;
import quantization.de.jade.JadeSolver; import quantization.de.jade.JadeSolver;
import quantization.utilities.Stopwatch; import quantization.utilities.Stopwatch;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
...@@ -25,34 +29,7 @@ class RunnableTest implements Runnable { ...@@ -25,34 +29,7 @@ class RunnableTest implements Runnable {
} }
public class DataCompressor { public class DataCompressor {
public static void main(String[] args) throws FileNotFoundException { public static void main(String[] args) throws IOException {
/*
int coreCount = Runtime.getRuntime().availableProcessors() - 1;
// Thread[] threads = new Thread[coreCount];
ExecutorService es = Executors.newFixedThreadPool(coreCount);
RunnableTest[] runnables = new RunnableTest[coreCount];
for (int i = 0; i < coreCount; i++) {
runnables[i] = new RunnableTest();
es.execute(runnables[i]);
}
es.shutdown();
try {
es.awaitTermination(1, TimeUnit.MINUTES);
} catch (InterruptedException e) {
System.out.println("Thread interrupted: " + e.getMessage());
}
for (int i = 0; i < coreCount; i++) {
System.out.println(runnables[i].tid);
}
System.out.println("All threads finished");
*/
final String sourceFile = "D:\\tmp\\server-dump\\small.bin"; final String sourceFile = "D:\\tmp\\server-dump\\small.bin";
final int NumberOfBits = 4; final int NumberOfBits = 4;
...@@ -62,15 +39,29 @@ public class DataCompressor { ...@@ -62,15 +39,29 @@ public class DataCompressor {
// LloydMaxU16ScalarQuantization quantization = new LloydMaxU16ScalarQuantization(values, NumberOfBits); // LloydMaxU16ScalarQuantization quantization = new LloydMaxU16ScalarQuantization(values, NumberOfBits);
// quantization.train(); // quantization.train();
JadeSolver jadeSolver = new JadeSolver(Dimension, 10 * Dimension, 100, 0.05, 0.1); JadeSolver jadeSolver = new JadeSolver(Dimension, 10 * Dimension, 250, 0.05, 0.1);
jadeSolver.setTrainingData(values); jadeSolver.setTrainingData(values);
DeHistory[] solutionHistory = null;
try { try {
jadeSolver.train(); solutionHistory = jadeSolver.train();
} catch (DeException e) { } catch (DeException e) {
e.printStackTrace(); e.printStackTrace();
} }
if (solutionHistory != null) {
FileOutputStream os = new FileOutputStream("JadeSolutionHistory.csv");
OutputStreamWriter writer = new OutputStreamWriter(os);
writer.write("Generation;AvgCost;BestCost\n");
for (final DeHistory hist : solutionHistory) {
writer.write(String.format("%d;%.5f;%.5f\n", hist.getIteration(), hist.getAvgCost(), hist.getBestCost()));
}
writer.flush();
writer.close();
os.flush();
os.close();
}
System.out.println("Finished learning..."); System.out.println("Finished learning...");
} }
} }
package quantization.de;
public class DeHistory {
private int m_iteration = 0;
private double m_avgCost = 0;
private double m_bestCost = 0;
public DeHistory(final int it, final double avgCost, final double bestCost) {
m_iteration = it;
m_avgCost = avgCost;
m_bestCost = bestCost;
}
public double getBestCost() {
return m_bestCost;
}
public void setBestCost(double bestCost) {
this.m_bestCost = bestCost;
}
public double getAvgCost() {
return m_avgCost;
}
public void setAvgCost(double avgCost) {
this.m_avgCost = avgCost;
}
public int getIteration() {
return m_iteration;
}
public void setIteration(int iteration) {
this.m_iteration = iteration;
}
}
...@@ -11,7 +11,7 @@ public interface IDESolver { ...@@ -11,7 +11,7 @@ public interface IDESolver {
void setDimension(final int dimension); void setDimension(final int dimension);
void train() throws DeException; DeHistory[] train() throws DeException;
void setTrainingData(final int[] data); void setTrainingData(final int[] data);
... ...
......
...@@ -9,6 +9,7 @@ import org.apache.commons.math3.random.RandomGenerator; ...@@ -9,6 +9,7 @@ import org.apache.commons.math3.random.RandomGenerator;
import quantization.Quantizer; import quantization.Quantizer;
import quantization.U16; import quantization.U16;
import quantization.Utils; import quantization.Utils;
import quantization.de.DeHistory;
import quantization.de.IDESolver; import quantization.de.IDESolver;
import quantization.de.IIndividual; import quantization.de.IIndividual;
import quantization.de.DeException; import quantization.de.DeException;
...@@ -358,8 +359,9 @@ public class JadeSolver implements IDESolver { ...@@ -358,8 +359,9 @@ public class JadeSolver implements IDESolver {
} }
@Override @Override
public void train() throws DeException { public DeHistory[] train() throws DeException {
final String delimiter = "-------------------------------------------"; final String delimiter = "-------------------------------------------";
DeHistory[] solutionHistory = new DeHistory[m_generationCount];
if (m_trainingData == null || m_trainingData.length <= 0) { if (m_trainingData == null || m_trainingData.length <= 0) {
throw new DeException("Training data weren't set."); throw new DeException("Training data weren't set.");
} }
...@@ -436,7 +438,9 @@ public class JadeSolver implements IDESolver { ...@@ -436,7 +438,9 @@ public class JadeSolver implements IDESolver {
// System.out.println(String.format("Generation %d average fitness(COST): %.5f", (generation + 1), avgFitness)); // System.out.println(String.format("Generation %d average fitness(COST): %.5f", (generation + 1), avgFitness));
System.out.println(generationLog.toString()); System.out.println(generationLog.toString());
solutionHistory[generation] = new DeHistory(generation, avgFitness, m_currentPopulationSorted[0].getFitness());
} }
return solutionHistory;
} }
@Override @Override
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment