#include <vector>
#include <map>
#include <utility>
#include <cmath>
#include <sstream>
#include <boost/random/mersenne_twister.hpp>
#include <boost/random/uniform_int_distribution.hpp>



#include "ErrorFunctions.h"
#include "exceptions.h"
#include "message.h"
#include "../mpi_wrapper.h"

namespace lib4neuro {

    size_t ErrorFunction::get_dimension() {
        return this->dimension;
    }

    std::vector<NeuralNetwork*>& ErrorFunction::get_nets() {
        return nets;
    }

    DataSet* ErrorFunction::get_dataset() const {
        return this->ds;
    }

    void ErrorFunction::set_dataset(DataSet* ds) {
        this->ds = ds;
    }

    void ErrorFunction::divide_data_train_test(double percent_test) {
        size_t ds_size = this->ds->get_n_elements();

        /* Store the full data set */
        this->ds_full = this->ds;

        /* Choose random subset of the DataSet for training and the remaining part for validation */
        boost::random::mt19937                    gen;
        boost::random::uniform_int_distribution<> dist(0,
                                                       ds_size - 1);

        size_t test_set_size = ceil(ds_size * percent_test);

        std::vector<unsigned int> test_indices;
        test_indices.reserve(test_set_size);
        for (size_t i = 0; i < test_set_size; i++) {
            test_indices.emplace_back(dist(gen));
        }
        std::sort(test_indices.begin(),
                  test_indices.end(),
                  std::greater<unsigned int>());

        std::vector<std::pair<std::vector<double>, std::vector<double>>> test_data, train_data;

        /* Copy all the data to train_data */
        for (auto e : *this->ds_full->get_data()) {
            train_data.emplace_back(e);
        }

        /* Move the testing data from train_data to test_data */
        for (auto ind : test_indices) {
            test_data.emplace_back(train_data.at(ind));
        }
        for(auto ind : test_indices) {
            train_data.erase(train_data.begin() + ind);
        }

        /* Re-initialize data set for training */
        this->ds = new DataSet(&train_data,
                               this->ds_full->get_normalization_strategy());

        /* Initialize test data */
        this->ds_test = new DataSet(&test_data,
                                    this->ds_full->get_normalization_strategy());
    }

    size_t ErrorFunction::divide_data_worst_subset(
            std::vector<size_t> &subset_indices,
            std::vector<int> &active_subset,
            std::vector<float> &entry_errors,
			size_t expansion_len,
			float tolerance
    ) {
        if( this->ds_full == nullptr ){
            this->ds_full = this->ds;
        }
        size_t ds_size = this->ds_full->get_n_elements();

        if( entry_errors.size() != ds_size ){
            entry_errors.resize( ds_size );
        }

        if( active_subset.size() != ds_size ){
            active_subset.resize( ds_size );
            std::fill(active_subset.begin(), active_subset.end(), 0);
        }

        std::vector<double> error_vector( this->get_n_outputs());
        for( size_t i = 0; i < ds_size; ++i ) {
            entry_errors[ i ] = this->eval_single_item_by_idx( i, nullptr, error_vector );
        }

        std::vector<std::pair<std::vector<double>, std::vector<double>>> train_set;
		std::map<double, size_t> expanded_set;
		for( size_t i = 0; i < ds_size; ++i ){
            if( active_subset[ i ] > 0 ){
                continue;
            }

			// expanded_set[ entry_errors[ i ] ] = i;
			expanded_set.insert(std::pair<double, size_t>(entry_errors[ i ], i));

			if( expanded_set.size() > expansion_len ){
				expanded_set.erase( expanded_set.begin( ) );
			}
		}

		subset_indices.clear( );
		for( size_t i = 0; i < ds_size; ++i ){
			if( entry_errors[ i ] < tolerance ){
				/* entry is very well learned, we leave it aside, for now */
				active_subset[ i ] = 0;
			}

			if( active_subset[ i ] > 0 ){
				subset_indices.push_back( i );
			}
		}


        double max_error = -1.0;
        size_t max_error_entry_idx = 0;
		for( auto el: expanded_set ){
			max_error = el.first;
			max_error_entry_idx = el.second;

			subset_indices.push_back( el.second );
		}


        for( auto el: subset_indices ){
            train_set.emplace_back( this->ds_full->get_data( )->at( el ) );
            active_subset[el]++;

            if(active_subset[el] > 2){
                active_subset[el] = 2;
            }
        }

        if( this->ds != this->ds_full ){
            delete this->ds;
        }
        this->ds = new DataSet(&train_set,
                           this->ds_full->get_normalization_strategy());

        return train_set.size( );

    }

