Skip to content
Snippets Groups Projects
ErrorFunctions.h 1002 B
Newer Older
  • Learn to ignore specific revisions
  • //
    // Created by martin on 7/15/18.
    //
    
    #ifndef INC_4NEURO_ERRORFUNCTION_H
    #define INC_4NEURO_ERRORFUNCTION_H
    
    
    #include "../Network/NeuralNetwork.h"
    #include "../DataSet/DataSet.h"
    
    
    class ErrorFunction {
    
    public:
    
        /**
         *
         * @param weights
         * @return
         */
        virtual double eval(double* weights) = 0;
    
        /**
         * 
         * @return 
         */
        virtual size_t get_dimension();
    
    protected:
    
        /**
         *
         */
        size_t dimension;
    };
    
    class MSE : public ErrorFunction {
    
    public:
        /**
         * Constructor for single neural network
         * @param net
         * @param ds
         */
        MSE(NeuralNetwork* net, DataSet* ds);
    
        /**
         * Constructor for multiple error functions, which will get summed up
         * @param func_vec
         */
        //MSE(std::vector<ErrorFunction> func_vec);
    
        /**
         *
         * @param weights
         * @return
         */
        virtual double eval(double* weights);
    
    private:
    
        NeuralNetwork* net;
        DataSet* ds;
    
    };
    
    
    #endif //INC_4NEURO_ERRORFUNCTION_H