diff --git a/java/src/BPNet.java b/java/src/BPNet.java new file mode 100644 index 0000000000000000000000000000000000000000..81f1c98881a89eb95be9da10ad985e4fe5577bf2 --- /dev/null +++ b/java/src/BPNet.java @@ -0,0 +1,462 @@ +package cz.vsb.mro0010.neuralnetworks; + +import java.util.ArrayList; +import java.util.Arrays; + +public class BPNet extends MultiLayeredNet { + + protected float error; + + + public float getError() { + return error; + } + + + public void setError(float error) { + this.error = error; + } + + + protected float tolerance; + protected String neuronType; + protected float learnCoeff; + + + + public BPNet( float tolerance, int nrOfLayers, int nrOfInputs, ArrayList<Integer> nrOfNeuronsPerLayer, float slope, float learnCoeff ) { + super(nrOfInputs, nrOfLayers, nrOfNeuronsPerLayer); + this.neuronType = "SigmoidalNeuron"; + this.tolerance = tolerance; + this.learnCoeff = learnCoeff; + + for (int i = 0; i < nrOfLayers; i++) { + for (int j = 0; j < nrOfNeuronsPerLayer.get(i); j++) { + this.neuronLayers.get(i).add(new SigmoidalNeuron(slope)); + } + } + for (int i = 0; i < nrOfLayers; i++) { + this.interconnectionsLayers.add(new InterconnectionsBP(this.learnCoeff)); + } + for (Neuron neuronIn : this.inputNeuronLayer) { + for (Neuron neuronFirstLevel : this.neuronLayers.get(0)) { + this.interconnectionsLayers.get(0).addConnection(new Connection(neuronIn, neuronFirstLevel, (float) (Math.random()))); + } + } + for (int i = 1; i < nrOfLayers; i++) { + for (Neuron neuronIn : this.neuronLayers.get(i-1)) { + for (Neuron neuronOut : this.neuronLayers.get(i)) { + this.interconnectionsLayers.get(i).addConnection(new Connection(neuronIn, neuronOut, (float) (Math.random()))); + } + + } + } + + + } + + + public float getTolerance() { + return tolerance; + } + + public void setTolerance(float tolerance) { + this.tolerance = tolerance; + } + + + @Override + public String getNeuronType() { + return this.neuronType; + } + + @Override + public int learn(String trainingSet) { + boolean learned = false; + int iter = 0; + ArrayList<String> trainingElements = new ArrayList<String>(Arrays.asList(trainingSet.split("\n"))); + while(!learned) { + learned = true; + this.error = 0; + for (int i = 0; i < trainingElements.size(); i++) { + learned &= learnStep(trainingElements.get(i)); + } + iter++; +// System.out.println(iter); + } + return iter; +// System.out.println("Learned in " + iter + " whole training set iterations."); + } + + public boolean learnStep(String trainingElement) { + // Run training Element + String[] splitedTrainingElement = trainingElement.split(" "); + StringBuffer inputString = new StringBuffer(); + for (int i = 0; i < this.nrOfInputs; i++) { //Input values + inputString.append(splitedTrainingElement[i]); + inputString.append(" "); + } + ArrayList<Float> expectedValues = new ArrayList<Float>(); + for (int i = this.nrOfInputs; i < splitedTrainingElement.length; i++) { //Expected values + expectedValues.add(Float.parseFloat(splitedTrainingElement[i])); + } + this.run(inputString.substring(0, inputString.length() - 1)); + // Calculate error + float error = 0; + for (int i = 0; i < expectedValues.size(); i++) { + float y = this.neuronLayers.get(this.nrOfLayers-1).get(i).getState(); //output of ith neuron + float o = expectedValues.get(i); + error += (float)( 0.5 * Math.pow((y-o), 2)); + } + if (this.error < error) { + this.error = error; + } + if (error > this.tolerance) { //Error is too high -> modify weights + // Calculate deltas + for (int i = this.nrOfLayers - 1; i >= 0; i -= 1) { + for (Neuron n : this.neuronLayers.get(i)) { + SigmoidalNeuron neuron = (SigmoidalNeuron)n; + if (i == this.nrOfLayers - 1) { //Top layer + float y = neuron.getState(); + float o = expectedValues.get(this.neuronLayers.get(i).indexOf(neuron)); + float delta = y - o; + neuron.setError(delta); + } else { //Other layers + ArrayList<Connection> connectionsToUpperLayerFromNeuron = new ArrayList<Connection>(); + // Find all connections, that have "neuron" as input + for (Connection c : this.interconnectionsLayers.get(i+1).getConnections()) { + if (c.getInputNeuron().equals(neuron)) + connectionsToUpperLayerFromNeuron.add(c); + } + float delta = 0; + for (Connection c : connectionsToUpperLayerFromNeuron) { + float deltaUpper = ((SigmoidalNeuron)c.getOutputNeuron()).getError(); + float lambdaUpper = ((SigmoidalNeuron)c.getOutputNeuron()).getSlope(); + float yUpper = c.getOutputNeuron().getState(); + float w = c.getWeight(); + delta += deltaUpper*lambdaUpper*yUpper*(1-yUpper)*w; + } + neuron.setError(delta); + } + } + } + // Adjust weights + for (Interconnections interconnectionsLayer : this.interconnectionsLayers) { + interconnectionsLayer.adjustWeights(); + } + return false; + } else { + return true; + } + + + } + + public String getOutput() { + StringBuffer output = new StringBuffer(); + ArrayList<Neuron> outputLayer = this.neuronLayers.get(this.nrOfLayers-1); + for (int i = 0; i < outputLayer.size(); i++) { + output.append(String.valueOf(outputLayer.get(i).getState())); + output.append(" "); + } + + return output.toString(); + } + + + + + public void changeSlopeTo(float slope) { + for (ArrayList<Neuron> neuronLayer : this.neuronLayers) { + for (Neuron neuron : neuronLayer) { + ((SigmoidalNeuron)neuron).setSlope(slope); + } + } + } + + + public void changeLearnCoeffTo(float learnCoeff) { + for (Interconnections layer : interconnectionsLayers) { + ((InterconnectionsBP)layer).setLearningRate(learnCoeff); + } + + } + + + public void resetWeights() { + for (Interconnections layer : interconnectionsLayers) { + for (Connection connection : layer.getConnections()) { + connection.setWeight((float)Math.random()); + } + } + + } + + public void addNeuron(int layerIndex, float slope) { + SigmoidalNeuron newNeuron = new SigmoidalNeuron(slope); + neuronLayers.get(layerIndex).add(newNeuron); + if ((layerIndex < nrOfLayers) && (layerIndex >= 0)) { + Interconnections inputConnectionLayer = this.interconnectionsLayers.get(layerIndex); + if (layerIndex == 0) { + ArrayList<InputLayerPseudoNeuron> inputNeurons = this.inputNeuronLayer; + for (Neuron inputNeuron : inputNeurons) { + inputConnectionLayer.addConnection(new Connection(inputNeuron, newNeuron, (float)Math.random())); + } + } else { + ArrayList<Neuron> inputNeurons = this.neuronLayers.get(layerIndex - 1); + for (Neuron inputNeuron : inputNeurons) { + inputConnectionLayer.addConnection(new Connection(inputNeuron, newNeuron, (float)Math.random())); + } + } + + if (layerIndex < nrOfLayers - 1) { + Interconnections outputConnectionLayer = this.interconnectionsLayers.get(layerIndex + 1); + ArrayList<Neuron> outputNeurons = this.neuronLayers.get(layerIndex + 1); + for (Neuron outputNeuron : outputNeurons) { + outputConnectionLayer.addConnection(new Connection(newNeuron, outputNeuron, (float)Math.random())); + } + } + this.nrOfNeuronsPerLayer.set(layerIndex, this.nrOfNeuronsPerLayer.get(layerIndex) + 1 ); + + + } else { + + throw new InvalidLayerNumberException(); + + } + } + + public void removeNeuron(int layerIndex) { + int nrOfNeuronsInThisLayer = this.nrOfNeuronsPerLayer.get(layerIndex); + if ((layerIndex < nrOfLayers) && (layerIndex >= 0)) { + if (nrOfNeuronsInThisLayer == 1) { + + removeNeuronLayer(layerIndex); + + } else { + Neuron removedNeuron = this.neuronLayers.get(layerIndex).get(nrOfNeuronsInThisLayer - 1); + Interconnections inputConnectionLayer = this.interconnectionsLayers.get(layerIndex); + ArrayList<Connection> removedConnections = new ArrayList<Connection>(); + for (Connection connection : inputConnectionLayer.getConnections()) { + if (connection.getOutputNeuron().equals(removedNeuron)) { + removedConnections.add(connection); + } + } + for (Connection connection : removedConnections) { + inputConnectionLayer.getConnections().remove(connection); + } + removedConnections = new ArrayList<Connection>(); + if (layerIndex < nrOfLayers - 1) { + Interconnections outputConnectionLayer = this.interconnectionsLayers.get(layerIndex + 1); + for (Connection connection : outputConnectionLayer.getConnections()) { + if (connection.getInputNeuron().equals(removedNeuron)) { + removedConnections.add(connection); + } + } + for (Connection connection : removedConnections) { + outputConnectionLayer.getConnections().remove(connection); + } + } + + this.neuronLayers.get(layerIndex).remove(removedNeuron); + this.nrOfNeuronsPerLayer.set(layerIndex, this.nrOfNeuronsPerLayer.get(layerIndex) - 1 ); + } + + } else { + throw new InvalidLayerNumberException(); + } + } + + public void addNeuronLayer(int nrOfNeurons, int layerIndex, float slope) { + if ((layerIndex < nrOfLayers + 1) && (layerIndex >= 0) && (nrOfNeurons > 0)) { + + this.nrOfLayers++; + this.nrOfNeuronsPerLayer.add(layerIndex, nrOfNeurons); + // new layer creation + ArrayList<Neuron> newNeuronLayer = new ArrayList<Neuron>(); + for (int i = 0; i < nrOfNeurons; i++) { + newNeuronLayer.add(new SigmoidalNeuron(slope)); + } + // old connections removal + if (layerIndex < nrOfLayers - 1) { // only if inner layer is added + this.interconnectionsLayers.remove(layerIndex); + } + // new layer adding + this.neuronLayers.add(layerIndex, newNeuronLayer); + // new connections creation + // input + Interconnections inputConnLayer = new InterconnectionsBP(learnCoeff); + if (layerIndex == 0) { + ArrayList<InputLayerPseudoNeuron> inputNeurons = this.inputNeuronLayer; + ArrayList<Neuron> outputNeurons = newNeuronLayer; //Layers already shifted + for (Neuron inputNeuron : inputNeurons) { + for (Neuron outputNeuron : outputNeurons) { + inputConnLayer.addConnection(new Connection(inputNeuron, outputNeuron, (float)Math.random())); + } + } + } else { + ArrayList<Neuron> inputNeurons = this.neuronLayers.get(layerIndex - 1); + ArrayList<Neuron> outputNeurons = newNeuronLayer; //Layers already shifted, this is new layer + for (Neuron inputNeuron : inputNeurons) { + for (Neuron outputNeuron : outputNeurons) { + inputConnLayer.addConnection(new Connection(inputNeuron, outputNeuron, (float)Math.random())); + } + } + } + this.interconnectionsLayers.add(layerIndex, inputConnLayer); + // output + Interconnections outputConnLayer = new InterconnectionsBP(learnCoeff); + if (layerIndex < nrOfLayers - 1) { + ArrayList<Neuron> inputNeurons = newNeuronLayer; + ArrayList<Neuron> outputNeurons = this.neuronLayers.get(layerIndex + 1); //Layers already shifted + for (Neuron inputNeuron : inputNeurons) { + for (Neuron outputNeuron : outputNeurons) { + outputConnLayer.addConnection(new Connection(inputNeuron, outputNeuron, (float)Math.random())); + } + } + this.interconnectionsLayers.add(layerIndex + 1, outputConnLayer); + } + + + } else { + throw new InvalidLayerNumberException(); + } + + } + + + + public void removeNeuronLayer(int layerIndex) { + if ((layerIndex < nrOfLayers ) && (layerIndex >= 0) && (nrOfLayers > 1)) { + // delete output connections + if (layerIndex < nrOfLayers - 1) { + this.interconnectionsLayers.remove(layerIndex + 1); + } + // delete input connections + this.interconnectionsLayers.remove(layerIndex); + // delete neurons on layer + this.neuronLayers.remove(layerIndex); + this.nrOfNeuronsPerLayer.remove(layerIndex); + this.nrOfLayers--; + // create new connections + if (layerIndex < nrOfLayers + 1) { + Interconnections connLayer = new InterconnectionsBP(learnCoeff); + if (layerIndex == 0) { + ArrayList<InputLayerPseudoNeuron> inputNeurons = this.inputNeuronLayer; + ArrayList<Neuron> outputNeurons = this.neuronLayers.get(0); + for (Neuron inputNeuron : inputNeurons) { + for (Neuron outputNeuron : outputNeurons) { + connLayer.addConnection(new Connection(inputNeuron, outputNeuron, (float)Math.random())); + } + } + } else { + ArrayList<Neuron> inputNeurons = this.neuronLayers.get(layerIndex - 1); + ArrayList<Neuron> outputNeurons = this.neuronLayers.get(layerIndex); + for (Neuron inputNeuron : inputNeurons) { + for (Neuron outputNeuron : outputNeurons) { + connLayer.addConnection(new Connection(inputNeuron, outputNeuron, (float)Math.random())); + } + } + } + this.interconnectionsLayers.add(layerIndex, connLayer); + + } + + } else { + throw new InvalidLayerNumberException(); + } + } + + @Override + public String toString() { + return getNeuronMap(); + } + + public String getNeuronMap() { + StringBuffer map = new StringBuffer(); + for (int i = 0; i < nrOfLayers; i++) { + map.append(String.valueOf(nrOfNeuronsPerLayer.get(i))); + map.append(" "); + } + map.deleteCharAt(map.length() - 1); + return map.toString(); + } + + public static void main(String[] args) { + ArrayList<Integer> nrOfNeuronsPerLayer = new ArrayList<Integer>(); + nrOfNeuronsPerLayer.add(10); + nrOfNeuronsPerLayer.add(7); + nrOfNeuronsPerLayer.add(2); + BPNet net = new BPNet( (float)0.01, 3, 5, nrOfNeuronsPerLayer, (float)1.8, (float)0.7); // bigger slope = better resolution + + String trainingSet = "0.4 0.5 1 0.5 1 0 1\n0 0 0 0 0 1 1\n0.1 0.2 0.3 0.4 0.5 0 0\n1 0 1 0 1 1 0\n0.2 0.4 0 0 0.9 0 1"; + net.learn(trainingSet); + net.run("0.4 0.5 1 0.5 1"); //expected 0 1 + System.out.println(net.getOutput()); + net.run("0 0 0 0 0"); // 1 1 + System.out.println(net.getOutput()); + net.run("0.1 0.2 0.3 0.4 0.5"); // 0 0 + System.out.println(net.getOutput()); + net.run("1 0 1 0 1"); // 1 0 + System.out.println(net.getOutput()); + net.run("0.2 0.4 0 0 0.9"); // 0 1 + System.out.println(net.getOutput()); + + System.out.println("Not trained elements:"); + net.run("0.9 0.1 0.9 0.1 0.9"); // expected 1 0 + System.out.println(net.getOutput()); + net.run("0.01 0.01 0.01 0.01 0.01"); // expected 1 1 + System.out.println(net.getOutput()); + net.run("0.15 0.15 0.35 0.35 0.5"); // 0 0 + System.out.println(net.getOutput()); + + System.out.println(net.getNeuronMap()); + net.addNeuron(0, 1.8f); + System.out.println(net.getNeuronMap()); + net.addNeuron(1, 1.8f); + System.out.println(net.getNeuronMap()); + net.addNeuron(2, 1.8f); + System.out.println(net.getNeuronMap()); + net.removeNeuron(0); + System.out.println(net.getNeuronMap()); + net.removeNeuron(1); + System.out.println(net.getNeuronMap()); + net.removeNeuron(2); + System.out.println(net.getNeuronMap()); + + net.addNeuronLayer(5, 0, 1.8f); + System.out.println(net.getNeuronMap()); + net.addNeuronLayer(5, 2, 1.8f); + System.out.println(net.getNeuronMap()); + net.addNeuronLayer(5, 5, 1.8f); + System.out.println(net.getNeuronMap()); + + net.removeNeuronLayer(5); + System.out.println(net.getNeuronMap()); + net.removeNeuronLayer(2); + System.out.println(net.getNeuronMap()); + net.removeNeuronLayer(0); + System.out.println(net.getNeuronMap()); + + net.learn(trainingSet); + net.run("0.4 0.5 1 0.5 1"); //expected 0 1 + System.out.println(net.getOutput()); + net.run("0 0 0 0 0"); // 1 1 + System.out.println(net.getOutput()); + net.run("0.1 0.2 0.3 0.4 0.5"); // 0 0 + System.out.println(net.getOutput()); + net.run("1 0 1 0 1"); // 1 0 + System.out.println(net.getOutput()); + net.run("0.2 0.4 0 0 0.9"); // 0 1 + System.out.println(net.getOutput()); + + System.out.println("Not trained elements:"); + net.run("0.9 0.1 0.9 0.1 0.9"); // expected 1 0 + System.out.println(net.getOutput()); + net.run("0.01 0.01 0.01 0.01 0.01"); // expected 1 1 + System.out.println(net.getOutput()); + net.run("0.15 0.15 0.35 0.35 0.5"); // 0 0 + System.out.println(net.getOutput()); + } + +} diff --git a/java/src/BinaryNeuron.java b/java/src/BinaryNeuron.java new file mode 100644 index 0000000000000000000000000000000000000000..25a0b26fdfdf4b36904d110960c55d4af6e284a0 --- /dev/null +++ b/java/src/BinaryNeuron.java @@ -0,0 +1,15 @@ +package cz.vsb.mro0010.neuralnetworks; + +public class BinaryNeuron extends Neuron { + + @Override + public void transfer() { + if (this.getPotential() > this.getThreshold()) { + this.setState(1); + } else { + this.setState(0); + } + + } + +} diff --git a/java/src/CarDriverClient.java b/java/src/CarDriverClient.java new file mode 100644 index 0000000000000000000000000000000000000000..bd5a0b9fa5b5895a8a6055e1c7b96be59c2f3c92 --- /dev/null +++ b/java/src/CarDriverClient.java @@ -0,0 +1,528 @@ +package cz.vsb.mro0010.neuralnetworks; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileReader; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.IOException; +import java.io.StreamTokenizer; +import java.net.ConnectException; +import java.net.Socket; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import javax.swing.JOptionPane; + + +/** + * Jednoduchy ukazkovy klient. + * Pripoji se k zavodnimu serveru a ridi auto. + * + */ +public class CarDriverClient { + + private Socket socket; // spojeni + private BufferedReader in; // cteni se serveru + private BufferedWriter out; // zapis na server + private BPNet neuralNetwork; + + /** + * Pripoji se k zavodu. + * + * @param host zavodni server + * @param port port serveru + * @param raceName nazev zavodu, do nehoz se chce klient pripojit + * @param driverName jmeno ridice + * @throws java.lang.IOException problem s pripojenim + */ + public CarDriverClient(String host, int port, String raceName, String driverName, String carType, BPNet neuralNetwork) throws IOException { + // add neural net + this.neuralNetwork = neuralNetwork; + + // connect to server + socket = new Socket(host, port); + out = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream(), "UTF-8")); + in = new BufferedReader(new InputStreamReader(socket.getInputStream(), "UTF-8")); + + // connect to race + out.write("driver\n"); // protocol specification + out.write("race:" + raceName + "\n"); // race name + out.write("driver:" + driverName + "\n"); // driver name + out.write("color:0000FF\n"); // car color + if(carType != null){ + out.write("car:" + carType + "\n"); // car type + } + out.write("\n"); + out.flush(); + + // precteni a kontrola dopovedi serveru + String line = in.readLine(); + if (!line.equals("ok")) { + // pokud se pripojeni nepodari, je oznamena chyba a vyvolana vyjimka + System.err.println("Chyba: " + line); + throw new ConnectException(line); + } + in.readLine(); // precteni prazdneho radku + } + + public static List<String> listRaces(String host, int port) throws IOException { + // pripojeni k serveru + Socket socket = new Socket(host, port); + BufferedWriter out = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream(), "UTF-8")); + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream(), "UTF-8")); + + // pripojeni k zavodu + out.write("racelist\n"); // specifikace protokolu + out.write("\n"); + out.flush(); + + // precteni a kontrola dopovedi serveru + String line = in.readLine(); + if (!line.equals("ok")) { + // pokud se pripojeni nepodari, je oznamena chyba a vyvolana vyjimka + System.err.println("Chyba: " + line); + throw new ConnectException(line); + } + line = in.readLine(); // precteni prazdneho radku + List<String> racelist = new ArrayList<String>(); + line = in.readLine(); + System.out.println("Races:"); + while(line != null && !"".equals(line)){ + racelist.add(line); + System.out.println(line); + line = in.readLine(); + } + return racelist; + } + public static List<String> listCars(String host, int port, String raceName) throws IOException { + // pripojeni k serveru + Socket socket = new Socket(host, port); + BufferedWriter out = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream(), "UTF-8")); + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream(), "UTF-8")); + + // pripojeni k zavodu + out.write("carlist\n"); // specifikace protokolu + out.write("race:" + raceName + "\n"); + out.write("\n"); + out.flush(); + + // precteni a kontrola dopovedi serveru + String line = in.readLine(); + if (!line.equals("ok")) { + // pokud se pripojeni nepodari, je oznamena chyba a vyvolana vyjimka + System.err.println("Chyba: " + line); + throw new ConnectException(line); + } + line = in.readLine(); // precteni prazdneho radku + List<String> carList = new ArrayList<String>(); + line = in.readLine(); + System.out.println("cars:"); + while(line != null && !"".equals(line)){ + carList.add(line); + System.out.println(line); + line = in.readLine(); + } + return carList; + } + + /** + * Beh zavodu. Cte data ze serveru. Spousti rizeni auta. + * Ukonci se po ukonceni zavodu. + * + * @throws java.io.IOException problem ve spojeni k serveru + */ + public void run() throws IOException { + while (true) { // smycka do konce zavodu + String line = in.readLine(); +// System.out.println(line); + if (line.equals("round")) { // dalsi kolo v zavode + round(); + } else if (line.equals("finish")) { // konec zavodu konci smucku + break; + } else { + System.err.println("Chyba se serveru: " + line); + } + } + } + + /** + * Resi jedno posunuti auta. Precte pozici auta od servru, + * vypocte nastaveni rizeni, ktere na server. + * + * @throws java.io.IOException problem ve spojeni k serveru + */ + public void round() throws IOException { + float angle = 0; // uhel k care <0,1> + float speed = 0; // rychlost auta <0,1> + float distance0 = 0; // vzdalenost od cary <0,1> + float distance4 = 0; // vzdalenost od cary za 4m<0,1> + float distance8 = 0; // vzdalenost od cary za 8m<0,1> + float distance16 = 0; // vzdalenost od cary za 16m<0,1> + float distance32 = 0; // vzdalenost od cary za 32m<0,1> + float friction = 0; + float skid = 0; + float checkpoint = 0; + float sensorFrontLeft = 0; + float sensorFrontMiddleLeft = 0; + float sensorFrontMiddleRight = 0; + float sensorFrontRight = 0; + float sensorFrontRightCorner1 = 0; + float sensorFrontRightCorner2 = 0; + float sensorRight1 = 0; + float sensorRight2 = 0; + float sensorRearRightCorner2 = 0; + float sensorRearRightCorner1 = 0; + float sensorRearRight = 0; + float sensorRearLeft = 0; + float sensorRearLeftCorner1 = 0; + float sensorRearLeftCorner2 = 0; + float sensorLeft1 = 0; + float sensorLeft2 = 0; + float sensorFrontLeftCorner1 = 0; + float sensorFrontLeftCorner2 = 0; + + // cteni dat ze serveru + String line = in.readLine(); +// System.out.println(line); + while (line.length() > 0) { + String[] data = line.split(":", 2); + String key = data[0]; + String value = data[1]; + if (key.equals("angle")) { + angle = Float.parseFloat(value); + } else if (key.equals("speed")) { + speed = Float.parseFloat(value); + } else if (key.equals("distance0")) { + distance0 = Float.parseFloat(value); + } else if (key.equals("distance4")) { + distance4 = Float.parseFloat(value); + } else if (key.equals("distance8")) { + distance8 = Float.parseFloat(value); + } else if (key.equals("distance16")) { + distance16 = Float.parseFloat(value); + } else if (key.equals("distance32")) { + distance32 = Float.parseFloat(value); + } else if (key.equals("friction")) { + friction = Float.parseFloat(value); + } else if (key.equals("skid")) { + skid = Float.parseFloat(value); + } else if (key.equals("checkpoint")) { + checkpoint = Float.parseFloat(value); + } else if (key.equals("sensorFrontLeft")) { + sensorFrontLeft = Float.parseFloat(value); + } else if (key.equals("sensorFrontMiddleLeft")) { + sensorFrontMiddleLeft = Float.parseFloat(value); + } else if (key.equals("sensorFrontMiddleRight")) { + sensorFrontMiddleRight = Float.parseFloat(value); + } else if (key.equals("sensorFrontRight")) { + sensorFrontRight = Float.parseFloat(value); + } else if (key.equals("sensorFrontRightCorner1")) { + sensorFrontRightCorner1 = Float.parseFloat(value); + } else if (key.equals("sensorFrontRightCorner2")) { + sensorFrontRightCorner2 = Float.parseFloat(value); + } else if (key.equals("sensorRight1")) { + sensorRight1 = Float.parseFloat(value); + } else if (key.equals("sensorRight2")) { + sensorRight2 = Float.parseFloat(value); + } else if (key.equals("sensorRearRightCorner2")) { + sensorRearRightCorner2 = Float.parseFloat(value); + } else if (key.equals("sensorRearRightCorner1")) { + sensorRearRightCorner1 = Float.parseFloat(value); + } else if (key.equals("sensorRearRight")) { + sensorRearRight = Float.parseFloat(value); + } else if (key.equals("sensorRearLeft")) { + sensorRearLeft = Float.parseFloat(value); + } else if (key.equals("sensorRearLeftCorner1")) { + sensorRearLeftCorner1 = Float.parseFloat(value); + } else if (key.equals("sensorRearLeftCorner2")) { + sensorRearLeftCorner2 = Float.parseFloat(value); + } else if (key.equals("sensorLeft1")) { + sensorLeft1 = Float.parseFloat(value); + } else if (key.equals("sensorLeft2")) { + sensorLeft2 = Float.parseFloat(value); + } else if (key.equals("sensorFrontLeftCorner1")) { + sensorFrontLeftCorner1 = Float.parseFloat(value); + } else if (key.equals("sensorFrontLeftCorner2")) { + sensorFrontLeftCorner2 = Float.parseFloat(value); + } else { + System.err.println("Chyba se serveru: " + line); + } + line = in.readLine(); +// System.out.println(line); + } + + // vypocet nastaveni rizeni, ktery je mozno zmenit za jiny algoritmus + float acc; // zrychleni auta <0,1> + float wheel; // otoceni volantem (kolama) <0,1> + + StringBuffer neuralNetInput = new StringBuffer(); + + +// float angle = 0; // uhel k care <0,1> +// float speed = 0; // rychlost auta <0,1> +// float distance0 = 0; // vzdalenost od cary <0,1> +// float distance4 = 0; // vzdalenost od cary za 4m<0,1> +// float distance8 = 0; // vzdalenost od cary za 8m<0,1> +// float distance16 = 0; // vzdalenost od cary za 16m<0,1> +// float distance32 = 0; // vzdalenost od cary za 32m<0,1> +// float friction = 0; +// float skid = 0; +// float checkpoint = 0; +// float sensorFrontLeft = 0; +// float sensorFrontMiddleLeft = 0; +// float sensorFrontMiddleRight = 0; +// float sensorFrontRight = 0; +// float sensorFrontRightCorner1 = 0; +// float sensorFrontRightCorner2 = 0; +// float sensorRight1 = 0; +// float sensorRight2 = 0; +// float sensorRearRightCorner2 = 0; +// float sensorRearRightCorner1 = 0; +// float sensorRearRight = 0; +// float sensorRearLeft = 0; +// float sensorRearLeftCorner1 = 0; +// float sensorRearLeftCorner2 = 0; +// float sensorLeft1 = 0; +// float sensorLeft2 = 0; +// float sensorFrontLeftCorner1 = 0; +// float sensorFrontLeftCorner2 = 0; + + + +// neuralNetInput.append(String.valueOf(angle)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(speed)); +// neuralNetInput.append(" "); + neuralNetInput.append("0.5 0.5 "); + neuralNetInput.append(String.valueOf(distance0)); + neuralNetInput.append(" "); + neuralNetInput.append(String.valueOf(distance4)); + neuralNetInput.append(" "); + neuralNetInput.append(String.valueOf(distance8)); + neuralNetInput.append(" "); + neuralNetInput.append(String.valueOf(distance16)); + neuralNetInput.append(" "); + neuralNetInput.append(String.valueOf(distance32)); + neuralNetInput.append(" "); + neuralNetInput.append("1 1 0.5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"); +// neuralNetInput.append(String.valueOf(friction)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(skid)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(checkpoint)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorFrontLeft)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorFrontMiddleLeft)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorFrontMiddleRight)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorFrontRight)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorFrontRightCorner1)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorFrontRightCorner2)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorRight1)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorRight2)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorRearRightCorner2)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorRearRightCorner1)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorRearRight)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorRearLeft)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorRearLeftCorner1)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorRearLeftCorner2)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorLeft1)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorLeft2)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorFrontLeftCorner1)); +// neuralNetInput.append(" "); +// neuralNetInput.append(String.valueOf(sensorFrontLeftCorner2)); + + neuralNetwork.run(neuralNetInput.toString()); + + String output = neuralNetwork.getOutput(); + String[] outputArray = output.split(" "); + + wheel = Float.parseFloat(outputArray[0]); + acc = Float.parseFloat(outputArray[1]); + + + + + // odpoved serveru + out.write("ok\n"); + out.write("acc:" + acc + "\n"); + out.write("wheel:" + wheel + "\n"); + out.write("\n"); + out.flush(); + } + + /** + * Funkce, ktera vytvari a spousti klienta. + * + * @param args pole argumentu: server port nazev_zavodu jmeno_ridice + * @throws java.io.IOException problem ve spojeni k serveru, zavodu + */ + public static void main(String[] args) throws IOException { +// String host = "java.cs.vsb.cz"; +// int port = 9460; + String host = "localhost"; +// int port = 9461; // test + int port = 9460; // normal + String raceName = "Zavod"; + String driverName = "basic_client"; + String carType = null; + if (args.length < 4) { + // kontrola argumentu programu + System.err.println("argumenty: server port nazev_zavodu jmeno_ridice [typ_auta]"); + List<String> raceList = CarDriverClient.listRaces(host, port); + raceName = raceList.get(new Random().nextInt(raceList.size())); + List<String> carList = CarDriverClient.listCars(host, port, raceName); + carType = carList.get(0); + driverName += "_" + carType; +// host = JOptionPane.showInputDialog("Host:", host); +// port = Integer.parseInt(JOptionPane.showInputDialog("Port:", Integer.toString(port))); +// raceName = JOptionPane.showInputDialog("Race name:", raceName); +// driverName = JOptionPane.showInputDialog("Driver name:", driverName); + } else { + // nacteni parametu + host = args[0]; + port = Integer.parseInt(args[1]); + raceName = args[2]; + driverName = args[3]; + if(args.length > 4){ + carType = args[4]; + } + } + // vytvoreni neuronove site + ArrayList<Integer> nrOfNeuronsPerLayer = new ArrayList<Integer>(); + //nrOfNeuronsPerLayer.add(20); +// nrOfNeuronsPerLayer.add(15); +// nrOfNeuronsPerLayer.add(10); +// nrOfNeuronsPerLayer.add(2); + nrOfNeuronsPerLayer.add(3); + nrOfNeuronsPerLayer.add(3); + nrOfNeuronsPerLayer.add(2); + BPNet neuralNet = new BPNet(0.1f, 3, 28, nrOfNeuronsPerLayer, 1.4f, 0.4f); + + FileReader fr = new FileReader(new File("C:\\Users\\Martin\\Desktop\\NSProjekty\\testTrainingSet5.txt")); + StreamTokenizer tokenizer = new StreamTokenizer(fr); + /*for (int i = 0; i < 6; i++ ) + tokenizer.nextToken(); + */ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + int nrOfLayers = (int)tokenizer.nval; + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + int nrOfInputs = (int)tokenizer.nval; + tokenizer.nextToken(); + tokenizer.nextToken(); + tokenizer.nextToken(); + tokenizer.nextToken(); + ArrayList<float[]> inputRanges = new ArrayList<float[]>(); + ArrayList<String> inputNames = new ArrayList<String>(); + for (int i = 0; i < nrOfInputs; i++) { + String inputName = tokenizer.sval; + inputNames.add(inputName); + tokenizer.nextToken(); + float[] dims = new float[2]; + dims[0] = (float)tokenizer.nval; + tokenizer.nextToken(); + dims[1] = (float)tokenizer.nval; + inputRanges.add(dims); + tokenizer.nextToken(); + } + /*for (int i = 0; i < 3; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + nrOfNeuronsPerLayer = new ArrayList<Integer>(); + int nrOfOutputs = 0; + for (int i = 0; i < nrOfLayers; i++) { + nrOfNeuronsPerLayer.add((int)tokenizer.nval); + if (i == nrOfLayers - 1) { + nrOfOutputs = (int)tokenizer.nval; + } + tokenizer.nextToken(); + } + for (int i = 0; i < 3; i++ ) + tokenizer.nextToken(); + ArrayList<String> outputNames = new ArrayList<String>(); + for (int i = 0; i < nrOfOutputs; i++) { + outputNames.add(tokenizer.sval); + tokenizer.nextToken(); + } + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + float learnCoeff = (float)tokenizer.nval; + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + float inertiaCoeff = (float)tokenizer.nval; + /*for (int i = 0; i < 7; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + int nrOfTrainingElements = (int)tokenizer.nval; + /*for (int i = 0; i < 4; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < nrOfTrainingElements; i++) { + for (int j = 0; j < nrOfInputs; j++) { + sb.append(String.valueOf(tokenizer.nval/(inputRanges.get(j)[1]-inputRanges.get(j)[0]) - inputRanges.get(j)[0]/(inputRanges.get(j)[1]-inputRanges.get(j)[0]))); + sb.append(" "); + tokenizer.nextToken(); + } + for (int j = 0; j < nrOfOutputs; j++) { + sb.append(String.valueOf(tokenizer.nval)); + sb.append(" "); + tokenizer.nextToken(); + } + sb.deleteCharAt(sb.length() - 1); + sb.append("\n"); + } + String trainingData = sb.toString(); + sb = new StringBuffer(); + /*for (int i = 0; i < 5; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + int nrOfTestElements = (int)tokenizer.nval; + String testData; + /*tokenizer.nextToken();*/ + if (nrOfTestElements > 0) { + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + for (int i = 0; i < nrOfTestElements; i++) { + for (int j = 0; j < nrOfInputs; j++) { + sb.append(String.valueOf(String.valueOf(tokenizer.nval/(inputRanges.get(j)[1]-inputRanges.get(j)[0]) - inputRanges.get(j)[0]/(inputRanges.get(j)[1]-inputRanges.get(j)[0])))); + sb.append(" "); + tokenizer.nextToken(); + } + sb.deleteCharAt(sb.lastIndexOf(" ")); + sb.append("\n"); + } + testData = sb.substring(0,sb.lastIndexOf("\n")); + } else { + testData = ""; + } + fr.close(); + + String trainingSet = trainingData; + System.out.println("Learning started."); + ; + System.out.println("Net learned in " + neuralNet.learn(trainingSet) + " iterations"); + + // vytvoreni klienta + CarDriverClient driver = new CarDriverClient(host, port, raceName, driverName, carType, neuralNet); + // spusteni + driver.run(); + } +} diff --git a/java/src/Connection.java b/java/src/Connection.java new file mode 100644 index 0000000000000000000000000000000000000000..d8c451c787cc81045320f8ddc7fa5ac986fd4314 --- /dev/null +++ b/java/src/Connection.java @@ -0,0 +1,52 @@ +package cz.vsb.mro0010.neuralnetworks; + +public class Connection { + + private Neuron inputNeuron; + private Neuron outputNeuron; + private float weight; + + public Connection(Neuron inputNeuron, Neuron outputNeuron, float weight) { + this.setInputNeuron(inputNeuron); + this.setOutputNeuron(outputNeuron); + this.setWeight(weight); + } + + protected Neuron getInputNeuron() { + return inputNeuron; + } + + protected void setInputNeuron(Neuron inputNeuron) { + this.inputNeuron = inputNeuron; + } + + protected Neuron getOutputNeuron() { + return outputNeuron; + } + + protected void setOutputNeuron(Neuron outputNeuron) { + this.outputNeuron = outputNeuron; + } + + public float getWeight() { + return weight; + } + + public void setWeight(float weight) { + this.weight = weight; + } + + public void adjustWeight(float value) { + this.weight += value; + } + + public void passSignal() { + outputNeuron.adjustPotential(inputNeuron.getState()*this.getWeight()); + } + + @Override + public String toString() { + return "Weight: " + this.getWeight(); + } + +} diff --git a/java/src/InputLayerPseudoNeuron.java b/java/src/InputLayerPseudoNeuron.java new file mode 100644 index 0000000000000000000000000000000000000000..d8489247c360c56a17e826eca0536b63ad7f209e --- /dev/null +++ b/java/src/InputLayerPseudoNeuron.java @@ -0,0 +1,14 @@ +package cz.vsb.mro0010.neuralnetworks; + +public class InputLayerPseudoNeuron extends Neuron { + + public InputLayerPseudoNeuron() { + super(); + } + + @Override + public void transfer() { + this.setState(this.getPotential()); + } + +} diff --git a/java/src/Interconnections.java b/java/src/Interconnections.java new file mode 100644 index 0000000000000000000000000000000000000000..8732fb18c3413be7cc31217c22570fb950b25fe9 --- /dev/null +++ b/java/src/Interconnections.java @@ -0,0 +1,54 @@ +package cz.vsb.mro0010.neuralnetworks; + +import java.util.ArrayList; + +public abstract class Interconnections { + + protected ArrayList<Connection> connections; + + public ArrayList<Connection> getConnections() { + return connections; + } + + public Interconnections() { + this.connections = new ArrayList<Connection>(); + } + + public void addConnection(Connection connection) { + this.connections.add(connection); + } + +// public void passSignal() { +// for (Connection c : this.connections) { +// +// Neuron n = c.getOutputNeuron(); +// n.initialize(); +// for (Connection cn : this.connections) { +// if (cn.getOutputNeuron().equals(n)) { +// cn.passSignal(); +// } +// } +// n.transfer(); +// } +// } + + public void passSignal() { // Faster version + ArrayList<Neuron> processedNeurons = new ArrayList<Neuron>(); + for (Connection c : this.connections) { + + Neuron n = c.getOutputNeuron(); + if (!processedNeurons.contains(n)) { + processedNeurons.add(n); + n.initialize(); + for (Connection cn : this.connections) { + if (cn.getOutputNeuron().equals(n)) { + cn.passSignal(); + } + } + n.transfer(); + } + } + } + + public abstract void adjustWeights(); +} diff --git a/java/src/InterconnectionsBP.java b/java/src/InterconnectionsBP.java new file mode 100644 index 0000000000000000000000000000000000000000..fc29d41da6161dd3c75f8c464014704c9a311c24 --- /dev/null +++ b/java/src/InterconnectionsBP.java @@ -0,0 +1,25 @@ +package cz.vsb.mro0010.neuralnetworks; + +public class InterconnectionsBP extends InterconnectionsMultiLayer { + + public InterconnectionsBP(float learningRate) { + super(learningRate); + } + + public void setLearningRate(float learningRate) { + this.learningRate = learningRate; + } + + @Override + public void adjustWeights() { // backPropagation - set new weights !after! all deltas are calculated + for (Connection connection : this.connections) { + float delta = ((SigmoidalNeuron)connection.getOutputNeuron()).getError(); + float lambda = ((SigmoidalNeuron)connection.getOutputNeuron()).getSlope(); + float y = connection.getOutputNeuron().getState(); + float x = connection.getInputNeuron().getState(); + float errorDerivative = delta*lambda*y*(1-y)*x; + connection.adjustWeight(-learningRate*errorDerivative); + } + } + +} diff --git a/java/src/InterconnectionsMultiLayer.java b/java/src/InterconnectionsMultiLayer.java new file mode 100644 index 0000000000000000000000000000000000000000..5785bdfda73453a4d2e56fdc5f6a9fc265190249 --- /dev/null +++ b/java/src/InterconnectionsMultiLayer.java @@ -0,0 +1,11 @@ +package cz.vsb.mro0010.neuralnetworks; + +public abstract class InterconnectionsMultiLayer extends Interconnections { + + protected float learningRate; //eta + + public InterconnectionsMultiLayer(float learningRate) { + this.learningRate = learningRate; + } + +} diff --git a/java/src/InvalidInputNumberException.java b/java/src/InvalidInputNumberException.java new file mode 100644 index 0000000000000000000000000000000000000000..3582a092522588d48d59d8994ceccef69548bab7 --- /dev/null +++ b/java/src/InvalidInputNumberException.java @@ -0,0 +1,14 @@ +package cz.vsb.mro0010.neuralnetworks; + +public class InvalidInputNumberException extends RuntimeException { + + /** + * + */ + private static final long serialVersionUID = -6282750644609100469L; + + public InvalidInputNumberException() { + super("Number of input values does not correspond with network input size"); + } + +} diff --git a/java/src/InvalidLayerNumberException.java b/java/src/InvalidLayerNumberException.java new file mode 100644 index 0000000000000000000000000000000000000000..5e77f8e09054c7de2f508cd3a026b1653b7757ed --- /dev/null +++ b/java/src/InvalidLayerNumberException.java @@ -0,0 +1,14 @@ +package cz.vsb.mro0010.neuralnetworks; + +public class InvalidLayerNumberException extends RuntimeException { + + + /** + * + */ + private static final long serialVersionUID = 1366940285989358521L; + + public InvalidLayerNumberException() { + super("Number of layer does not correspond with network"); + } +} diff --git a/java/src/InvalidNeuronTypeException.java b/java/src/InvalidNeuronTypeException.java new file mode 100644 index 0000000000000000000000000000000000000000..a983ab477d08c053e1aa455d2c8e48d8c23eb5f3 --- /dev/null +++ b/java/src/InvalidNeuronTypeException.java @@ -0,0 +1,14 @@ +package cz.vsb.mro0010.neuralnetworks; + +public class InvalidNeuronTypeException extends RuntimeException { + + + /** + * + */ + private static final long serialVersionUID = 5354372081840990196L; + + public InvalidNeuronTypeException() { + super("Wrong Neuron type"); + } +} diff --git a/java/src/MultiLayeredNet.java b/java/src/MultiLayeredNet.java new file mode 100644 index 0000000000000000000000000000000000000000..164170bd774ab09b86f911acd5cf445f29e09c7a --- /dev/null +++ b/java/src/MultiLayeredNet.java @@ -0,0 +1,58 @@ +package cz.vsb.mro0010.neuralnetworks; + +import java.util.ArrayList; + +public abstract class MultiLayeredNet extends NeuralNet { + + protected ArrayList<ArrayList<Neuron>> neuronLayers; + protected ArrayList<InputLayerPseudoNeuron> inputNeuronLayer; + protected int nrOfInputs; + protected int nrOfLayers; + protected ArrayList<Integer> nrOfNeuronsPerLayer; + + public MultiLayeredNet(int nrOfInputs, int nrOfLayers, ArrayList<Integer> nrOfNeuronsPerLayer) { + super(); + this.nrOfInputs = nrOfInputs; + this.nrOfLayers = nrOfLayers; + this.nrOfNeuronsPerLayer = nrOfNeuronsPerLayer; + neuronLayers = new ArrayList<ArrayList<Neuron>>(nrOfLayers); + inputNeuronLayer = new ArrayList<InputLayerPseudoNeuron>(nrOfInputs); + for (int i = 0; i < nrOfLayers; i++) { + neuronLayers.add(new ArrayList<Neuron>(nrOfNeuronsPerLayer.get(i))); + } + for (int i = 0; i < nrOfInputs; i++) { + inputNeuronLayer.add(new InputLayerPseudoNeuron()); + } + } + + public MultiLayeredNet() { + this(0,0,null); + } + + public int getNrOfInputs() { + return nrOfInputs; + } + + public int getNrOfLayers() { + return nrOfLayers; + } + + @Override + public void run(String input) { + String[] inputValues = input.split(" "); + if (inputValues.length != nrOfInputs) + throw new InvalidInputNumberException(); + for (int i = 0; i < nrOfInputs; i++) { + InputLayerPseudoNeuron in = this.inputNeuronLayer.get(i); + in.initialize(); + in.adjustPotential(Float.parseFloat(inputValues[i])); + in.transfer(); + } + + for (int i = 0; i < nrOfLayers; i++) { + Interconnections interconnectionsLayer = interconnectionsLayers.get(i); + interconnectionsLayer.passSignal(); + } + } + +} diff --git a/java/src/NeuralNet.java b/java/src/NeuralNet.java new file mode 100644 index 0000000000000000000000000000000000000000..9a666d7f2931f5273f084bb7eee011b26c94853c --- /dev/null +++ b/java/src/NeuralNet.java @@ -0,0 +1,20 @@ +package cz.vsb.mro0010.neuralnetworks; + +import java.util.ArrayList; + +public abstract class NeuralNet { + + protected ArrayList<Interconnections> interconnectionsLayers; + + public NeuralNet(ArrayList<Interconnections> interconnectionsLayers) { + this.interconnectionsLayers = interconnectionsLayers; + } + + public NeuralNet() { + this(new ArrayList<Interconnections>()); + } + + public abstract String getNeuronType(); + public abstract int learn(String trainingSet); + public abstract void run(String input); +} diff --git a/java/src/Neuron.java b/java/src/Neuron.java new file mode 100644 index 0000000000000000000000000000000000000000..69260b322b7d949ad01866f9b0b087175a4dcb04 --- /dev/null +++ b/java/src/Neuron.java @@ -0,0 +1,59 @@ +package cz.vsb.mro0010.neuralnetworks; + +public abstract class Neuron { + + private float potential; // inner potential + private float state; // excitation state + private float threshold; // threshold of excitation + + + public Neuron() { + this(0, 0, 0); + } + + public Neuron(float potential, float state, float threshold) { + this.setPotential(potential); + this.setState(state); + this.setThreshold(threshold); + } + + public void initialize() { + this.setPotential(0); + this.setState(0); + } + + public float getThreshold() { + return threshold; + } + + public void setThreshold(float threshold) { + this.threshold = threshold; + } + + public float getState() { + return state; + } + + protected void setState(float state) { + this.state = state; + } + + protected float getPotential() { + return this.potential; + } + + private void setPotential(float potential) { + this.potential = potential; + } + + public void adjustPotential(float value) { + this.potential += value; + } + + @Override + public String toString() { + return "Pot.: " + this.potential + ", State: " + this.state + ", Thr.: " + this.threshold; + } + + public abstract void transfer(); +} diff --git a/java/src/Projekt1GUI.java b/java/src/Projekt1GUI.java new file mode 100644 index 0000000000000000000000000000000000000000..e5c5945300ca9a713d299a4d9c7453747f1aee9d --- /dev/null +++ b/java/src/Projekt1GUI.java @@ -0,0 +1,581 @@ +package cz.vsb.mro0010.neuralnetworks; + +import java.awt.Color; +import java.awt.EventQueue; +import java.awt.Rectangle; +import javax.swing.JFileChooser; +import javax.swing.JFrame; +import javax.swing.JMenuBar; +import javax.swing.JMenu; +import javax.swing.JMenuItem; +import javax.swing.JOptionPane; +import javax.swing.JTable; + +import java.awt.event.ActionListener; +import java.awt.event.ActionEvent; +import java.awt.event.WindowEvent; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; +import java.io.StreamTokenizer; +import java.util.ArrayList; + +import javax.swing.JButton; +import javax.swing.JScrollPane; +import javax.swing.JLabel; +import javax.swing.event.ListSelectionEvent; +import javax.swing.event.ListSelectionListener; +import javax.swing.filechooser.FileFilter; + +import org.jfree.chart.ChartFactory; +import org.jfree.chart.ChartPanel; +import org.jfree.chart.JFreeChart; +import org.jfree.chart.axis.NumberAxis; +import org.jfree.chart.plot.XYPlot; +import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer; +import org.jfree.chart.renderer.xy.XYSplineRenderer; +import org.jfree.data.xy.XYSeries; +import org.jfree.data.xy.XYSeriesCollection; +import org.jfree.ui.RectangleInsets; +import org.jfree.util.ShapeUtilities; + +public class Projekt1GUI { + + private JFrame frmPerceptron; + private SinglePerceptronNeuralNet neuralNet; + private File dataFile; + private String trainingData; + private String testData; + private int nrOfInputs; + private ArrayList<float[]> inputRanges; + private float learnCoeff; + private int nrOfTrainingElements; + private int nrOfTestElements; + private String trainingOutput; + private int nrOfTrainingIterations; + + //Swing components + private JButton btnLearn; + private JTable tableLearn; + private JTable tableTest; + private JTable tableTrainingProcess; + private JScrollPane scrollPaneLearn; + private JScrollPane scrollPaneTest; + private JScrollPane scrollPaneTrainingProcess; + private JButton buttonBackward; + private JButton buttonForward; + private JButton btnTestData; + + //Chart components + private XYSeriesCollection dataset; + private ChartPanel pnlChart; + private XYLineAndShapeRenderer renderer; + + /** + * Launch the application. + */ + public static void main(String[] args) { + EventQueue.invokeLater(new Runnable() { + public void run() { + try { + Projekt1GUI window = new Projekt1GUI(); + window.frmPerceptron.setVisible(true); + } catch (Exception e) { + e.printStackTrace(); + } + } + }); + } + + /** + * Create the application. + */ + public Projekt1GUI() { + initialize(); + + } + + /** + * Initialize the contents of the frame. + */ + private void initialize() { + frmPerceptron = new JFrame(); + frmPerceptron.setTitle("Perceptron"); + frmPerceptron.setBounds(100, 100, 652, 498); + frmPerceptron.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); + frmPerceptron.getContentPane().setLayout(null); + + + + btnLearn = new JButton("Learn"); + btnLearn.addActionListener(new ActionListener() { + + + public void actionPerformed(ActionEvent e) { + btnLearn.setEnabled(false); + btnTestData.setEnabled(true); + neuralNet.learn(trainingData); + trainingOutput = neuralNet.getTrainingOutput(); + + //Show training process table + String[] columnNames = new String[nrOfInputs + 1]; + for (int i = 0; i < nrOfInputs; i++) { + columnNames[i] = "w" + String.valueOf(i+1); + } + columnNames[nrOfInputs] = "Threshold"; + String[] rows = trainingOutput.split("\n"); + nrOfTrainingIterations = rows.length; + Float[][] fDataTable = new Float[nrOfTrainingIterations][nrOfInputs + 1]; + for (int i = 0; i < nrOfTrainingIterations; i++) { + String[] cells = rows[i].split(" "); + for (int j = 0; j < nrOfInputs + 1; j++) { + fDataTable[i][j] = Float.valueOf(cells[j]); + } + } + tableTrainingProcess = new JTable( fDataTable, columnNames); + tableTrainingProcess.getSelectionModel().addListSelectionListener(new ListSelectionListener(){ + public void valueChanged(ListSelectionEvent event) { + + if (tableTrainingProcess.getSelectedRow() == 0) { + buttonForward.setEnabled(true); + buttonBackward.setEnabled(false); + } + else if (tableTrainingProcess.getSelectedRow() == tableTrainingProcess.getRowCount()-1) { + buttonBackward.setEnabled(true); + buttonForward.setEnabled(false); + } + else { + buttonBackward.setEnabled(true); + buttonForward.setEnabled(true); + } + + //redraw chart in 2D + if ((nrOfInputs == 2) && (dataset != null)) { + float xMin = inputRanges.get(0)[0]; + float xMax = inputRanges.get(0)[1]; + float yMin = inputRanges.get(1)[0]; + float yMax = inputRanges.get(1)[1]; + + int selectedRow = tableTrainingProcess.getSelectedRow(); + float w0 = -(float)tableTrainingProcess.getModel().getValueAt(selectedRow, 2); + float w1 = (float)tableTrainingProcess.getModel().getValueAt(selectedRow, 0); + float w2 = (float)tableTrainingProcess.getModel().getValueAt(selectedRow, 1); + float step = (float)0.01; + + boolean containSeries = false; + String key = "Line"; + for (Object obj : dataset.getSeries()) { + if (obj instanceof XYSeries) { + XYSeries xys = (XYSeries)obj; + if (xys.getKey().equals(key)) { + containSeries = true; + } + } + } + if (!containSeries) { + XYSeries series = new XYSeries(key); + dataset.addSeries(series); + } + for (Object obj : dataset.getSeries()) { + if (obj instanceof XYSeries) { + XYSeries xys = (XYSeries)obj; + if (xys.getKey().equals(key)) { + int index = dataset.getSeries().indexOf(xys); + xys.clear(); + for (float x = xMin; x < xMax; x += step) { + float y = -w1/w2 * x - w0/w2; + if ( (yMin <= y) && (y <= yMax)) { + xys.add(x, y); + } + } + renderer.setSeriesPaint(index, Color.RED); + } + } + } + } + } + }); + scrollPaneTrainingProcess.setViewportView(tableTrainingProcess); + tableTrainingProcess.setRowSelectionInterval(0, 0); + if (nrOfTrainingIterations > 1) + buttonForward.setEnabled(true); + + // in 2D case draw graph + if (nrOfInputs == 2) { + //Create a chart + XYSeries series = new XYSeries("Line"); + float xMin = 0;//inputRanges.get(0)[0]; + float xMax = 1;//inputRanges.get(0)[1]; + float yMin = 0;//inputRanges.get(1)[0]; + float yMax = 1;//inputRanges.get(1)[1]; + + int selectedRow = tableTrainingProcess.getSelectedRow(); + float w0 = -(float)tableTrainingProcess.getModel().getValueAt(selectedRow, 2); + float w1 = (float)tableTrainingProcess.getModel().getValueAt(selectedRow, 0); + float w2 = (float)tableTrainingProcess.getModel().getValueAt(selectedRow, 1); + float step = (float)0.01; + for (float x = xMin; x < xMax; x += step) { + float y = -w1/w2 * x - w0/w2; + if ( (yMin <= y) && (y <= yMax)) { + series.add(x, y); + } + } + + XYSeries seriesLearnNeg = new XYSeries("LN"); + XYSeries seriesLearnPoz = new XYSeries("LP"); + String[] trainingRows = trainingData.split("\n"); + for (int i = 0; i < nrOfTrainingElements; i++) { + String[] trainingElement = trainingRows[i].split(" "); + if (Float.valueOf(trainingElement[2]) == 1) { + seriesLearnPoz.add(Float.valueOf(trainingElement[0]), Float.valueOf(trainingElement[1])); + } else { + seriesLearnNeg.add(Float.valueOf(trainingElement[0]), Float.valueOf(trainingElement[1])); + } + } + + + dataset = new XYSeriesCollection(); + dataset.addSeries(series); + dataset.addSeries(seriesLearnPoz); + dataset.addSeries(seriesLearnNeg); + + //Create chart with name , axis names and dataset + JFreeChart chart = ChartFactory.createXYLineChart("", "x1", "x2", dataset); + if ((pnlChart != null) && (pnlChart.getParent() == frmPerceptron.getContentPane())) + frmPerceptron.getContentPane().remove(pnlChart); + + //Change plot properties + + XYPlot plot = (XYPlot) chart.getPlot(); + plot.setBackgroundPaint(Color.white); + plot.setAxisOffset(new RectangleInsets(0, 0, 0, 0)); + plot.setDomainGridlinesVisible(false); + plot.setDomainGridlinePaint(Color.lightGray); + plot.setRangeGridlinePaint(Color.white); + //Set axes range + //x + NumberAxis domain = (NumberAxis) plot.getDomainAxis(); + domain.setRange(xMin, xMax); + //y + NumberAxis yRange = (NumberAxis) plot.getRangeAxis(); + yRange.setRange(yMin, yMax); + + //Set renderer + + renderer = new XYSplineRenderer(); + renderer.setSeriesShapesVisible(0, false); + renderer.setSeriesShapesVisible(1, true); + renderer.setSeriesShape(1, ShapeUtilities.createUpTriangle(4)); + renderer.setSeriesShapesVisible(2, true); + renderer.setSeriesShape(2, ShapeUtilities.createDownTriangle(4)); + renderer.setSeriesPaint(0, Color.RED); + renderer.setSeriesPaint(1, Color.BLUE); + renderer.setSeriesPaint(2, Color.BLUE); + renderer.setSeriesLinesVisible(0, true); + renderer.setSeriesLinesVisible(1, false); + renderer.setSeriesLinesVisible(2, false); + plot.setRenderer(renderer); + pnlChart = new ChartPanel(chart); + pnlChart.setBounds(309, 267, 273, 150); + pnlChart.setDomainZoomable(false); + pnlChart.setRangeZoomable(false); + pnlChart.getChart().removeLegend(); + frmPerceptron.getContentPane().add(pnlChart); + frmPerceptron.repaint(); + } else { + if (pnlChart != null) { + frmPerceptron.getContentPane().remove(pnlChart); + frmPerceptron.repaint(); + } + } + } + }); + btnLearn.setEnabled(false); + btnLearn.setBounds(10, 188, 89, 23); + frmPerceptron.getContentPane().add(btnLearn); + + btnTestData = new JButton("Test data"); + btnTestData.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) { + btnTestData.setEnabled(false); + String[] columnNames = new String[nrOfInputs + 1]; + for (int i = 0; i < nrOfInputs; i++) { + columnNames[i] = "x" + String.valueOf(i+1); + } + columnNames[nrOfInputs] = "y"; + Float[][] fDataTable = new Float[nrOfTestElements][nrOfInputs + 1]; + String[] rows = testData.split("\n"); + for (int i = 0; i < nrOfTestElements; i++) { + String[] cells = rows[i].split(" "); + for (int j = 0; j < nrOfInputs; j++) { + fDataTable[i][j] = Float.valueOf(cells[j]); + } + neuralNet.run(rows[i]); + String y = neuralNet.getOutput(); + fDataTable[i][nrOfInputs] = Float.valueOf(y); + } + tableTest = new JTable( fDataTable, columnNames); + scrollPaneTest.setViewportView(tableTest); + // in 2D case redraw graph + if (nrOfInputs == 2) { + XYSeries seriesTestNeg = new XYSeries("TN"); + XYSeries seriesTestPoz = new XYSeries("TP"); + String[] testRows = testData.split("\n"); + for (int i = 0; i < nrOfTestElements; i++) { + String[] testElement = testRows[i].split(" "); + neuralNet.run(testRows[i]); + String y = neuralNet.getOutput(); + if (Float.valueOf(y) == 1) { + seriesTestPoz.add(Float.valueOf(testElement[0]), Float.valueOf(testElement[1])); + } else { + seriesTestNeg.add(Float.valueOf(testElement[0]), Float.valueOf(testElement[1])); + } + } + dataset.addSeries(seriesTestPoz); + dataset.addSeries(seriesTestNeg); + + renderer.setSeriesShapesVisible(3, true); + renderer.setSeriesShape(3, ShapeUtilities.createUpTriangle(6)); + renderer.setSeriesShapesVisible(4, true); + renderer.setSeriesShape(4, ShapeUtilities.createDownTriangle(6)); + renderer.setSeriesPaint(3, Color.GREEN); + renderer.setSeriesPaint(4, Color.GREEN); + renderer.setSeriesLinesVisible(3, false); + renderer.setSeriesLinesVisible(4, false); + + } + } + }); + btnTestData.setEnabled(false); + btnTestData.setBounds(10, 222, 89, 23); + frmPerceptron.getContentPane().add(btnTestData); + + scrollPaneLearn = new JScrollPane(); + scrollPaneLearn.setBounds(10, 25, 283, 156); + frmPerceptron.getContentPane().add(scrollPaneLearn); + + scrollPaneTest = new JScrollPane(); + scrollPaneTest.setBounds(10, 267, 283, 160); + frmPerceptron.getContentPane().add(scrollPaneTest); + + JLabel lblNewLabel = new JLabel("Training data"); + lblNewLabel.setBounds(10, 11, 116, 14); + frmPerceptron.getContentPane().add(lblNewLabel); + + JLabel lblTestData = new JLabel("Test data"); + lblTestData.setBounds(10, 252, 103, 14); + frmPerceptron.getContentPane().add(lblTestData); + + scrollPaneTrainingProcess = new JScrollPane(); + scrollPaneTrainingProcess.setBounds(303, 25, 283, 156); + frmPerceptron.getContentPane().add(scrollPaneTrainingProcess); + + JLabel lblTrainingProcess = new JLabel("Training process"); + lblTrainingProcess.setBounds(303, 11, 97, 14); + frmPerceptron.getContentPane().add(lblTrainingProcess); + + buttonBackward = new JButton("<<"); + buttonBackward.setEnabled(false); + buttonBackward.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent arg0) { + int row = tableTrainingProcess.getSelectedRow(); + int tableRows = tableTrainingProcess.getRowCount(); + if (row == tableRows - 1) { + buttonForward.setEnabled(true); + } + if (row == 1) { + buttonBackward.setEnabled(false); + } + tableTrainingProcess.setRowSelectionInterval(row-1, row-1); + Rectangle rect = tableTrainingProcess.getCellRect(row-1, 0, true); + tableTrainingProcess.scrollRectToVisible(rect); + } + }); + buttonBackward.setBounds(348, 188, 89, 23); + frmPerceptron.getContentPane().add(buttonBackward); + + buttonForward = new JButton(">>"); + buttonForward.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) { + int row = tableTrainingProcess.getSelectedRow(); + int tableRows = tableTrainingProcess.getRowCount(); + if (row == 0) { + buttonBackward.setEnabled(true); + } + if (row == tableRows - 2) { + buttonForward.setEnabled(false); + } + tableTrainingProcess.setRowSelectionInterval(row+1, row+1); + Rectangle rect = tableTrainingProcess.getCellRect(row+1, 0, true); + tableTrainingProcess.scrollRectToVisible(rect); + + + } + }); + buttonForward.setEnabled(false); + buttonForward.setBounds(447, 188, 89, 23); + frmPerceptron.getContentPane().add(buttonForward); + + JLabel lbldView = new JLabel("2D View"); + lbldView.setBounds(312, 226, 46, 14); + frmPerceptron.getContentPane().add(lbldView); + + JMenuBar menuBar = new JMenuBar(); + frmPerceptron.setJMenuBar(menuBar); + + JMenu mnFile = new JMenu("File"); + menuBar.add(mnFile); + + JMenuItem mntmLoadData = new JMenuItem("Load data"); + mntmLoadData.addActionListener(new ActionListener() { + + + public void actionPerformed(ActionEvent e) { + JFileChooser fc = new JFileChooser(); + fc.setDialogType(JFileChooser.OPEN_DIALOG); + FileFilter filter = new FileFilter() { + + @Override + public String getDescription() { + // TODO Auto-generated method stub + return "Txt files"; + } + + @Override + public boolean accept(File f) { + // TODO Auto-generated method stub + return (f.getName().endsWith(".txt") || f.isDirectory()); + } + }; + fc.setFileFilter(filter); + + + + if (fc.showOpenDialog(frmPerceptron) == JFileChooser.APPROVE_OPTION) { + dataFile = fc.getSelectedFile(); + FileReader fr; + try { + //Parse data file + fr = new FileReader(dataFile); + StreamTokenizer tokenizer = new StreamTokenizer(fr); + /*for (int i = 0; i < 6; i++ ) + tokenizer.nextToken(); + */ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + nrOfInputs = (int)tokenizer.nval; + /*tokenizer.nextToken(); + tokenizer.nextToken(); + tokenizer.nextToken(); + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + inputRanges = new ArrayList<float[]>(); + for (int i = 0; i < nrOfInputs; i++) { + float[] dims = new float[2]; + dims[0] = (float)tokenizer.nval; + tokenizer.nextToken(); + dims[1] = (float)tokenizer.nval; + inputRanges.add(dims); + tokenizer.nextToken(); + tokenizer.nextToken(); + } + /*for (int i = 0; i < 3; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + learnCoeff = (float)tokenizer.nval; + /*for (int i = 0; i < 7; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + nrOfTrainingElements = (int)tokenizer.nval; + /*for (int i = 0; i < 4; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < nrOfTrainingElements; i++) { + for (int j = 0; j < nrOfInputs; j++) { + sb.append(String.valueOf(tokenizer.nval/(inputRanges.get(j)[1]-inputRanges.get(j)[0]) - inputRanges.get(j)[0]/(inputRanges.get(j)[1]-inputRanges.get(j)[0]))); + sb.append(" "); + tokenizer.nextToken(); + } + sb.append(String.valueOf(tokenizer.nval)); + sb.append("\n"); + tokenizer.nextToken(); + } + trainingData = sb.toString(); + sb = new StringBuffer(); + /*for (int i = 0; i < 5; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + nrOfTestElements = (int)tokenizer.nval; + /*tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + for (int i = 0; i < nrOfTestElements; i++) { + for (int j = 0; j < nrOfInputs; j++) { + sb.append(String.valueOf(String.valueOf(tokenizer.nval/(inputRanges.get(j)[1]-inputRanges.get(j)[0]) - inputRanges.get(j)[0]/(inputRanges.get(j)[1]-inputRanges.get(j)[0])))); + sb.append(" "); + tokenizer.nextToken(); + } + sb.deleteCharAt(sb.lastIndexOf(" ")); + sb.append("\n"); + } + + testData = sb.substring(0,sb.lastIndexOf("\n")); + fr.close(); + neuralNet = new SinglePerceptronNeuralNet(new BinaryNeuron(), nrOfInputs, learnCoeff); + btnLearn.setEnabled(true); + //Show learn table + String[] columnNames = new String[nrOfInputs + 1]; + for (int i = 0; i < nrOfInputs; i++) { + columnNames[i] = "x" + String.valueOf(i+1); + } + columnNames[nrOfInputs] = "y"; + Float[][] fDataTable = new Float[nrOfTrainingElements][nrOfInputs + 1]; + String[] rows = trainingData.split("\n"); + for (int i = 0; i < nrOfTrainingElements; i++) { + String[] cells = rows[i].split(" "); + for (int j = 0; j < nrOfInputs + 1; j++) { + fDataTable[i][j] = Float.valueOf(cells[j]); + } + } + tableLearn = new JTable( fDataTable, columnNames); + scrollPaneLearn.setViewportView(tableLearn); + //Show test table + columnNames = new String[nrOfInputs]; + for (int i = 0; i < nrOfInputs; i++) { + columnNames[i] = "x" + String.valueOf(i+1); + } + fDataTable = new Float[nrOfTestElements][nrOfInputs]; + rows = testData.split("\n"); + for (int i = 0; i < nrOfTestElements; i++) { + String[] cells = rows[i].split(" "); + for (int j = 0; j < nrOfInputs; j++) { + fDataTable[i][j] = Float.valueOf(cells[j]); + } + } + tableTest = new JTable( fDataTable, columnNames); + scrollPaneTest.setViewportView(tableTest); + + } catch (FileNotFoundException e1) { + e1.printStackTrace(); + JOptionPane.showMessageDialog(null, "Error: File not found"); + } catch (IOException e1) { + // TODO Auto-generated catch block + e1.printStackTrace(); + } + + + } + + } + }); + mnFile.add(mntmLoadData); + + JMenuItem mntmExit = new JMenuItem("Exit"); + mntmExit.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent arg0) { + frmPerceptron.dispatchEvent(new WindowEvent(frmPerceptron, WindowEvent.WINDOW_CLOSING)); + } + }); + mnFile.add(mntmExit); + } +} diff --git a/java/src/Projekt2GUI.java b/java/src/Projekt2GUI.java new file mode 100644 index 0000000000000000000000000000000000000000..1b4406f177ec50e2148f271a690aa8c53cb2d5e5 --- /dev/null +++ b/java/src/Projekt2GUI.java @@ -0,0 +1,784 @@ +package cz.vsb.mro0010.neuralnetworks; + +import java.awt.Color; +import java.awt.EventQueue; + +import javax.swing.JFileChooser; +import javax.swing.JFrame; +import javax.swing.JMenuBar; +import javax.swing.JMenu; +import javax.swing.JMenuItem; +import javax.swing.JOptionPane; +import javax.swing.JTable; +import javax.swing.ScrollPaneConstants; + +import java.awt.event.ActionListener; +import java.awt.event.ActionEvent; +import java.awt.event.WindowEvent; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.StreamTokenizer; +import java.text.DecimalFormat; +import java.util.ArrayList; +import java.util.Arrays; + +import javax.swing.JButton; +import javax.swing.JScrollPane; +import javax.swing.JLabel; +import javax.swing.filechooser.FileFilter; +import javax.swing.JSpinner; +import javax.swing.JTextField; +import javax.swing.SpinnerNumberModel; +import javax.swing.event.ChangeListener; +import javax.swing.event.ChangeEvent; +import javax.swing.JPanel; + +import com.thoughtworks.xstream.XStream; + +import java.awt.GridLayout; + +public class Projekt2GUI { + + private JFrame frmBPnet; + private BPNet neuralNet; + private File dataFile; + private String trainingData; + private String testData; + + private int nrOfInputs; + private int nrOfOutputs; + private int nrOfLayers; + private float maxError; + private float slope; + private float inertiaCoeff; + private ArrayList<Integer> nrOfNeuronsPerLayer; + private ArrayList<String> inputNames; + private ArrayList<String> outputNames; + + + private ArrayList<float[]> inputRanges; + private float learnCoeff; + + + private int nrOfTrainingElements; + private int nrOfTestElements; + + //Swing components + private JButton btnLearn; + private JTable tableLearn; + private JTable tableTest; + private JScrollPane scrollPaneLearn; + private JScrollPane scrollPaneTest; + private JButton btnTestData; + private JButton btnDoSpecifiedLearn; + private JSpinner spinnerLearnSteps; + private JTextField textFieldIterations; + private JLabel lblLearned; + private JLabel lblLearnCoeff; + private JLabel lblSlopeLambda; + private JSpinner spinnerLearnCoeff; + private JSpinner spinnerSlope; + private JLabel lblMaxError; + private JSpinner spinnerError; + private JLabel lblCurentError; + private JTextField textFieldCurrentError; + private JTextField textFieldTestElement; + private JTextField textFieldTestOutput; + private JButton btnTestElement; + private JButton btnResetWeights; + private JPanel panelTopology; + private JButton btnAddLayer; + private JSpinner spinnerLayer; + private JSpinner spinnerLayerNeurons; + private JMenuItem mntmSaveNeuralNet; + + + /** + * Launch the application. + */ + public static void main(String[] args) { + EventQueue.invokeLater(new Runnable() { + public void run() { + try { + Projekt2GUI window = new Projekt2GUI(); + window.frmBPnet.setVisible(true); + } catch (Exception e) { + e.printStackTrace(); + } + } + }); + } + + /** + * Create the application. + */ + public Projekt2GUI() { + initialize(); + } + + + private void changeAfterLearn() { + btnLearn.setEnabled(false); + btnTestData.setEnabled(true); + btnDoSpecifiedLearn.setEnabled(false); + lblLearned.setText("Learned"); + lblLearned.setForeground(Color.GREEN); + spinnerLearnSteps.setEnabled(false); + spinnerError.setEnabled(false); + spinnerLearnCoeff.setEnabled(false); + spinnerSlope.setEnabled(false); + textFieldCurrentError.setText(String.valueOf(neuralNet.getError())); + btnTestElement.setEnabled(true); + textFieldTestElement.setEnabled(true); + textFieldTestOutput.setEnabled(true); + btnResetWeights.setEnabled(false); + btnAddLayer.setEnabled(false); + spinnerLayer.setEnabled(false); + spinnerLayerNeurons.setEnabled(false); + frmBPnet.getContentPane().remove(panelTopology); + frmBPnet.revalidate(); + frmBPnet.repaint(); + mntmSaveNeuralNet.setEnabled(true); + } + + /** + * Initialize the contents of the frame. + */ + private void initialize() { + //Default values + slope = (float)1.1; + maxError = (float)0.1; + + + frmBPnet = new JFrame(); + frmBPnet.setTitle("Backpropagation network"); + frmBPnet.setBounds(100, 100, 778, 562); + frmBPnet.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); + frmBPnet.getContentPane().setLayout(null); + + + + btnLearn = new JButton("Quick Learn"); + btnLearn.addActionListener(new ActionListener() { + + + public void actionPerformed(ActionEvent e) { + + int iterations = neuralNet.learn(trainingData); + changeAfterLearn(); + textFieldIterations.setText(String.valueOf(iterations)); + //JOptionPane.showMessageDialog(null, "Neural Net learned in " + iterations + " iterations."); + + } + }); + btnLearn.setEnabled(false); + btnLearn.setBounds(10, 188, 174, 23); + frmBPnet.getContentPane().add(btnLearn); + + btnTestData = new JButton("Test data"); + btnTestData.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) { + btnTestData.setEnabled(false); + + + String[] columnNames = new String[nrOfInputs + nrOfOutputs]; + for (int i = 0; i < nrOfInputs; i++) { + columnNames[i] = inputNames.get(i); + + } + for (int i = 0; i < nrOfOutputs; i++) { + columnNames[nrOfInputs + i] = outputNames.get(i); + } + Float[][] fDataTable = new Float[nrOfTestElements][nrOfInputs + nrOfOutputs]; + String[] rows = testData.split("\n"); + for (int i = 0; i < nrOfTestElements; i++) { + neuralNet.run(rows[i]); + String output = neuralNet.getOutput(); + String[] cells = (rows[i] + " " + output).split(" "); + for (int j = 0; j < nrOfInputs + nrOfOutputs; j++) { + fDataTable[i][j] = Float.valueOf(cells[j]); + } + } + tableTest = new JTable( fDataTable, columnNames); + tableTest.setAutoResizeMode(JTable.AUTO_RESIZE_OFF); + scrollPaneTest.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_ALWAYS); + scrollPaneTest.setViewportView(tableTest); + + + } + }); + btnTestData.setEnabled(false); + btnTestData.setBounds(10, 222, 174, 23); + frmBPnet.getContentPane().add(btnTestData); + + scrollPaneLearn = new JScrollPane(); + scrollPaneLearn.setBounds(10, 25, 368, 156); + frmBPnet.getContentPane().add(scrollPaneLearn); + + scrollPaneTest = new JScrollPane(); + scrollPaneTest.setBounds(10, 267, 368, 160); + frmBPnet.getContentPane().add(scrollPaneTest); + + JLabel lblNewLabel = new JLabel("Training data"); + lblNewLabel.setBounds(10, 11, 116, 14); + frmBPnet.getContentPane().add(lblNewLabel); + + JLabel lblTestData = new JLabel("Test data"); + lblTestData.setBounds(10, 252, 103, 14); + frmBPnet.getContentPane().add(lblTestData); + + JLabel lblTrainingProcess = new JLabel("Training modification"); + lblTrainingProcess.setBounds(386, 11, 132, 14); + frmBPnet.getContentPane().add(lblTrainingProcess); + + btnDoSpecifiedLearn = new JButton("Do specified learn steps"); + btnDoSpecifiedLearn.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) { + boolean learned = false; + int iter = 0; + int maxIterations = (int)spinnerLearnSteps.getModel().getValue(); + ArrayList<String> trainingElements = new ArrayList<String>(Arrays.asList(trainingData.split("\n"))); +// float maxError = 0; + while(!learned && (iter < maxIterations)) { + neuralNet.setError(0); + learned = true; + for (int i = 0; i < trainingElements.size(); i++) { + learned &= neuralNet.learnStep(trainingElements.get(i)); +// if (neuralNet.getError() > maxError) { +// maxError = neuralNet.getError(); +// } + } + iter++; + textFieldCurrentError.setText(String.valueOf(neuralNet.getError())); +// System.out.println(iter); + } + if (learned) { + changeAfterLearn(); + } + int currentIter = Integer.parseInt(textFieldIterations.getText()); + textFieldIterations.setText(String.valueOf(currentIter + iter)); + + + } + }); + btnDoSpecifiedLearn.setEnabled(false); + btnDoSpecifiedLearn.setBounds(194, 188, 184, 23); + frmBPnet.getContentPane().add(btnDoSpecifiedLearn); + + spinnerLearnSteps = new JSpinner(); + spinnerLearnSteps.setModel(new SpinnerNumberModel(1, 1, 100000, 1)); + spinnerLearnSteps.setEnabled(false); + spinnerLearnSteps.setBounds(194, 223, 88, 20); + frmBPnet.getContentPane().add(spinnerLearnSteps); + + textFieldIterations = new JTextField(); + textFieldIterations.setEnabled(false); + textFieldIterations.setText("0"); + textFieldIterations.setBounds(292, 223, 86, 20); + frmBPnet.getContentPane().add(textFieldIterations); + textFieldIterations.setColumns(10); + + lblLearned = new JLabel("Not Learned"); + lblLearned.setForeground(Color.RED); + lblLearned.setBackground(Color.LIGHT_GRAY); + lblLearned.setBounds(514, 143, 74, 14); + frmBPnet.getContentPane().add(lblLearned); + + lblLearnCoeff = new JLabel("Learn coeff"); + lblLearnCoeff.setBounds(388, 39, 79, 14); + frmBPnet.getContentPane().add(lblLearnCoeff); + + lblSlopeLambda = new JLabel("Slope - lambda"); + lblSlopeLambda.setBounds(388, 64, 89, 14); + frmBPnet.getContentPane().add(lblSlopeLambda); + + + + spinnerLearnCoeff = new JSpinner(); + spinnerLearnCoeff.addChangeListener(new ChangeListener() { + public void stateChanged(ChangeEvent e) { + learnCoeff = (float)spinnerLearnCoeff.getModel().getValue(); + if (neuralNet != null) + neuralNet.changeLearnCoeffTo(learnCoeff); + } + }); + spinnerLearnCoeff.setEnabled(false); + spinnerLearnCoeff.setModel(new SpinnerNumberModel(new Float(0.5), new Float(0.05), new Float(1), new Float(0.05))); + JSpinner.NumberEditor editor = (JSpinner.NumberEditor)spinnerLearnCoeff.getEditor(); + DecimalFormat format = editor.getFormat(); + format.setMinimumFractionDigits(5); + spinnerLearnCoeff.setBounds(499, 36, 74, 20); + frmBPnet.getContentPane().add(spinnerLearnCoeff); + + slope = (float)1.1; + spinnerSlope = new JSpinner(); + spinnerSlope.addChangeListener(new ChangeListener() { + public void stateChanged(ChangeEvent e) { + slope = (float)spinnerSlope.getModel().getValue(); + if (neuralNet != null) + neuralNet.changeSlopeTo(slope); + } + }); + spinnerSlope.setEnabled(false); + spinnerSlope.setBounds(499, 61, 74, 20); + spinnerSlope.setModel(new SpinnerNumberModel(new Float(slope), new Float(0.05), null, new Float(0.05))); + editor = (JSpinner.NumberEditor)spinnerSlope.getEditor(); + format = editor.getFormat(); + format.setMinimumFractionDigits(5); + frmBPnet.getContentPane().add(spinnerSlope); + + lblMaxError = new JLabel("Max error"); + lblMaxError.setBounds(388, 89, 67, 14); + frmBPnet.getContentPane().add(lblMaxError); + + + maxError = (float)0.1; + spinnerError = new JSpinner(); + spinnerError.addChangeListener(new ChangeListener() { + public void stateChanged(ChangeEvent e) { + maxError = (float)spinnerError.getModel().getValue(); + if (neuralNet != null) + neuralNet.setTolerance(maxError); + } + }); + spinnerError.setEnabled(false); + spinnerError.setBounds(499, 86, 74, 20); + spinnerError.setModel(new SpinnerNumberModel(new Float(maxError), new Float(0.00001), new Float(100), new Float(0.00001))); + editor = (JSpinner.NumberEditor)spinnerError.getEditor(); + format = editor.getFormat(); + format.setMinimumFractionDigits(5); + frmBPnet.getContentPane().add(spinnerError); + + lblCurentError = new JLabel("Current Error"); + lblCurentError.setBounds(388, 115, 79, 14); + frmBPnet.getContentPane().add(lblCurentError); + + textFieldCurrentError = new JTextField(); + textFieldCurrentError.setEnabled(false); + textFieldCurrentError.setBounds(487, 112, 86, 20); + frmBPnet.getContentPane().add(textFieldCurrentError); + textFieldCurrentError.setColumns(10); + + JLabel lblChangeNetworkTopology = new JLabel("Change network topology"); + lblChangeNetworkTopology.setBounds(389, 168, 184, 14); + frmBPnet.getContentPane().add(lblChangeNetworkTopology); + + textFieldTestElement = new JTextField(); + textFieldTestElement.setEnabled(false); + textFieldTestElement.setBounds(10, 438, 272, 20); + frmBPnet.getContentPane().add(textFieldTestElement); + textFieldTestElement.setColumns(10); + + btnTestElement = new JButton("Run input"); + btnTestElement.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) { + String input = textFieldTestElement.getText(); + try { + neuralNet.run(input); + } + catch(InvalidInputNumberException exception) { + JOptionPane.showMessageDialog(null, "Invalid Input"); + } + finally { + String output = neuralNet.getOutput(); + textFieldTestOutput.setText(output); + } + } + }); + btnTestElement.setEnabled(false); + btnTestElement.setBounds(289, 437, 89, 23); + frmBPnet.getContentPane().add(btnTestElement); + + textFieldTestOutput = new JTextField(); + textFieldTestOutput.setEnabled(false); + textFieldTestOutput.setBounds(49, 471, 329, 20); + frmBPnet.getContentPane().add(textFieldTestOutput); + textFieldTestOutput.setColumns(10); + + JLabel lblOutput = new JLabel("Output"); + lblOutput.setBounds(10, 474, 46, 14); + frmBPnet.getContentPane().add(lblOutput); + + btnResetWeights = new JButton("Reset weights"); + btnResetWeights.setEnabled(false); + btnResetWeights.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) { + neuralNet.resetWeights(); + } + }); + btnResetWeights.setBounds(388, 138, 116, 23); + frmBPnet.getContentPane().add(btnResetWeights); + + panelTopology = new JPanel(); + panelTopology.setBounds(386, 213, 366, 214); + frmBPnet.getContentPane().add(panelTopology); + + btnAddLayer = new JButton("Add layer"); + btnAddLayer.setEnabled(false); + btnAddLayer.setBounds(386, 188, 89, 23); + frmBPnet.getContentPane().add(btnAddLayer); + + spinnerLayer = new JSpinner(); + spinnerLayer.setEnabled(false); + spinnerLayer.setBounds(499, 189, 40, 20); + frmBPnet.getContentPane().add(spinnerLayer); + + spinnerLayerNeurons = new JSpinner(); + spinnerLayerNeurons.setEnabled(false); + spinnerLayerNeurons.setBounds(571, 189, 40, 20); + frmBPnet.getContentPane().add(spinnerLayerNeurons); + + JLabel lblTo = new JLabel("to"); + lblTo.setBounds(483, 192, 46, 14); + frmBPnet.getContentPane().add(lblTo); + + JLabel lblWith = new JLabel("with"); + lblWith.setBounds(542, 192, 46, 14); + frmBPnet.getContentPane().add(lblWith); + + JLabel lblNeurons = new JLabel("neurons"); + lblNeurons.setBounds(621, 192, 74, 14); + frmBPnet.getContentPane().add(lblNeurons); + + + + JMenuBar menuBar = new JMenuBar(); + frmBPnet.setJMenuBar(menuBar); + + + JMenu mnFile = new JMenu("File"); + menuBar.add(mnFile); + + JMenuItem mntmLoadData = new JMenuItem("Load data"); + mntmLoadData.addActionListener(new ActionListener() { + + + public void actionPerformed(ActionEvent e) { + JFileChooser fc = new JFileChooser(); + fc.setDialogType(JFileChooser.OPEN_DIALOG); + FileFilter filter = new FileFilter() { + + @Override + public String getDescription() { + return "Txt files"; + } + + @Override + public boolean accept(File f) { + return (f.getName().endsWith(".txt") || f.isDirectory()); + } + }; + fc.setFileFilter(filter); + + + + if (fc.showOpenDialog(frmBPnet) == JFileChooser.APPROVE_OPTION) { + dataFile = fc.getSelectedFile(); + FileReader fr; + try { + + spinnerLearnSteps.setEnabled(true); + spinnerError.setEnabled(true); + spinnerLearnCoeff.setEnabled(true); + spinnerSlope.setEnabled(true); + spinnerLearnCoeff.setValue(new Float((float)spinnerLearnCoeff.getValue())); + spinnerSlope.setValue(spinnerSlope.getValue()); + spinnerError.setValue(spinnerError.getValue()); + + //Parse data file + + + + fr = new FileReader(dataFile); + StreamTokenizer tokenizer = new StreamTokenizer(fr); + /*for (int i = 0; i < 6; i++ ) + tokenizer.nextToken(); + */ + while(true) { + tokenizer.nextToken(); + if ((tokenizer.nextToken() == StreamTokenizer.TT_WORD) && tokenizer.sval.equals("vrstev")) { + tokenizer.nextToken(); + break; + } + } + nrOfLayers = (int)tokenizer.nval; + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + nrOfInputs = (int)tokenizer.nval; + tokenizer.nextToken(); + tokenizer.nextToken(); + tokenizer.nextToken(); + tokenizer.nextToken(); + inputRanges = new ArrayList<float[]>(); + inputNames = new ArrayList<String>(); + for (int i = 0; i < nrOfInputs; i++) { + String inputName = tokenizer.sval; + inputNames.add(inputName); + tokenizer.nextToken(); + float[] dims = new float[2]; + dims[0] = (float)tokenizer.nval; + tokenizer.nextToken(); + dims[1] = (float)tokenizer.nval; + inputRanges.add(dims); + tokenizer.nextToken(); + } + /*for (int i = 0; i < 3; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + nrOfNeuronsPerLayer = new ArrayList<Integer>(); + for (int i = 0; i < nrOfLayers; i++) { + nrOfNeuronsPerLayer.add((int)tokenizer.nval); + if (i == nrOfLayers - 1) { + nrOfOutputs = (int)tokenizer.nval; + } + tokenizer.nextToken(); + } + for (int i = 0; i < 3; i++ ) + tokenizer.nextToken(); + outputNames = new ArrayList<String>(); + for (int i = 0; i < nrOfOutputs; i++) { + outputNames.add(tokenizer.sval); + tokenizer.nextToken(); + } + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + learnCoeff = (float)tokenizer.nval; + spinnerLearnCoeff.getModel().setValue(new Float(learnCoeff)); + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + inertiaCoeff = (float)tokenizer.nval; + /*for (int i = 0; i < 7; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + nrOfTrainingElements = (int)tokenizer.nval; + /*for (int i = 0; i < 4; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < nrOfTrainingElements; i++) { + for (int j = 0; j < nrOfInputs; j++) { + sb.append(String.valueOf(tokenizer.nval/(inputRanges.get(j)[1]-inputRanges.get(j)[0]) - inputRanges.get(j)[0]/(inputRanges.get(j)[1]-inputRanges.get(j)[0]))); + sb.append(" "); + tokenizer.nextToken(); + } + for (int j = 0; j < nrOfOutputs; j++) { + sb.append(String.valueOf(tokenizer.nval)); + sb.append(" "); + tokenizer.nextToken(); + } + sb.deleteCharAt(sb.length() - 1); + sb.append("\n"); + } + trainingData = sb.toString(); + sb = new StringBuffer(); + /*for (int i = 0; i < 5; i++ ) + tokenizer.nextToken();*/ + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + nrOfTestElements = (int)tokenizer.nval; + /*tokenizer.nextToken();*/ + if (nrOfTestElements > 0) { + while(tokenizer.nextToken() != StreamTokenizer.TT_NUMBER) {} + for (int i = 0; i < nrOfTestElements; i++) { + for (int j = 0; j < nrOfInputs; j++) { + sb.append(String.valueOf(String.valueOf(tokenizer.nval/(inputRanges.get(j)[1]-inputRanges.get(j)[0]) - inputRanges.get(j)[0]/(inputRanges.get(j)[1]-inputRanges.get(j)[0])))); + sb.append(" "); + tokenizer.nextToken(); + } + sb.deleteCharAt(sb.lastIndexOf(" ")); + sb.append("\n"); + } + testData = sb.substring(0,sb.lastIndexOf("\n")); + } else { + testData = ""; + } + fr.close(); + + + + neuralNet = new BPNet(maxError, nrOfLayers, nrOfInputs, nrOfNeuronsPerLayer, slope, learnCoeff); + spinnerError.getModel().setValue(maxError); + btnLearn.setEnabled(true); + //Show learn table + String[] columnNames = new String[nrOfInputs + nrOfOutputs]; + for (int i = 0; i < nrOfInputs; i++) { + columnNames[i] = inputNames.get(i); + + } + for (int i = 0; i < nrOfOutputs; i++) { + columnNames[nrOfInputs + i] = outputNames.get(i); + } + Float[][] fDataTable = new Float[nrOfTrainingElements][nrOfInputs + nrOfOutputs]; + String[] rows = trainingData.split("\n"); + for (int i = 0; i < nrOfTrainingElements; i++) { + String[] cells = rows[i].split(" "); + for (int j = 0; j < nrOfInputs + nrOfOutputs; j++) { + fDataTable[i][j] = Float.valueOf(cells[j]); + } + } + tableLearn = new JTable( fDataTable, columnNames); + tableLearn.setAutoResizeMode(JTable.AUTO_RESIZE_OFF); + scrollPaneLearn.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_ALWAYS); + scrollPaneLearn.setViewportView(tableLearn); + //Show test table + columnNames = new String[nrOfInputs]; + for (int i = 0; i < nrOfInputs; i++) { + columnNames[i] = inputNames.get(i); + } + fDataTable = new Float[nrOfTestElements][nrOfInputs]; + rows = testData.split("\n"); + for (int i = 0; i < nrOfTestElements; i++) { + String[] cells = rows[i].split(" "); + for (int j = 0; j < nrOfInputs; j++) { + fDataTable[i][j] = Float.valueOf(cells[j]); + } + } + tableTest = new JTable( fDataTable, columnNames); + tableTest.setAutoResizeMode(JTable.AUTO_RESIZE_OFF); + scrollPaneTest.setViewportView(tableTest); + btnResetWeights.setEnabled(true); + btnDoSpecifiedLearn.setEnabled(true); + btnAddLayer.setEnabled(true); + spinnerLayer.setEnabled(true); + spinnerLayer.setModel(new SpinnerNumberModel(0, 0, neuralNet.getNrOfLayers(), 1)); + spinnerLayerNeurons.setEnabled(true); + spinnerLayerNeurons.setModel(new SpinnerNumberModel(1, 1, null, 1)); + + btnAddLayer.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) { + neuralNet.addNeuronLayer((Integer)spinnerLayerNeurons.getValue(), (Integer)spinnerLayer.getValue(), slope); + refreshPanelTopology(); + } + }); + refreshPanelTopology(); + + + } catch (FileNotFoundException e1) { + e1.printStackTrace(); + JOptionPane.showMessageDialog(null, "Error: File not found"); + } catch (IOException e1) { + e1.printStackTrace(); + JOptionPane.showMessageDialog(null, "IOException"); + } + + + } + + } + + private void refreshPanelTopology() { + panelTopology.setLayout(new GridLayout(neuralNet.getNrOfLayers() + 1, 4)); + panelTopology.removeAll(); + nrOfLayers = neuralNet.getNrOfLayers(); + String map = neuralNet.getNeuronMap(); + String[] layers = map.split(" "); + spinnerLayer.setModel(new SpinnerNumberModel(0, 0, neuralNet.getNrOfLayers() - 1, 1)); + spinnerLayerNeurons.setModel(new SpinnerNumberModel(1, 1, null, 1)); + for (int i = 0; i < nrOfLayers; i++) { + JLabel label = new JLabel(layers[nrOfLayers - 1 - i]); + panelTopology.add(label); + if (i > 0) { + JButton btn1 = new JButton("Rmv neuron"); + JButton btn2 = new JButton("Rmv layer"); + JButton btn3 = new JButton("Add neuron"); + btn1.setName(String.valueOf(nrOfLayers - 1 - i)); + btn2.setName(String.valueOf(nrOfLayers - 1 - i)); + btn3.setName(String.valueOf(nrOfLayers - 1 - i)); + btn1.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + String name = ((JButton)e.getSource()).getName(); + neuralNet.removeNeuron(Integer.parseInt(name)); + refreshPanelTopology(); + } + }); + btn2.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + String name = ((JButton)e.getSource()).getName(); + neuralNet.removeNeuronLayer(Integer.parseInt(name)); + refreshPanelTopology(); + } + }); + btn3.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + String name = ((JButton)e.getSource()).getName(); + neuralNet.addNeuron(Integer.parseInt(name), slope); + refreshPanelTopology(); + + } + }); + panelTopology.add(btn1); + panelTopology.add(btn2); + panelTopology.add(btn3); + } else { + panelTopology.add(new JLabel(" ")); + panelTopology.add(new JLabel(" ")); + panelTopology.add(new JLabel(" ")); + } + } + panelTopology.add(new JLabel("Inputs")); + panelTopology.add(new JLabel(String.valueOf(neuralNet.getNrOfInputs()))); + panelTopology.add(new JLabel(" ")); + panelTopology.add(new JLabel(" ")); + frmBPnet.revalidate(); + } + }); + mnFile.add(mntmLoadData); + + JMenuItem mntmExit = new JMenuItem("Exit"); + mntmExit.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent arg0) { + frmBPnet.dispatchEvent(new WindowEvent(frmBPnet, WindowEvent.WINDOW_CLOSING)); + } + }); + + mntmSaveNeuralNet = new JMenuItem("Save Neural Net"); + mntmSaveNeuralNet.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) { + try { + File address = null; + JFileChooser fc = new JFileChooser(); + FileFilter filter = new FileFilter() { + + @Override + public String getDescription() { + return "Xml files"; + } + + @Override + public boolean accept(File f) { + return (f.getName().endsWith(".xml") || f.isDirectory()); + } + }; + fc.setFileFilter(filter); + fc.setCurrentDirectory(new java.io.File(".")); + //fc.setFileSelectionMode(JFileChooser.SAVE_DIALOG); + if (fc.showSaveDialog(frmBPnet) == JFileChooser.APPROVE_OPTION) { + address = fc.getSelectedFile(); + XStream xstream = new XStream(); + String xml = xstream.toXML(neuralNet); + BufferedWriter out = new BufferedWriter(new FileWriter(address)); + out.write(xml); + out.close(); +// BPNet testNet = (BPNet)xstream.fromXML(xml); + + JOptionPane.showMessageDialog(null, "Hotovo"); + } + } + catch (Exception ex) { + ex.printStackTrace(); + JOptionPane.showMessageDialog(null, ex.getMessage()); + } + + } + }); + mntmSaveNeuralNet.setEnabled(false); + mnFile.add(mntmSaveNeuralNet); + mnFile.add(mntmExit); + } +} diff --git a/java/src/SigmoidalNeuron.java b/java/src/SigmoidalNeuron.java new file mode 100644 index 0000000000000000000000000000000000000000..706eea0f33efc56898645794d3acb358c2664ca0 --- /dev/null +++ b/java/src/SigmoidalNeuron.java @@ -0,0 +1,53 @@ +package cz.vsb.mro0010.neuralnetworks; + +public class SigmoidalNeuron extends Neuron { + + private float error; //delta + private float slope; //lambda + + + + + public void setSlope(float slope) { + this.slope = slope; + } + + + public SigmoidalNeuron(float slope) { + this.slope = slope; + this.error = 0; + } + + + @Override + public void transfer() { + float z = this.getPotential(); + float y = (float) (1.0/(1.0 + Math.exp(-slope*z))); + this.setState(y); + } + + + public float getSlope() { + return slope; + } + + public float getError() { + return error; + } + + + public void setError(float error) { + this.error = error; + } + + + public static void main(String args[]) { + SigmoidalNeuron neuron = new SigmoidalNeuron((float)0.5); + for (int i = -10; i <= 10; i++) { + neuron.initialize(); + neuron.adjustPotential(i); + neuron.transfer(); + System.out.println(neuron.getState()); + } + } +} diff --git a/java/src/SinglePerceptronNeuralNet.java b/java/src/SinglePerceptronNeuralNet.java new file mode 100644 index 0000000000000000000000000000000000000000..e741be2594a37de90b5b60c0c6a40bd2c1fbfbca --- /dev/null +++ b/java/src/SinglePerceptronNeuralNet.java @@ -0,0 +1,147 @@ +package cz.vsb.mro0010.neuralnetworks; + +import java.util.ArrayList; +import java.util.Arrays; + +public class SinglePerceptronNeuralNet extends NeuralNet { + + private Neuron neuron; + private int nrOfInputs; + private ArrayList<Connection> connections; + private ArrayList<InputLayerPseudoNeuron> input; + private String trainingOutput; + private float learnCoef; + + public SinglePerceptronNeuralNet(Neuron neuron, int nrOfInputs, float learnCoef) { + super(); + this.neuron = neuron; + this.nrOfInputs = nrOfInputs; + this.input = new ArrayList<InputLayerPseudoNeuron>(); + this.connections = new ArrayList<Connection>(); + for (int i = 0; i < this.nrOfInputs; i++) { + InputLayerPseudoNeuron inputNeuron = new InputLayerPseudoNeuron(); + this.input.add(inputNeuron); + this.connections.add(new Connection(inputNeuron, neuron, (float)Math.random())); + } + this.setTrainingOutput(" "); + this.learnCoef = learnCoef; + } + + @Override + public String getNeuronType() { + return neuron.getClass().getSimpleName(); + } + + @Override + public int learn(String trainingSet) { + ArrayList<String> trainingElements = new ArrayList<String>(Arrays.asList(trainingSet.split("\n"))); + boolean learned = false; + int iterations = 0; + StringBuffer trainingProgress = new StringBuffer(); + for (Connection c : connections) { + trainingProgress.append(String.valueOf(c.getWeight())); + trainingProgress.append(" "); + } + trainingProgress.append(String.valueOf(-neuron.getThreshold())); + trainingProgress.append("\n"); + while (!learned) { + iterations++; + learned = true; + for (String element : trainingElements) { + String[] sa = element.split(" "); + String expectedOutput = sa[sa.length - 1]; + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < sa.length - 1; i++) { + sb.append(sa[i]); + sb.append(" "); + } + this.run(sb.toString()); + + if (Float.parseFloat(expectedOutput) != Float.parseFloat(this.getOutput())) { + learned = false; + float eo = Float.parseFloat(expectedOutput); + float ro = Float.parseFloat(this.getOutput()); + neuron.setThreshold(neuron.getThreshold() + learnCoef*-(eo-ro)*1); // w_0 = -threshold + for (Connection c : connections) { + c.adjustWeight(learnCoef*(eo-ro)*c.getInputNeuron().getState()); + } + for (Connection c : connections) { + trainingProgress.append(String.valueOf(c.getWeight())); + trainingProgress.append(" "); + } + trainingProgress.append(String.valueOf(neuron.getThreshold())); + trainingProgress.append("\n"); + } + } + } + //System.out.println("Learned! in " + (iterations-1) + " iterations"); + this.setTrainingOutput(trainingProgress.toString()); + return iterations; + } + + @Override + public void run(String inputString) { + String[] input = inputString.split(" "); + for (int i = 0; i < input.length; i++) { + InputLayerPseudoNeuron in = this.input.get(i); + in.initialize(); + in.adjustPotential(Float.parseFloat(input[i])); + in.transfer(); + } + neuron.initialize(); + for (Connection c : connections) { + c.passSignal(); + } + neuron.transfer(); + + } + + public String getOutput() { + String output = String.valueOf(neuron.getState()); + return output; + } + + public String getTrainingOutput() { + return trainingOutput; + } + + private void setTrainingOutput(String trainingOutput) { + this.trainingOutput = trainingOutput; + } + + /*public static void main(String[] args) { + SinglePerceptronNeuralNet net = new SinglePerceptronNeuralNet(new BinaryNeuron(), 2, (float)0.7); + net.neuron.setThreshold((float) Math.random()); +// String learnSet = "1 0.5 0\n0.4 0.8 1\n0.1 0.1 0\n0.6 0.9 1\n0.8 0.7 0\n0.4 1.0 1"; +// net.learn(learnSet); +// net.run("0.7 0.9"); +// System.out.println(net.getOutput()); +// net.run("0.9 0.7"); +// System.out.println(net.getOutput()); +// net.run("0.2 0.2"); +// System.out.println(net.getOutput()); +// net.run("0.1 1.0"); +// System.out.println(net.getOutput()); +// net.run("1.0 0.1"); +// System.out.println(net.getOutput()); + String learnSet = "0.7 0.3 0\n0.2 0.6 1\n0.3 0.4 1\n0.9 0.8 0\n0.1 0.2 1\n0.5 0.6 1"; + net.learn(learnSet); + net.run("0.7 0.9"); + System.out.println(net.getOutput()); + net.run("0.9 0.7"); + System.out.println(net.getOutput()); + net.run("0.2 0.2"); + System.out.println(net.getOutput()); + net.run("0.1 1.0"); + System.out.println(net.getOutput()); + net.run("1.0 0.1"); + System.out.println(net.getOutput()); + net.run("0.6 0.5"); + System.out.println(net.getOutput()); + net.run("0.5 0.6"); + System.out.println(net.getOutput()); + }*/ + + + +}