Skip to content
Snippets Groups Projects
ErrorFunctions.h 4.76 KiB
Newer Older
//
// Created by martin on 7/15/18.
//

#ifndef INC_4NEURO_ERRORFUNCTION_H
#define INC_4NEURO_ERRORFUNCTION_H

#include "../Network/NeuralNetwork.h"
#include "../DataSet/DataSet.h"
namespace lib4neuro {

    enum ErrorFunctionType {
        ErrorFuncMSE
    };


    class ErrorFunction {
    public:

        /**
         *
         * @param weights
         * @return
         */
        virtual double eval(std::vector<double> *weights = nullptr) = 0;

        /**
         *
         * @return
         */
        LIB4NEURO_API virtual size_t get_dimension();

        /**
         *
         * @param params
         * @param grad
         */
        virtual void
        calculate_error_gradient(std::vector<double> &params,
                                 std::vector<double> &grad,
                                 double alpha = 1.0,
                                 size_t batch = 0) = 0;

        /**
         *
         * @return
         */
        virtual std::vector<double> *get_parameters() = 0;

        /**
         * //TODO delete after gradient learning is debugged
         * @return
         */
        virtual DataSet* get_dataset() = 0;

        /**
         *
         * @return
         */
        NeuralNetwork* get_network_instance();

        /**
         *
         * @param percent_train
         * @return
         */
        void divide_data_train_test(double percent_test);

        /**
         *
         */
        void return_full_data_set_for_training();

        /**
         *
         */
        virtual double eval_on_test_data(std::vector<double>* weights = nullptr) = 0;
        NeuralNetwork* net = nullptr;

        /**
         *
         */
        DataSet* ds = nullptr;

        /**
         *
         */
        DataSet* ds_full = nullptr;

        /**
         *
         */
        DataSet* ds_test = nullptr;
    };

    class MSE : public ErrorFunction {

    public:
        /**
         * Constructor for single neural network
         * @param net
         * @param ds
         */
        LIB4NEURO_API MSE(NeuralNetwork *net, DataSet *ds);

        /**
         *
         * @param weights
         * @return
         */
        LIB4NEURO_API double eval(std::vector<double> *weights = nullptr) override;

        /**
         *
         * @param params
         * @param grad
         */
        LIB4NEURO_API void
        calculate_error_gradient(std::vector<double> &params,
                                 std::vector<double> &grad,
                                 double alpha = 1.0,
                                 size_t batch = 0) override;

        /**
         *
         * @return
         */
        LIB4NEURO_API std::vector<double> *get_parameters() override;

        LIB4NEURO_API DataSet *get_dataset() override {
            return this->ds;
        };
        LIB4NEURO_API double eval_on_test_data(std::vector<double> *weights = nullptr) override;

        double eval_general(DataSet* data_set, std::vector<double>* weights = nullptr);

    };

    class ErrorSum : public ErrorFunction {
    public:
        /**
         *
         */
        LIB4NEURO_API ErrorSum();

        /**
         *
         */
        LIB4NEURO_API ~ErrorSum();

        /**
         *
         * @param weights
         * @return
         */
        LIB4NEURO_API double eval(std::vector<double> *weights = nullptr) override;
        /**
         *
         * @param weights
         * @return
         */
        LIB4NEURO_API double eval_on_test_data(std::vector<double> *weights = nullptr) override;

        /**
         *
         * @param F
         */
        LIB4NEURO_API void add_error_function(ErrorFunction *F, double alpha = 1.0);

        /**
         *
         * @return
         */
        LIB4NEURO_API size_t get_dimension() override;

        /**
         *
         * @param params
         * @param grad
         */
        LIB4NEURO_API void
        calculate_error_gradient(std::vector<double> &params,
                                 std::vector<double> &grad,
                                 double alpha = 1.0,
                                 size_t batch = 0) override;
        /**
         *
         * @return
         */
        LIB4NEURO_API std::vector<double> *get_parameters() override;

        LIB4NEURO_API DataSet *get_dataset() override {
            return this->summand->at(0)->get_dataset();
        };

    private:
        std::vector<ErrorFunction *> *summand;
        std::vector<double> *summand_coefficient;
    };

}

#endif //INC_4NEURO_ERRORFUNCTION_H