//
// Created by martin on 20.08.19.
//

#define ARMA_ALLOW_FAKE_GCC
#include <4neuro.h>


void optimize_via_particle_swarm(l4n::NeuralNetwork& net,
                                 l4n::ErrorFunction& ef) {

    /* TRAINING METHOD SETUP */
    std::vector<double> domain_bounds(2 * (net.get_n_weights() + net.get_n_biases()));

    for (size_t i = 0; i < domain_bounds.size() / 2; ++i) {
        domain_bounds[2 * i]     = -150;
        domain_bounds[2 * i + 1] = 150;
    }

    double c1          = 1.7;
    double c2          = 1.7;
    double w           = 0.7;
    size_t n_particles = 300;
    size_t iter_max    = 500;

    /* if the maximal velocity from the previous step is less than 'gamma' times the current maximal velocity, then one
     * terminating criterion is met */
    double gamma = 0.5;

    /* if 'delta' times 'n' particles are in the centroid neighborhood given by the radius 'epsilon', then the second
     * terminating criterion is met ('n' is the total number of particles) */
    double epsilon = 0.02;
    double delta   = 0.7;

    l4n::ParticleSwarm swarm_01(
        &domain_bounds,
        c1,
        c2,
        w,
        gamma,
        epsilon,
        delta,
        n_particles,
        iter_max
    );
    swarm_01.optimize(ef);

    net.copy_parameter_space(swarm_01.get_parameters());

    /* ERROR CALCULATION */
    std::cout << "Run finished! Error of the network[Particle swarm]: " << ef.eval(nullptr) << std::endl;
    std::cout
        << "***********************************************************************************************************************"
        << std::endl;
}

double optimize_via_gradient_descent(l4n::NeuralNetwork& net,
                                     l4n::ErrorFunction& ef) {

    std::cout
        << "***********************************************************************************************************************"
        << std::endl;

//    for( double tol = 1e-1; tol > 1e-6; tol *= 1e-1 ){
        l4n::GradientDescentBB gd(1e-4,
                                  1000,
                                  100000);

        l4n::LazyLearning lazy_wrapper( gd, 1e-4 );
        lazy_wrapper.optimize(ef);
//    }
//    gd.optimize(ef);

//    net.copy_parameter_space(gd.get_parameters());

    /* ERROR CALCULATION */
    double err = ef.eval(nullptr);
    COUT_INFO("Run finished! Error of the network[Gradient descent]: " << err );

    /* Just for validation test purposes - NOT necessary for the example to work! */
    return err;
}

double optimize_via_LBMQ(l4n::NeuralNetwork& net,
                         l4n::ErrorFunction& ef) {

    size_t max_iterations = 1000;
    size_t batch_size = 0;
    double tolerance = 1e-4;
    double tolerance_gradient = tolerance * 1e-6;
    double tolerance_parameters = tolerance * 1e-6;

    std::cout
        << "***********************************************************************************************************************"
        << std::endl;
//    for( double tol = 1; tol > tolerance; tol *= 0.5 ){
        l4n::LevenbergMarquardt lm(
            max_iterations,
            batch_size,
            tolerance,
            tolerance_gradient,
            tolerance_parameters
        );

        l4n::LazyLearning lazy_wrapper( lm, tolerance );
        lazy_wrapper.optimize(ef);
//        lm.optimize(ef);
//        break;
//    }

//    lm.optimize( ef );

   // net.copy_parameter_space(lm.get_parameters());

    /* ERROR CALCULATION */
    double err = ef.eval(nullptr);
    // std::cout << "Run finished! Error of the network[Levenberg-Marquardt]: " << err << std::endl;

    /* Just for validation test purposes - NOT necessary for the example to work! */
    return err;
}

double optimize_via_NelderMead(l4n::NeuralNetwork& net, l4n::ErrorFunction& ef) {
    l4n::NelderMead nm(500, 200);

    nm.optimize(ef);
    net.copy_parameter_space(nm.get_parameters());

    /* ERROR CALCULATION */
    double err = ef.eval(nullptr);
    std::cout << "Run finished! Error of the network[Nelder-Mead]: " << err << std::endl;

    /* Just for validation test purposes - NOT necessary for the example to work! */
    return err;

}