    void ErrorFunction::return_full_data_set_for_training() {
        if (this->ds_test || this->ds != this->ds_full) {
//            delete this->ds;
            this->ds = this->ds_full;
            this->ds_full = nullptr;
        }
    }
	
	void ErrorFunction::add_to_hessian_and_rhs_single(
		arma::Mat<double>& H,
		arma::Col<double>& rhs,
		size_t entry_idx
	){
		
        std::vector<double> partial_error(this->get_n_outputs());
		std::vector<std::vector<double>> jac_loc;
		
		this->nets[0]->get_jacobian( jac_loc, this->ds->get_data()->at( entry_idx ), partial_error );
		
		auto nr = jac_loc.size();
		auto nc = this->get_dimension();

		for (size_t ri = 0; ri < nr; ++ri) {
			auto alpha = partial_error[ri];
			for (size_t ci = 0; ci < nc; ++ci) {
				rhs.at(ci) += alpha * jac_loc[ri][ci];
			}
		}
		
		double rval, cval;
		for( auto d = 0; d < nr; ++d ){
			for( auto ci = 0; ci < nc; ++ci ){
				cval = jac_loc[d][ci];
				for( auto ri = 0; ri < nc; ++ri ){
					rval = jac_loc[d][ri];
					H.at(ri, ci) += rval * cval;
				}
			}
		}
	}

    void MSE::get_jacobian_and_rhs(std::vector<std::vector<double>>& jacobian,
                                   std::vector<double>& rhs) {
        this->get_jacobian_and_rhs(jacobian, rhs, *this->ds->get_data());

    }

    void MSE::get_jacobian_and_rhs(std::vector<std::vector<double>>& jacobian,
                                   std::vector<double>& rhs,
                                   std::vector<std::pair<std::vector<double>, std::vector<double>>>& data) {
//        size_t row_idx = 0;
        std::vector<double> partial_error(this->get_n_outputs());
        rhs.resize(this->get_dimension());
        std::fill(rhs.begin(),
                  rhs.end(),
                  0.0);

        std::vector<std::vector<double>> jac_loc;
        for (auto                        item : data) {

            this->nets[0]->get_jacobian(jac_loc,
                                        item,
                                        partial_error);

            for (size_t ri = 0; ri < jac_loc.size(); ++ri) {
                jacobian.push_back(jac_loc[ri]);

                for (size_t ci = 0; ci < this->get_dimension(); ++ci) {
//                    J.at(row_idx,
//                         ci) = jacobian[ri][ci];
                    rhs.at(ci) += partial_error[ri] * jac_loc[ri][ci];
                }
//                row_idx++;
            }
        }
    }

    MSE::MSE(NeuralNetwork* net,
             DataSet* ds,
             bool rescale_error
             ) {
        this->nets.push_back(net);
        this->ds        = ds;
        this->dimension = net->get_n_weights() + net->get_n_biases();
        this->rescale_error = rescale_error;
    }

