Skip to content
Snippets Groups Projects
ErrorFunctions.h 1002 B
Newer Older
//
// Created by martin on 7/15/18.
//

#ifndef INC_4NEURO_ERRORFUNCTION_H
#define INC_4NEURO_ERRORFUNCTION_H

#include "../Network/NeuralNetwork.h"
#include "../DataSet/DataSet.h"

class ErrorFunction {
public:

    /**
     *
     * @param weights
     * @return
     */
    virtual double eval(double* weights) = 0;
    /**
     * 
     * @return 
     */
    virtual size_t get_dimension();

protected:

    /**
     *
     */
    size_t dimension;
};

class MSE : public ErrorFunction {

public:
    /**
     * Constructor for single neural network
     * @param net
     * @param ds
     */
    MSE(NeuralNetwork* net, DataSet* ds);

    /**
     * Constructor for multiple error functions, which will get summed up
     * @param func_vec
     */
    //MSE(std::vector<ErrorFunction> func_vec);

    /**
     *
     * @param weights
     * @return
     */
    virtual double eval(double* weights);

private:

    NeuralNetwork* net;
    DataSet* ds;
};


#endif //INC_4NEURO_ERRORFUNCTION_H