int main() {
	MPI_INIT
    try{

         /* Specify cutoff functions */
        l4n::CutoffFunction2 cutoff2(8);

        /* Specify symmetry functions */
        l4n::G2 sym_f1(&cutoff2, 0.00, 0);
        l4n::G2 sym_f2(&cutoff2, 0.02, 1);
        l4n::G2 sym_f3(&cutoff2, 0.04, 2);
        l4n::G2 sym_f4(&cutoff2, 0.06, 3);
        l4n::G2 sym_f5(&cutoff2, 0.08, 4);
        l4n::G2 sym_f6(&cutoff2, 0.10, 5);
        l4n::G2 sym_f7(&cutoff2, 0.12, 6);
        l4n::G2 sym_f8(&cutoff2, 0.14, 7);
        l4n::G2 sym_f9(&cutoff2, 0.16, 8);

        l4n::G5 sym_f10(&cutoff2, 0, -1, 0);
        l4n::G5 sym_f11(&cutoff2, 0, -1, 3);
        l4n::G5 sym_f12(&cutoff2, 0, -1, 6);
        l4n::G5 sym_f13(&cutoff2, 0, -1, 9);
        l4n::G5 sym_f14(&cutoff2, 0, -1, 12);
        l4n::G5 sym_f15(&cutoff2, 0, -1, 15);


        std::vector<l4n::SymmetryFunction*> helium_sym_funcs = {&sym_f1,
                                                                &sym_f2,
                                                                &sym_f3,
                                                                &sym_f4,
                                                                &sym_f5,
                                                                &sym_f6
															};

        l4n::Element helium = l4n::Element("He",
                                           helium_sym_funcs);
        std::unordered_map<l4n::ELEMENT_SYMBOL, l4n::Element*> elements;
        elements[l4n::ELEMENT_SYMBOL::He] = &helium;

        /* Read data */
        l4n::XYZReader reader("../../data/HE4+T0_000050.xyz", true);
        reader.read();

        COUT_INFO( "Finished reading data" );

        std::shared_ptr<l4n::DataSet> ds = reader.get_acsf_data_set( elements );
//         ds->print_data();

        /* Create a neural network */
        std::unordered_map<l4n::ELEMENT_SYMBOL, std::vector<unsigned int>> n_hidden_neurons;
        n_hidden_neurons[l4n::ELEMENT_SYMBOL::He] = {10, 6, 1};

        std::unordered_map<l4n::ELEMENT_SYMBOL, std::vector<l4n::NEURON_TYPE>> type_hidden_neurons;
        type_hidden_neurons[l4n::ELEMENT_SYMBOL::He] = {l4n::NEURON_TYPE::LOGISTIC, l4n::NEURON_TYPE::LOGISTIC, l4n::NEURON_TYPE::LINEAR};

        l4n::ACSFNeuralNetwork net(elements, *reader.get_element_list(), reader.contains_charge(), n_hidden_neurons, type_hidden_neurons);

        l4n::MSE mse(&net, ds.get(), false);

        net.randomize_parameters();

        for(size_t i = 0; i < ds->get_data()->at(0).first.size(); i++) {
            std::cout << ds->get_data()->at(0).first.at(i) << " ";
            if(i % 2 == 1) {
                std::cout << std::endl;
            }
        }
        std::cout << "----" << std::endl;

        l4n::ACSFParametersOptimizer param_optim(&mse, &reader);
        std::vector<l4n::SYMMETRY_FUNCTION_PARAMETER> fitted_params = {l4n::SYMMETRY_FUNCTION_PARAMETER::EXTENSION,
                                                                       l4n::SYMMETRY_FUNCTION_PARAMETER::SHIFT_MAX,
                                                                       l4n::SYMMETRY_FUNCTION_PARAMETER::SHIFT,
                                                                       l4n::SYMMETRY_FUNCTION_PARAMETER::ANGULAR_RESOLUTION};
//                                                                       l4n::SYMMETRY_FUNCTION_PARAMETER::PERIOD_LENGTH};

        // param_optim.fit_ACSF_parameters(fitted_params,
                                        // false,
                                        // 50,
                                        // 10,
                                        // 1e-5,
                                        // 0.98,
                                        // 0.085,
                                        // 1e-6);

        for(size_t i = 0; i < mse.get_dataset()->get_data()->at(0).first.size(); i++) {
            std::cout << mse.get_dataset()->get_data()->at(0).first.at(i) << " ";
            if(i % 2 == 1) {
                std::cout << std::endl;
            }
        }
        std::cout << "----" << std::endl;


//         optimize_via_particle_swarm(net, mse);
//
//
////        optimize_via_NelderMead(net, mse);
//
        double err1 = optimize_via_LBMQ(net, mse);
//        double err2 = optimize_via_gradient_descent(net, mse);

//        std::cout << "Weights: " << net.get_min_max_weight().first << " " << net.get_min_max_weight().second << std::endl;

        /* Print fit comparison with real data */
        std::vector<double> output;
        output.resize(1);

//        for(auto e : *mse.get_dataset()->get_data()) {
////            for(size_t i = 0; i < e.first.size(); i++) {
////                std::cout << e.first.at(i) << " ";
////                if(i % 2 == 1) {
////                    std::cout << std::endl;
////                }
////            }
//            std::cout << "OUTS (DS, predict): " << e.second.at(0) << " ";
//            net.eval_single(e.first, output);
//            std::cout << output.at(0) << std::endl;
//        }

    } catch (const std::exception& e) {
        std::cerr << e.what() << std::endl;
        exit(EXIT_FAILURE);
    }

//    arma::Mat<double> m = {{1,2,3}, {4,5,6}, {7,8,9}};
//    arma::Col<double> v = arma::conv_to<std::vector<double>(m);



//    std::cout << arma::stddev(m) << std::endl;
	MPI_FINISH
    return 0;
}