    double MSE::eval_on_single_input(std::vector<double>* input,
                                     std::vector<double>* output,
                                     std::vector<double>* weights) {
        std::vector<double> predicted_output(this->nets[0]->get_n_outputs());
        this->nets[0]->eval_single(*input,
                                   predicted_output,
                                   weights);
        double result = 0;
        double val;

        for (size_t i = 0; i < output->size(); i++) {
            val = output->at(i) - predicted_output.at(i);
            result += val * val;
        }

        return result;
    }

    double MSE::eval_on_data_set(lib4neuro::DataSet* data_set,
                                 std::ofstream* results_file_path,
                                 std::vector<double>* weights,
                                 bool verbose
    ) {
        size_t dim_in  = data_set->get_input_dim();
        size_t dim_out = data_set->get_output_dim();
        double error   = 0.0, val, output_norm = 0;

        std::vector<std::pair<std::vector<double>, std::vector<double>>>* data = data_set->get_data();
        size_t n_elements = data->size();

        //TODO instead use something smarter
        std::vector<std::vector<double>> outputs(data->size());
        std::vector<double>              output(dim_out);

        if (verbose) {
            COUT_DEBUG("Evaluation of the error function MSE on the given data-set" << std::endl);
            COUT_DEBUG(R_ALIGN << "[Element index]" << " "
                               << R_ALIGN << "[Input]" << " "
                               << R_ALIGN << "[Real output]" << " "
                               << R_ALIGN << "[Predicted output]" << " "
                               << R_ALIGN << "[Absolute error]" << " "
                               << R_ALIGN << "[Relative error %]"
                               << std::endl);
        }

        if (results_file_path) {
            *results_file_path << R_ALIGN << "[Element index]" << " "
                               << R_ALIGN << "[Input]" << " "
                               << R_ALIGN << "[Real output]" << " "
                               << R_ALIGN << "[Predicted output]" << " "
                               << R_ALIGN << "[Abs. error]" << " "
                               << R_ALIGN << "[Rel. error %]"
                               << std::endl;
        }

        for (size_t i = 0; i < data->size(); i++) {  // Iterate through every element in the test set
            /* Compute the net output and store it into 'output' variable */
            this->nets[0]->eval_single(data->at(i).first,
                                       output,
                                       weights);
            outputs.at(i) = output;
        }

        double denormalized_output;
        double denormalized_real_input;
        double denormalized_real_output;

        std::string separator = "";
        for (size_t i         = 0; i < data->size(); i++) {

            /* Compute difference for every element of the output vector */
            std::stringstream ss_input;
            for (size_t       j = 0; j < dim_in; j++) {
                denormalized_real_input = data_set->get_denormalized_value(data->at(i).first.at(j));
                ss_input << separator << denormalized_real_input;
                separator = ",";
            }

            std::stringstream ss_real_output;
            std::stringstream ss_predicted_output;

            double loc_error = 0;
            output_norm = 0;
            separator   = "";
            for (size_t j = 0; j < dim_out; ++j) {
                denormalized_real_output = data_set->get_denormalized_value(data->at(i).second.at(j));
                denormalized_output      = data_set->get_denormalized_value(outputs.at(i).at(j));

                ss_real_output << separator << denormalized_real_output;
                ss_predicted_output << separator << denormalized_output;
                separator = ",";

                val = denormalized_output - denormalized_real_output;
                loc_error += val * val;
                error += loc_error;

                output_norm += denormalized_output * denormalized_output;
            }
//            std::cout << " entry #" << i+1 << ", error: " << loc_error << std::endl;


            std::stringstream ss_ind;
            ss_ind << "[" << i << "]";
#ifdef L4N_DEBUG

            if (verbose) {
                COUT_DEBUG(R_ALIGN << ss_ind.str() << " "
                                   << R_ALIGN << ss_input.str() << " "
                                   << R_ALIGN << ss_real_output.str() << " "
                                   << R_ALIGN << ss_predicted_output.str() << " "
                                   << R_ALIGN << loc_error << " "
                                   << R_ALIGN
                                   << 200.0 * loc_error / (loc_error + output_norm)
                                   << std::endl);
            }


#endif
            if (results_file_path) {
                *results_file_path << R_ALIGN << ss_ind.str() << " "
                                   << R_ALIGN << ss_input.str() << " "
                                   << R_ALIGN << ss_real_output.str() << " "
                                   << R_ALIGN << ss_predicted_output.str() << " "
                                   << R_ALIGN << loc_error << " "
                                   << R_ALIGN
                                   << 200.0 * loc_error / (loc_error + output_norm)
                                   << std::endl;
            }
        }

        double result = error / (this->rescale_error?n_elements:1.0);
		// for( int pi = 0; pi < lib4neuro::mpi_nranks; ++pi ){
			// if( lib4neuro::mpi_rank == pi ){
				// std::cout <<std::setprecision(6) << "[" << pi << "]eval: " << result << std::endl;
			// }
			// MPI_Barrier( lib4neuro::mpi_active_comm );
		// }
		

        if (verbose) {
            COUT_DEBUG("MSE = " << result << std::endl);
        }

        if (results_file_path) {
            *results_file_path << "MSE = " << result << std::endl;
        }

        return result;
    }

