//
// Created by martin on 7/15/18.
//

#include <vector>

#include "ErrorFunctions.h"

namespace lib4neuro {

    size_t ErrorFunction::get_dimension() {
        return this->dimension;
    }

    MSE::MSE(NeuralNetwork *net, DataSet *ds) {
        this->net = net;
        this->ds = ds;
        this->dimension = net->get_n_weights() + net->get_n_biases();
    }

    double MSE::eval(std::vector<double> *weights) {
        unsigned int dim_out = this->ds->get_output_dim();
//    unsigned int dim_in = this->ds->get_input_dim();
        size_t n_elements = this->ds->get_n_elements();
        double error = 0.0, val;

        std::vector<std::pair<std::vector<double>, std::vector<double>>> *data = this->ds->get_data();

//    //TODO instead use something smarter
//    this->net->copy_weights(weights);

        std::vector<double> output(dim_out);

        for (unsigned int i = 0; i < n_elements; ++i) {  // Iterate through every element in the test set

            this->net->eval_single(data->at(i).first, output,
                                   weights);  // Compute the net output and store it into 'output' variable


//        printf("errors: ");
            for (unsigned int j = 0; j < dim_out; ++j) {  // Compute difference for every element of the output vector

                val = output[j] - data->at(i).second[j];
                error += val * val;

//            printf("%f, ", val * val);
            }
//        printf("\n");

        }

//    printf("n_elements: %d\n", n_elements);
        return error / n_elements;
    }

    ErrorSum::ErrorSum() {
        this->summand = nullptr;
        this->summand_coefficient = nullptr;
        this->dimension = 0;
    }

    ErrorSum::~ErrorSum() {
        if (this->summand) {
            delete this->summand;
        }
        if (this->summand_coefficient) {
            delete this->summand_coefficient;
        }
    }

    double ErrorSum::eval(std::vector<double> *weights) {
        double output = 0.0;

        for (unsigned int i = 0; i < this->summand->size(); ++i) {
            output += this->summand->at(i)->eval(weights) * this->summand_coefficient->at(i);
        }

        return output;
    }

    void ErrorSum::add_error_function(ErrorFunction *F, double alpha) {
        if (!this->summand) {
            this->summand = new std::vector<ErrorFunction *>(0);
        }
        this->summand->push_back(F);

        if (!this->summand_coefficient) {
            this->summand_coefficient = new std::vector<double>(0);
        }
        this->summand_coefficient->push_back(alpha);

        if (F->get_dimension() > this->dimension) {
            this->dimension = F->get_dimension();
        }
    }

    size_t ErrorSum::get_dimension() {
//    if(!this->dimension) {
//        size_t max = 0;
//        for(auto e : *this->summand) {
//            if(e->get_dimension() > max) {
//                max = e->get_dimension();
//            }
//        };
//
//        this->dimension = max;
//    }
        return this->dimension;
    }

}