ErrorFunctions.cpp 1.66 KB
Newer Older
1 2 3 4
//
// Created by martin on 7/15/18.
//

Martin Beseda's avatar
Martin Beseda committed
5 6 7
#include <vector>
#include <utility>

8
#include "ErrorFunctions.h"
Martin Beseda's avatar
Martin Beseda committed
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

size_t ErrorFunction::get_dimension() {
    return this->dimension;
}

MSE::MSE(NeuralNetwork *net, DataSet *ds) {
    this->net = net;
    this->ds = ds;
    this->dimension = net->get_n_weights();
}

double MSE::eval(double *weights) {
    unsigned int dim_out = this->ds->get_output_dim();
    unsigned int dim_in = this->ds->get_input_dim();
    size_t n_elements = this->ds->get_n_elements();
    double error = 0.0, val;

    std::vector<std::pair<std::vector<double>, std::vector<double>>>* data = this->ds->get_data();

    this->net->copy_weights(weights);

    std::vector<double> output( dim_out );

    for(unsigned int i = 0; i < n_elements; ++i){  // Iterate through every element in the test set

        this->net->eval_single(std::get<0>(data->at(i)), output);  // Compute the net output and store it into 'output' variable

        for(unsigned int j = 0; j < dim_out; ++j) {  // Compute difference for every element of the output vector
            val = output[j] - std::get<1>(data->at(i))[j];
            error += val * val;
        }

    }

    return error/n_elements;
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
}

MSE_SUM::MSE_SUM() {
    this->summand = nullptr;
}

MSE_SUM::~MSE_SUM(){
    if( this->summand ){
        delete this->summand;
    }
}

double MSE_SUM::eval(double *weights) {
    double output = 0.0;

    for(ErrorFunction *f: *this->summand){
        output += f->eval( weights );
    }

    return output;
}

void MSE_SUM::add_error_function(ErrorFunction *F) {
    if(!this->summand){
        this->summand = new std::vector<ErrorFunction*>(0);
    }
    this->summand->push_back(F);
Martin Beseda's avatar
Martin Beseda committed
71
}