    double MSE::eval_on_data_set(DataSet* data_set,
                                 std::string results_file_path,
                                 std::vector<double>* weights,
                                 bool verbose) {
        std::ofstream ofs(results_file_path);
        if (ofs.is_open()) {
            return this->eval_on_data_set(data_set,
                                          &ofs,
                                          weights,

                                          verbose);
            ofs.close();
        } else {
            THROW_RUNTIME_ERROR("File " + results_file_path + " couldn't be open!");
        }

        return -1.0;
    }

    double MSE::eval_on_data_set(DataSet* data_set,
                                 std::vector<double>* weights,
                                 bool verbose) {
        return this->eval_on_data_set(data_set,
                                      nullptr,
                                      weights,

                                      verbose);
    }

    double MSE::eval(std::vector<double>* weights,
                     bool denormalize_data,
                     bool verbose) {

        double out = this->eval_on_data_set(this->ds,
                                      nullptr,
                                      weights,
                                      verbose);

        MPI_Allreduce( MPI_IN_PLACE, &out, 1, MPI_DOUBLE, MPI_SUM, lib4neuro::mpi_active_comm );
        return out;
    }

    double MSE::eval_on_test_data(std::vector<double>* weights,
                                  bool verbose) {
        return this->eval_on_data_set(this->ds_test,
                                      weights,
                                      verbose);
    }

    double MSE::eval_on_test_data(std::string results_file_path,
                                  std::vector<double>* weights,
                                  bool verbose) {
        return this->eval_on_data_set(this->ds_test,
                                      results_file_path,
                                      weights,
                                      verbose);
    }

    double MSE::eval_on_test_data(std::ofstream* results_file_path,
                                  std::vector<double>* weights,
                                  bool verbose) {
        return this->eval_on_data_set(this->ds_test,
                                      results_file_path,
                                      weights,

                                      verbose);
    }

    void
    MSE::calculate_error_gradient(std::vector<double>& params,
                                  std::vector<double>& grad,
                                  double alpha,
                                  size_t batch) {

        size_t dim_out    = this->ds->get_output_dim();
        size_t n_elements = this->ds->get_n_elements();
        std::vector<std::pair<std::vector<double>, std::vector<double>>>* data = this->ds->get_data();

        std::fill(grad.begin(), grad.end(), 0.0);

        if (batch > 0) {
            *data = this->ds->get_random_data_batch(batch);
            n_elements = data->size();
        }
        std::vector<double> error_derivative(dim_out);

        for (auto el: *data) {  // Iterate through every element in the test set

            this->nets[0]->eval_single(el.first,
                                       error_derivative,
                                       &params);  // Compute the net output and store it into 'output' variable

            for (size_t j = 0; j < dim_out; ++j) {
                error_derivative.at(j) = 2.0 * (error_derivative.at(j) - el.second.at(j)); //real - expected result
            }

            this->nets[0]->add_to_gradient_single(el.first,
                                                  error_derivative,
                                                  alpha / (this->rescale_error?n_elements:1.0),
                                                  grad);
        }

        MPI_Allreduce( MPI_IN_PLACE, &grad[0], grad.size(), MPI_DOUBLE, MPI_SUM, lib4neuro::mpi_active_comm );
    }

