Skip to content
Snippets Groups Projects
ErrorFunctions.h 2.62 KiB
Newer Older
  • Learn to ignore specific revisions
  • //
    // Created by martin on 7/15/18.
    //
    
    #ifndef INC_4NEURO_ERRORFUNCTION_H
    #define INC_4NEURO_ERRORFUNCTION_H
    
    
    #include "../Network/NeuralNetwork.h"
    #include "../DataSet/DataSet.h"
    
    class ErrorFunction {
    
    public:
    
        /**
         *
         * @param weights
         * @return
         */
    
    Michal Kravcenko's avatar
    Michal Kravcenko committed
        virtual double eval(std::vector<double>* weights = nullptr) = 0;
    
        /**
         * 
         * @return 
         */
    
         virtual size_t get_dimension();
    
         /**
          *
          * @param params
          * @param grad
          */
         virtual void calculate_error_gradient(std::vector<double> &params, std::vector<double> &grad, double alpha = 1.0 ) = 0;
    
    
         /**
          *
          * @return
          */
         virtual std::vector<double>* get_parameters( ) = 0;
    
         /**
          * //TODO delete after gradient learning is debugged
          * @return
          */
         virtual DataSet* get_dataset() = 0;
    
    Martin Beseda's avatar
    Martin Beseda committed
        size_t dimension = 0;
    
    };
    
    class MSE : public ErrorFunction {
    
    public:
        /**
         * Constructor for single neural network
         * @param net
         * @param ds
         */
    
    
        /**
         *
         * @param weights
         * @return
         */
    
         double eval(std::vector<double>* weights = nullptr) override;
    
         /**
          *
          * @param params
          * @param grad
          */
        void calculate_error_gradient(std::vector<double> &params, std::vector<double> &grad, double alpha = 1.0 ) override;
    
        /**
         *
         * @return
         */
        std::vector<double>* get_parameters( ) override;
    
        DataSet* get_dataset() override {
            return this->ds;
        };
    
    
    private:
    
        DataSet* ds;
    
    Michal Kravcenko's avatar
    Michal Kravcenko committed
    class ErrorSum : public ErrorFunction{
    
         double eval(std::vector<double>* weights = nullptr) override;
    
         void add_error_function( ErrorFunction *F, double alpha = 1.0 );
    
    Martin Beseda's avatar
    Martin Beseda committed
        /**
         *
         * @return
         */
    
         size_t get_dimension( ) override;
    
         /**
          *
          * @param params
          * @param grad
          */
        void calculate_error_gradient( std::vector<double> &params, std::vector<double> &grad, double alpha = 1.0 ) override;
    
        /**
         *
         * @return
         */
        std::vector<double>* get_parameters( ) override;
    
        DataSet* get_dataset() override {
            return this->summand->at( 0 )->get_dataset();
        };
    
    Martin Beseda's avatar
    Martin Beseda committed
        std::vector<ErrorFunction*>* summand;
    
        std::vector<double> *summand_coefficient;
    
    
    #endif //INC_4NEURO_ERRORFUNCTION_H