Skip to content
Snippets Groups Projects
net_test_3.cpp 3.96 KiB
/**
 * Example of a set of neural networks sharing some edge weights
 * The system of equations associated with the net in this example is not regular
 * minimizes the function: [(2y+0.5)^2 + (2x+y+0.25)^2] / 2 + [(4.5x + 0.37)^2] / 1
 * minimum [0.010024714] at (x, y) = (-333/4370, -9593/43700) = (-0.076201373, -0.219519451)
 * */

//
// Created by martin on 7/16/18.
//

#include <vector>

#include "../include/4neuro.h"

int main() {

    /* TRAIN DATA DEFINITION */
    std::vector<std::pair<std::vector<double>, std::vector<double>>> data_vec_01, data_vec_02;
    std::vector<double> inp, out;

    inp = {0, 1};
    out = {0.5};
    data_vec_01.emplace_back(std::make_pair(inp, out));

    inp = {1, 0.5};
    out = {0.75};
    data_vec_01.emplace_back(std::make_pair(inp, out));

    DataSet ds_01(&data_vec_01);


    inp = {1.25};
    out = {0.63};
    data_vec_02.emplace_back(std::make_pair(inp, out));
    DataSet ds_02(&data_vec_02);

    /* NETWORK DEFINITION */
    NeuralNetwork net;

    /* Input neurons */
    NeuronLinear *i1 = new NeuronLinear();  //f(x) = x
    NeuronLinear *i2 = new NeuronLinear();  //f(x) = x

    double b = 1;//bias
    NeuronLinear *i3 = new NeuronLinear( &b ); //f(x) = x + 1

    /* Output neurons */
    NeuronLinear *o1 = new NeuronLinear(&b);  //f(x) = x + 1
    NeuronLinear *o2 = new NeuronLinear(&b);  //f(x) = x + 1



    /* Adding neurons to the nets */
    int idx1 = net.add_neuron(i1);
    int idx2 = net.add_neuron(i2);
    int idx3 = net.add_neuron(o1);
    int idx4 = net.add_neuron(i3);
    int idx5 = net.add_neuron(o2);

    /* Adding connections */
    //net.add_connection_simple(idx1, idx3, -1, 1.0);
    //net.add_connection_simple(idx2, idx3, -1, 1.0);
    net.add_connection_simple(idx1, idx3); // weight index 0
    net.add_connection_simple(idx2, idx3); // weight index 1
    net.add_connection_simple(idx4, idx5, 0); // AGAIN weight index 0 - same weight!

    net.randomize_weights();

    /* specification of the input/output neurons */
    std::vector<size_t> net_input_neurons_indices(3);
    std::vector<size_t> net_output_neurons_indices(2);
    net_input_neurons_indices[0] = idx1;
    net_input_neurons_indices[1] = idx2;
    net_input_neurons_indices[2] = idx4;

    net_output_neurons_indices[0] = idx3;
    net_output_neurons_indices[1] = idx5;

    net.specify_input_neurons(net_input_neurons_indices);
    net.specify_output_neurons(net_output_neurons_indices);


    /* CONSTRUCTION OF SUBNETWORKS */
    //TODO subnetworks retain the number of weights, could be optimized to include only the used weights
    std::vector<size_t> subnet_01_input_neurons, subnet_01_output_neurons;
    std::vector<size_t> subnet_02_input_neurons, subnet_02_output_neurons;

    subnet_01_input_neurons.push_back(idx1);
    subnet_01_input_neurons.push_back(idx2);
    subnet_01_output_neurons.push_back(idx3);
    NeuralNetwork *subnet_01 = net.get_subnet(subnet_01_input_neurons, subnet_01_output_neurons);

    subnet_02_input_neurons.push_back(idx4);
    subnet_02_output_neurons.push_back(idx5);
    NeuralNetwork *subnet_02 = net.get_subnet(subnet_02_input_neurons, subnet_02_output_neurons);

    /* COMPLEX ERROR FUNCTION SPECIFICATION */
    MSE mse_01(subnet_01, &ds_01);
    MSE mse_02(subnet_02, &ds_02);

    ErrorSum mse_sum;
    mse_sum.add_error_function( &mse_01 );
    mse_sum.add_error_function( &mse_02 );

    /* TRAINING METHOD SETUP */
    unsigned int max_iters = 50;


    //must encapsulate each of the partial error functions
    double domain_bounds[4] = {-800.0, 800.0, -800.0, 800.0};

    double c1 = 0.5, c2 = 1.5, w = 0.8;

    unsigned int n_particles = 100;

//    printf("mse2: %d\n", mse_02.get_dimension());
    ParticleSwarm swarm_01(&mse_sum, domain_bounds, c1, c2, w, n_particles, max_iters);

    swarm_01.optimize(0.5, 0.02, 0.9);

//    double weights[2] = {0, -0.25};
//    printf("evaluation of error at (x, y) = (%f, %f): %f\n", weights[0], weights[1], mse_01.eval(weights));

    delete subnet_02;
    delete subnet_01;
    return 0;
}