    double MSE::calculate_single_residual(std::vector<double>* input,
                                          std::vector<double>* output,
                                          std::vector<double>* parameters) {

        //TODO maybe move to the general ErrorFunction
        //TODO check input vector sizes - they HAVE TO be allocated before calling this function

        return -this->eval_on_single_input(input,
                                           output,
                                           parameters);
    }

    void MSE::calculate_residual_gradient(std::vector<double>* input,
                                          std::vector<double>* output,
                                          std::vector<double>* gradient,
                                          double h) {

        //TODO check input vector sizes - they HAVE TO be allocated before calling this function

        size_t              n_parameters = this->get_dimension();
        std::vector<double> parameters   = this->get_parameters();

        double delta;  // Complete step size
        double former_parameter_value;
        double f_val1;  // f(x + delta)
        double f_val2;  // f(x - delta)

        for (size_t i = 0; i < n_parameters; i++) {
            delta                  = h * (1 + std::abs(parameters.at(i)));
            former_parameter_value = parameters.at(i);

            if (delta != 0) {
                /* Computation of f_val1 = f(x + delta) */
                parameters.at(i) = former_parameter_value + delta;
                f_val1 = this->calculate_single_residual(input,
                                                         output,
                                                         &parameters);

                /* Computation of f_val2 = f(x - delta) */
                parameters.at(i) = former_parameter_value - delta;
                f_val2 = this->calculate_single_residual(input,
                                                         output,
                                                         &parameters);

                gradient->at(i) = (f_val1 - f_val2) / (2 * delta);
            }

            /* Restore parameter to the former value */
            parameters.at(i) = former_parameter_value;
        }
    }

    void MSE::calculate_error_gradient_single(std::vector<double>& error_vector,
                                              std::vector<double>& gradient_vector) {
        std::fill(gradient_vector.begin(),
                  gradient_vector.end(),
                  0);
        std::vector<double> dummy_input;
        this->nets[0]->add_to_gradient_single(dummy_input,
                                              error_vector,
                                              1.0,
                                              gradient_vector);
    }

    void
    MSE::analyze_error_gradient(std::vector<double>& params,
                                std::vector<double>& grad,
                                double alpha,
                                size_t batch) {

        size_t dim_out    = this->ds->get_output_dim();
        size_t n_elements = this->ds->get_n_elements();
        std::vector<std::pair<std::vector<double>, std::vector<double>>>* data = this->ds->get_data();

        if (batch > 0) {
            *data = this->ds->get_random_data_batch(batch);
            n_elements = data->size();
        }
        std::vector<double> error_derivative(dim_out);

        std::vector<double> grad_sum(grad.size());
        std::fill(grad_sum.begin(),
                  grad_sum.end(),
                  0.0);
        this->nets[0]->write_weights();
        this->nets[0]->write_biases();
        for (auto el: *data) {  // Iterate through every element in the test set

            this->nets[0]->eval_single_debug(el.first,
                                             error_derivative,
                                             &params);  // Compute the net output and store it into 'output' variable
            std::cout << "Input[";
            for (auto v: el.first) {
                std::cout << v << ", ";
            }
            std::cout << "]";

            std::cout << " Desired Output[";
            for (auto v: el.second) {
                std::cout << v << ", ";
            }
            std::cout << "]";

            std::cout << " Real Output[";
            for (auto v: error_derivative) {
                std::cout << v << ", ";
            }
            std::cout << "]";

            for (size_t j = 0; j < dim_out; ++j) {
                error_derivative.at(j) = 2.0 * (error_derivative.at(j) - el.second.at(j)); //real - expected result
            }
            std::cout << " Error derivative[";
            for (auto v: error_derivative) {
                std::cout << v << ", ";
            }
            std::cout << "]";

            std::fill(grad.begin(),
                      grad.end(),
                      0.0);
            this->nets[0]->add_to_gradient_single_debug(el.first,
                                                        error_derivative,
                                                        1.0,
                                                        grad);
            for (size_t i = 0; i < grad.size(); ++i) {
                grad_sum.at(i) += grad.at(i);
            }

            std::cout << " Gradient[";
            for (auto v: grad) {
                std::cout << v << ", ";
            }
            std::cout << "]";

            std::cout << std::endl;
        }
        std::cout << " Total gradient[";
        for (auto v: grad_sum) {
            std::cout << v << ", ";
        }
        std::cout << "]" << std::endl << std::endl;
    }

