net_test_2.cpp 3.08 KB
Newer Older
1
/**
Michal Kravcenko's avatar
Michal Kravcenko committed
2 3 4 5
 * Example of a neural network with reused edge weights
 * The system of equations associated with the net in this example is not regular
 * minimizes the function: ((2y+0.5)^2 + (2x+1)^2 + (2x + y + 0.25)^2 + (2x+1)^2 + 1 + (4.5x + 0.37)^2 ) /3
 * minimum [0.705493164] at (x, y) = (-1133/6290, -11193/62900) = (-0.180127186, -0.177949126)
6 7 8
 */

//
Michal Kravcenko's avatar
Michal Kravcenko committed
9
// Created by Michal on 7/17/18.
10 11 12 13 14 15 16 17 18
//

#include <vector>

#include "../include/4neuro.h"

int main() {

    /* TRAIN DATA DEFINITION */
Michal Kravcenko's avatar
Michal Kravcenko committed
19
    std::vector<std::pair<std::vector<double>, std::vector<double>>> data_vec;
20 21
    std::vector<double> inp, out;

Michal Kravcenko's avatar
Michal Kravcenko committed
22 23 24
    inp = {0, 1, 0};
    out = {0.5, 0};
    data_vec.emplace_back(std::make_pair(inp, out));
25

Michal Kravcenko's avatar
Michal Kravcenko committed
26 27 28
    inp = {1, 0.5, 0};
    out = {0.75, 0};
    data_vec.emplace_back(std::make_pair(inp, out));
29

Michal Kravcenko's avatar
Michal Kravcenko committed
30 31 32 33
    inp = {0, 0, 1.25};
    out = {0, 0.63};
    data_vec.emplace_back(std::make_pair(inp, out));
    DataSet ds(&data_vec);
34 35

    /* NETWORK DEFINITION */
36
    NeuralNetwork net;
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53

    /* Input neurons */
    NeuronLinear *i1 = new NeuronLinear(0.0, 1.0);  //f(x) = x
    NeuronLinear *i2 = new NeuronLinear(0.0, 1.0);  //f(x) = x

    NeuronLinear *i3 = new NeuronLinear(1, 1); //f(x) = x + 1

    /* Output neurons */
    NeuronLinear *o1 = new NeuronLinear(1.0, 2.0);  //f(x) = 2x + 1
    NeuronLinear *o2 = new NeuronLinear(1, 2);  //f(x) = 2x + 1



    /* Adding neurons to the nets */
    int idx1 = net.add_neuron(i1);
    int idx2 = net.add_neuron(i2);
    int idx3 = net.add_neuron(o1);
54 55
    int idx4 = net.add_neuron(i3);
    int idx5 = net.add_neuron(o2);
56 57 58 59 60 61 62 63

    /* Adding connections */
    //net.add_connection_simple(idx1, idx3, -1, 1.0);
    //net.add_connection_simple(idx2, idx3, -1, 1.0);
    net.add_connection_simple(idx1, idx3); // weight index 0
    net.add_connection_simple(idx2, idx3); // weight index 1
    net.add_connection_simple(idx4, idx5, 0); // AGAIN weight index 0 - same weight!

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    net.randomize_weights();

    /* specification of the input/output neurons */
    std::vector<size_t> net_input_neurons_indices(3);
    std::vector<size_t> net_output_neurons_indices(2);
    net_input_neurons_indices[0] = idx1;
    net_input_neurons_indices[1] = idx2;
    net_input_neurons_indices[2] = idx4;

    net_output_neurons_indices[0] = idx3;
    net_output_neurons_indices[1] = idx5;

    net.specify_input_neurons(net_input_neurons_indices);
    net.specify_output_neurons(net_output_neurons_indices);




82
    /* COMPLEX ERROR FUNCTION SPECIFICATION */
Michal Kravcenko's avatar
Michal Kravcenko committed
83 84 85 86
    MSE mse(&net, &ds);

//    double weights[2] = {-0.18012411, -0.17793740};
//    double weights[2] = {1, 1};
87

Michal Kravcenko's avatar
Michal Kravcenko committed
88
//    printf("evaluation of error at point (%f, %f) => %f\n", weights[0], weights[1], mse.eval(weights));
89 90

    /* TRAINING METHOD SETUP */
91
    unsigned int max_iters = 5000;
92 93


94
    //must encapsulate each of the partial error functions
95 96 97 98
    double domain_bounds[4] = {-800.0, 800.0, -800.0, 800.0};

    double c1 = 0.5, c2 = 1.5, w = 0.8;

Michal Kravcenko's avatar
Michal Kravcenko committed
99
    unsigned int n_particles = 100;
100

Michal Kravcenko's avatar
Michal Kravcenko committed
101
    ParticleSwarm swarm_01(&mse, domain_bounds, c1, c2, w, n_particles, max_iters);
102

Michal Kravcenko's avatar
Michal Kravcenko committed
103
    swarm_01.optimize(0.5, 0.02, 0.9);
104

Michal Kravcenko's avatar
Michal Kravcenko committed
105
    printf("evaluation of error: %f\n", mse.eval());
106

107 108
    return 0;
}