Skip to content
Snippets Groups Projects
ErrorFunctions.cpp 1.95 KiB
Newer Older
  • Learn to ignore specific revisions
  • //
    // Created by martin on 7/15/18.
    //
    
    
    #include <vector>
    #include <utility>
    
    
    #include "ErrorFunctions.h"
    
    
    size_t ErrorFunction::get_dimension() {
        return this->dimension;
    }
    
    MSE::MSE(NeuralNetwork *net, DataSet *ds) {
        this->net = net;
        this->ds = ds;
        this->dimension = net->get_n_weights();
    }
    
    double MSE::eval(double *weights) {
        unsigned int dim_out = this->ds->get_output_dim();
        unsigned int dim_in = this->ds->get_input_dim();
        size_t n_elements = this->ds->get_n_elements();
        double error = 0.0, val;
    
        std::vector<std::pair<std::vector<double>, std::vector<double>>>* data = this->ds->get_data();
    
        this->net->copy_weights(weights);
    
        std::vector<double> output( dim_out );
    
        for(unsigned int i = 0; i < n_elements; ++i){  // Iterate through every element in the test set
    
            this->net->eval_single(std::get<0>(data->at(i)), output);  // Compute the net output and store it into 'output' variable
    
            for(unsigned int j = 0; j < dim_out; ++j) {  // Compute difference for every element of the output vector
                val = output[j] - std::get<1>(data->at(i))[j];
                error += val * val;
            }
    
        }
    
        return error/n_elements;
    
    }
    
    MSE_SUM::MSE_SUM() {
        this->summand = nullptr;
    }
    
    MSE_SUM::~MSE_SUM(){
        if( this->summand ){
            delete this->summand;
        }
    }
    
    double MSE_SUM::eval(double *weights) {
        double output = 0.0;
    
        for(ErrorFunction *f: *this->summand){
            output += f->eval( weights );
        }
    
        return output;
    }
    
    void MSE_SUM::add_error_function(ErrorFunction *F) {
        if(!this->summand){
            this->summand = new std::vector<ErrorFunction*>(0);
        }
        this->summand->push_back(F);
    
    Martin Beseda's avatar
    Martin Beseda committed
    }
    
    size_t MSE_SUM::get_dimension() {
        if(!this->dimension) {
            size_t max = 0;
            for(auto e : *this->summand) {
                if(e->get_dimension() > max) {
                    max = e->get_dimension();
                }
            };
    
            this->dimension = max;
        }
        return this->dimension;