    double MSE::eval_single_item_by_idx(size_t i,
                                        std::vector<double>* parameter_vector,
                                        std::vector<double>& error_vector) {
        double output = 0, val;

        this->nets[0]->eval_single(this->ds->get_data()->at(i).first,
                                   error_vector,
                                   parameter_vector);

        for (size_t j = 0; j < error_vector.size(); ++j) {  // Compute difference for every element of the output vector
            val = error_vector.at(j) - this->ds->get_data()->at(i).second.at(j);
            output += val * val;
        }

        for (size_t j = 0; j < error_vector.size(); ++j) {
            error_vector.at(j) =
                2.0 * (error_vector.at(j) - this->ds->get_data()->at(i).second.at(j)); //real - expected result
        }

        return output;
    }


    std::vector<double> MSE::get_parameters() {
        std::vector<double> output(this->get_dimension());
        for (size_t         i = 0; i < this->nets[0]->get_n_weights(); ++i) {
            output[i] = this->nets[0]->get_parameter_ptr_weights()->at(i);
        }
        for (size_t         i = 0; i < this->nets[0]->get_n_biases(); ++i) {
            output[i + this->nets[0]->get_n_weights()] = this->nets[0]->get_parameter_ptr_biases()->at(i);
        }
        return output;
    }

    void MSE::set_parameters(std::vector<double>& params) {
        this->nets[0]->copy_parameter_space(&params);
    }

    size_t MSE::get_n_data_set() {
        return this->ds->get_n_elements();
    }

    size_t MSE::get_n_test_data_set() {
        return this->ds_test->get_n_elements();
    }

    size_t MSE::get_n_outputs() {
        return this->nets[0]->get_n_outputs();
    }

    void MSE::randomize_parameters(double scaling) {
        this->nets[0]->randomize_parameters();
        this->nets[0]->scale_parameters(scaling);
    }

    ErrorSum::ErrorSum() {
        this->summand   = nullptr;
        this->dimension = 0;
    }

    ErrorSum::~ErrorSum() {
        if (this->summand) {

            for (auto el: *this->summand) {
                if (el) {
                    delete el;
                }
            }

            delete this->summand;
        }
    }

    double ErrorSum::eval_on_test_data(std::vector<double>* weights,
                                       bool verbose) {
        //TODO take care of the case, when there are no test data

        double output = 0.0;
        ErrorFunction* ef = nullptr;

        for (unsigned int i = 0; i < this->summand->size(); ++i) {
            ef = this->summand->at(i);

            if (ef) {
                output += ef->eval_on_test_data(weights) * this->summand_coefficient.at(i);
            }
        }

        MPI_Allreduce( MPI_IN_PLACE, &output, 1, MPI_DOUBLE, MPI_SUM, lib4neuro::mpi_active_comm );

        return output;
    }

