/**
 * Example of a neural network with reused edge weights
 */


#include <vector>

#include <4neuro.h>

void optimize_via_particle_swarm(l4n::NeuralNetwork& net,
                                 l4n::ErrorFunction& ef) {

    /* TRAINING METHOD SETUP */
    std::vector<double> domain_bounds(2 * (net.get_n_weights() + net.get_n_biases()));

    for (size_t i = 0; i < domain_bounds.size() / 2; ++i) {
        domain_bounds[2 * i]     = -10;
        domain_bounds[2 * i + 1] = 10;
    }

    double c1          = 1.7;
    double c2          = 1.7;
    double w           = 0.7;
    size_t n_particles = 50;
    size_t iter_max    = 10;

    /* if the maximal velocity from the previous step is less than 'gamma' times the current maximal velocity, then one
     * terminating criterion is met */
    double gamma = 0.5;

    /* if 'delta' times 'n' particles are in the centroid neighborhood given by the radius 'epsilon', then the second
     * terminating criterion is met ('n' is the total number of particles) */
    double epsilon = 0.02;
    double delta   = 0.7;

    l4n::ParticleSwarm swarm_01(
        &domain_bounds,
        c1,
        c2,
        w,
        gamma,
        epsilon,
        delta,
        n_particles,
        iter_max
    );
    swarm_01.optimize(ef);

    net.copy_parameter_space(swarm_01.get_parameters());

    std::cout << "Run finished! Error of the network[Particle swarm]: " << ef.eval(nullptr) << std::endl;
    std::cout
        << "***********************************************************************************************************************"
        << std::endl;
}

void optimize_via_gradient_descent(l4n::NeuralNetwork& net,
                                   l4n::ErrorFunction& ef) {

    l4n::GradientDescentBB gd(1e-6,
                              1000);

    gd.optimize(ef);

    net.copy_parameter_space(gd.get_parameters());

    /* ERROR CALCULATION */
    std::cout << "Run finished! Error of the network[Gradient descent]: " << ef.eval(nullptr) << std::endl;
    std::cout
        << "***********************************************************************************************************************"
        << std::endl;
}

int main() {
    std::cout
        << "Running lib4neuro example   2: Basic use of the particle swarm method to train a network with five linear neurons and repeating edge weights"
        << std::endl;
    std::cout
        << "********************************************************************************************************************************************"
        << std::endl;
    std::cout << "The code attempts to find an approximate solution to the system of equations below:" << std::endl;
    std::cout << " 0 * w1 + 1 * w2 = 0.50 + b1" << std::endl;
    std::cout << " 1 * w1 + 0.5*w2 = 0.75 + b1" << std::endl;
    std::cout << "(1.25 + b2) * w2 = 0.63 + b3" << std::endl;
    std::cout
        << "***********************************************************************************************************************"
        << std::endl;

    /* TRAIN DATA DEFINITION */
    std::vector<std::pair<std::vector<double>, std::vector<double>>> data_vec;
    std::vector<double>                                              inp, out;

    inp = {0, 1, 0};
    out = {0.5, 0};
    data_vec.emplace_back(std::make_pair(inp,
                                         out));

    inp = {1, 0.5, 0};
    out = {0.75, 0};
    data_vec.emplace_back(std::make_pair(inp,
                                         out));

    inp = {0, 0, 1.25};
    out = {0, 0.63};
    data_vec.emplace_back(std::make_pair(inp,
                                         out));
    l4n::DataSet ds(&data_vec);

    /* NETWORK DEFINITION */
    l4n::NeuralNetwork net;

    /* Input neurons */
    std::shared_ptr<l4n::NeuronLinear> i1 = std::make_shared<l4n::NeuronLinear>();
    std::shared_ptr<l4n::NeuronLinear> i2 = std::make_shared<l4n::NeuronLinear>();

    std::shared_ptr<l4n::NeuronLinear> i3 = std::make_shared<l4n::NeuronLinear>();

    /* Output neurons */
    std::shared_ptr<l4n::NeuronLinear> o1 = std::make_shared<l4n::NeuronLinear>();
    std::shared_ptr<l4n::NeuronLinear> o2 = std::make_shared<l4n::NeuronLinear>();

    /* Adding neurons to the nets */
    size_t idx1 = net.add_neuron(i1,
                                 l4n::BIAS_TYPE::NO_BIAS);
    size_t idx2 = net.add_neuron(i2,
                                 l4n::BIAS_TYPE::NO_BIAS);
    size_t idx3 = net.add_neuron(o1,
                                 l4n::BIAS_TYPE::NEXT_BIAS);
    size_t idx4 = net.add_neuron(i3,
                                 l4n::BIAS_TYPE::NEXT_BIAS);
    size_t idx5 = net.add_neuron(o2,
                                 l4n::BIAS_TYPE::NEXT_BIAS);

    /* Adding connections */
    net.add_connection_simple(idx1,
                              idx3,
                              l4n::SIMPLE_CONNECTION_TYPE::NEXT_WEIGHT); // weight index 0
    net.add_connection_simple(idx2,
                              idx3,
                              l4n::SIMPLE_CONNECTION_TYPE::NEXT_WEIGHT); // weight index 1
    net.add_connection_simple(idx4,
                              idx5,
                              l4n::SIMPLE_CONNECTION_TYPE::EXISTING_WEIGHT,
                              0); // AGAIN weight index 0 - same weight!

    /* specification of the input/output neurons */
    std::vector<size_t> net_input_neurons_indices(3);
    std::vector<size_t> net_output_neurons_indices(2);
    net_input_neurons_indices[0] = idx1;
    net_input_neurons_indices[1] = idx2;
    net_input_neurons_indices[2] = idx4;

    net_output_neurons_indices[0] = idx3;
    net_output_neurons_indices[1] = idx5;

    net.specify_input_neurons(net_input_neurons_indices);
    net.specify_output_neurons(net_output_neurons_indices);

    /* COMPLEX ERROR FUNCTION SPECIFICATION */
    l4n::MSE mse(&net,
                 &ds);

    /* PARTICLE SWARM LEARNING */
    net.randomize_weights();
    optimize_via_particle_swarm(net,
                                mse);


    /* GRADIENT DESCENT LEARNING */
    net.randomize_weights();
    optimize_via_gradient_descent(net,
                                  mse);

    return 0;
}