Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
package cz.vsb.mro0010.neuralnetworks;
import java.util.ArrayList;
import java.util.Arrays;
public class SinglePerceptronNeuralNet extends NeuralNet {
private Neuron neuron;
private int nrOfInputs;
private ArrayList<Connection> connections;
private ArrayList<InputLayerPseudoNeuron> input;
private String trainingOutput;
private float learnCoef;
public SinglePerceptronNeuralNet(Neuron neuron, int nrOfInputs, float learnCoef) {
super();
this.neuron = neuron;
this.nrOfInputs = nrOfInputs;
this.input = new ArrayList<InputLayerPseudoNeuron>();
this.connections = new ArrayList<Connection>();
for (int i = 0; i < this.nrOfInputs; i++) {
InputLayerPseudoNeuron inputNeuron = new InputLayerPseudoNeuron();
this.input.add(inputNeuron);
this.connections.add(new Connection(inputNeuron, neuron, (float)Math.random()));
}
this.setTrainingOutput(" ");
this.learnCoef = learnCoef;
}
@Override
public String getNeuronType() {
return neuron.getClass().getSimpleName();
}
@Override
public int learn(String trainingSet) {
ArrayList<String> trainingElements = new ArrayList<String>(Arrays.asList(trainingSet.split("\n")));
boolean learned = false;
int iterations = 0;
StringBuffer trainingProgress = new StringBuffer();
for (Connection c : connections) {
trainingProgress.append(String.valueOf(c.getWeight()));
trainingProgress.append(" ");
}
trainingProgress.append(String.valueOf(-neuron.getThreshold()));
trainingProgress.append("\n");
while (!learned) {
iterations++;
learned = true;
for (String element : trainingElements) {
String[] sa = element.split(" ");
String expectedOutput = sa[sa.length - 1];
StringBuffer sb = new StringBuffer();
for (int i = 0; i < sa.length - 1; i++) {
sb.append(sa[i]);
sb.append(" ");
}
this.run(sb.toString());
if (Float.parseFloat(expectedOutput) != Float.parseFloat(this.getOutput())) {
learned = false;
float eo = Float.parseFloat(expectedOutput);
float ro = Float.parseFloat(this.getOutput());
neuron.setThreshold(neuron.getThreshold() + learnCoef*-(eo-ro)*1); // w_0 = -threshold
for (Connection c : connections) {
c.adjustWeight(learnCoef*(eo-ro)*c.getInputNeuron().getState());
}
for (Connection c : connections) {
trainingProgress.append(String.valueOf(c.getWeight()));
trainingProgress.append(" ");
}
trainingProgress.append(String.valueOf(neuron.getThreshold()));
trainingProgress.append("\n");
}
}
}
//System.out.println("Learned! in " + (iterations-1) + " iterations");
this.setTrainingOutput(trainingProgress.toString());
return iterations;
}
@Override
public void run(String inputString) {
String[] input = inputString.split(" ");
for (int i = 0; i < input.length; i++) {
InputLayerPseudoNeuron in = this.input.get(i);
in.initialize();
in.adjustPotential(Float.parseFloat(input[i]));
in.transfer();
}
neuron.initialize();
for (Connection c : connections) {
c.passSignal();
}
neuron.transfer();
}
public String getOutput() {
String output = String.valueOf(neuron.getState());
return output;
}
public String getTrainingOutput() {
return trainingOutput;
}
private void setTrainingOutput(String trainingOutput) {
this.trainingOutput = trainingOutput;
}
/*public static void main(String[] args) {
SinglePerceptronNeuralNet net = new SinglePerceptronNeuralNet(new BinaryNeuron(), 2, (float)0.7);
net.neuron.setThreshold((float) Math.random());
// String learnSet = "1 0.5 0\n0.4 0.8 1\n0.1 0.1 0\n0.6 0.9 1\n0.8 0.7 0\n0.4 1.0 1";
// net.learn(learnSet);
// net.run("0.7 0.9");
// System.out.println(net.getOutput());
// net.run("0.9 0.7");
// System.out.println(net.getOutput());
// net.run("0.2 0.2");
// System.out.println(net.getOutput());
// net.run("0.1 1.0");
// System.out.println(net.getOutput());
// net.run("1.0 0.1");
// System.out.println(net.getOutput());
String learnSet = "0.7 0.3 0\n0.2 0.6 1\n0.3 0.4 1\n0.9 0.8 0\n0.1 0.2 1\n0.5 0.6 1";
net.learn(learnSet);
net.run("0.7 0.9");
System.out.println(net.getOutput());
net.run("0.9 0.7");
System.out.println(net.getOutput());
net.run("0.2 0.2");
System.out.println(net.getOutput());
net.run("0.1 1.0");
System.out.println(net.getOutput());
net.run("1.0 0.1");
System.out.println(net.getOutput());
net.run("0.6 0.5");
System.out.println(net.getOutput());
net.run("0.5 0.6");
System.out.println(net.getOutput());
}*/
}