    double ErrorSum::eval_on_test_data(std::string results_file_path,
                                       std::vector<double>* weights,
                                       bool verbose) {
        THROW_NOT_IMPLEMENTED_ERROR();

        return -1;
    }

    double ErrorSum::eval_on_test_data(std::ofstream* results_file_path,
                                       std::vector<double>* weights,
                                       bool verbose) {
        THROW_NOT_IMPLEMENTED_ERROR();
        return -1;
    }

    double ErrorSum::eval_on_data_set(lib4neuro::DataSet* data_set,
                                      std::vector<double>* weights,
                                      bool verbose) {
        THROW_NOT_IMPLEMENTED_ERROR();

        return -1;
    }

    double ErrorSum::eval_on_data_set(lib4neuro::DataSet* data_set,
                                      std::string results_file_path,
                                      std::vector<double>* weights,
                                      bool verbose) {
        THROW_NOT_IMPLEMENTED_ERROR();

        return -1;
    }

    double ErrorSum::eval_on_data_set(lib4neuro::DataSet* data_set,
                                      std::ofstream* results_file_path,
                                      std::vector<double>* weights,
                                      bool verbose) {
        THROW_NOT_IMPLEMENTED_ERROR();
        return -1;
    }

    double ErrorSum::eval(std::vector<double>* weights,
                          bool denormalize_data,
                          bool verbose) {
        double output = 0.0;
        ErrorFunction* ef = nullptr;

        for (unsigned int i = 0; i < this->summand->size(); ++i) {
            ef = this->summand->at(i);

            if (ef) {
                output += ef->eval(weights) * this->summand_coefficient.at(i);
            }
        }


        return output;
    }

    double ErrorSum::eval_single_item_by_idx(size_t i,
                                             std::vector<double>* parameter_vector,
                                             std::vector<double>& error_vector) {
        double output = 0.0;
        ErrorFunction* ef     = nullptr;
        std::fill(error_vector.begin(),
                  error_vector.end(),
                  0);

        std::vector<double> error_vector_mem(error_vector.size());
        for (size_t         j = 0; j < this->summand->size(); ++j) {
            ef = this->summand->at(i);

            if (ef) {
                output += ef->eval_single_item_by_idx(i,
                                                      parameter_vector,
                                                      error_vector_mem) * this->summand_coefficient.at(j);

                for (size_t k = 0; k < error_vector_mem.size(); ++k) {
                    error_vector[k] += error_vector_mem[k] * this->summand_coefficient.at(j);
                }
            }
        }

        return output;
    }

    void ErrorSum::calculate_error_gradient(std::vector<double>& params,
                                            std::vector<double>& grad,
                                            double alpha,
                                            size_t batch) {

        ErrorFunction* ef = nullptr;
        for (size_t i = 0; i < this->summand->size(); ++i) {
            ef = this->summand->at(i);

            if (ef) {
                ef->calculate_error_gradient(params,
                                             grad,
                                             this->summand_coefficient.at(i) * alpha,
                                             batch);
            }
        }
    }

    void ErrorSum::calculate_error_gradient_single(std::vector<double>& error_vector,
                                                   std::vector<double>& gradient_vector) {
        COUT_INFO("ErrorSum::calculate_error_gradient_single NOT YET IMPLEMENTED!!!");
    }

    void ErrorSum::analyze_error_gradient(std::vector<double>& params,
                                          std::vector<double>& grad,
                                          double alpha,
                                          size_t batch) {

        ErrorFunction* ef = nullptr;
        for (size_t i = 0; i < this->summand->size(); ++i) {
            ef = this->summand->at(i);

            if (ef) {
                ef->calculate_error_gradient(params,
                                             grad,
                                             this->summand_coefficient.at(i) * alpha,
                                             batch);
            }
        }
    }

    void ErrorSum::add_error_function(ErrorFunction* F,
                                      double alpha) {
        if (!this->summand) {
            this->summand = new std::vector<ErrorFunction*>(0);
        }
        this->summand->push_back(F);

        this->summand_coefficient.push_back(alpha);

        if (F) {
            if (F->get_dimension() > this->dimension) {
                this->dimension = F->get_dimension();
            }
        }
    }

    size_t ErrorSum::get_dimension() {
        return this->dimension;
    }

    std::vector<double> ErrorSum::get_parameters() {
        return this->summand->at(0)->get_parameters();
    }

    void ErrorSum::set_parameters(std::vector<double>& params) {
        //TODO may cause problems for general error sum...
        for (auto n: *this->summand) {
            n->set_parameters(params);
        }
    }


    void ErrorSum::calculate_residual_gradient(std::vector<double>* input,
                                               std::vector<double>* output,
                                               std::vector<double>* gradient,
                                               double h) {
        THROW_NOT_IMPLEMENTED_ERROR();
    }

    double ErrorSum::calculate_single_residual(std::vector<double>* input,
                                               std::vector<double>* output,
                                               std::vector<double>* parameters) {
        THROW_NOT_IMPLEMENTED_ERROR();

        return 0;
    }

    double ErrorSum::eval_on_single_input(std::vector<double>* input,
                                          std::vector<double>* output,
                                          std::vector<double>* weights) {
        double o = 0.0;

        for (size_t i = 0; i < this->summand->size(); ++i) {
            o += this->summand->at(i)->eval_on_single_input(input,
                                                            output,
                                                            weights) * this->summand_coefficient.at(i);
        }

        return o;
    }

    size_t ErrorSum::get_n_data_set() {
        size_t o = 0;

        for (size_t i = 0; i < this->summand->size(); ++i) {
            o += this->summand->at(i)->get_n_data_set();
        }

        //TODO how is this function being used? should this be across all MPI ranks?

        return o;
    }

    size_t ErrorSum::get_n_test_data_set() {
        size_t o = 0;

        for (size_t i = 0; i < this->summand->size(); ++i) {
            o += this->summand->at(i)->get_n_test_data_set();
        }

        return o;
    }

    size_t ErrorSum::get_n_outputs() {
        size_t o = 0;

        for (size_t i = 0; i < this->summand->size(); ++i) {
            o += this->summand->at(i)->get_n_outputs();
        }

        return o;
    }

    void ErrorSum::divide_data_train_test(double percent) {
        for (auto n: *this->summand) {
            n->divide_data_train_test(percent);
        }
    }

    size_t ErrorSum::divide_data_worst_subset(
            std::vector<size_t> &subset_indices,
            std::vector<int> &active_subset,
            std::vector<float> &entry_errors,
			size_t expansion_len,
			float tolerance
    ) {
        size_t output = 0;
        assert( false );
        return output;
    }



    void ErrorSum::return_full_data_set_for_training() {
        for (auto n: *this->summand) {
            n->return_full_data_set_for_training();
        }
    }

    void ErrorSum::get_jacobian_and_rhs(std::vector<std::vector<double>>& jacobian,
                                        std::vector<double>& rhs) {
        for (auto n: *this->summand) {
            std::vector<double> rhs_loc;
            n->get_jacobian_and_rhs(jacobian,
                                    rhs_loc);

            size_t curr_size = rhs.size();
            rhs.resize(curr_size + rhs_loc.size());
            for (size_t i = 0; i < rhs_loc.size(); ++i) {
                rhs.at(i + curr_size) = rhs_loc.at(i);
            }
        }
    }

    void ErrorSum::get_jacobian_and_rhs(std::vector<std::vector<double>>& jacobian,
                                        std::vector<double>& rhs,
                                        std::vector<std::pair<std::vector<double>, std::vector<double>>>& data) {
        THROW_NOT_IMPLEMENTED_ERROR();
    }

    void ErrorSum::randomize_parameters(double scaling) {
        for (auto n: *this->summand) {
            n->randomize_parameters(scaling);
        }
    }

}