Skip to content
Snippets Groups Projects
ErrorFunctions.h 2.62 KiB
Newer Older
//
// Created by martin on 7/15/18.
//

#ifndef INC_4NEURO_ERRORFUNCTION_H
#define INC_4NEURO_ERRORFUNCTION_H

#include "../Network/NeuralNetwork.h"
#include "../DataSet/DataSet.h"
class ErrorFunction {
public:

    /**
     *
     * @param weights
     * @return
     */
Michal Kravcenko's avatar
Michal Kravcenko committed
    virtual double eval(std::vector<double>* weights = nullptr) = 0;
    /**
     * 
     * @return 
     */
     virtual size_t get_dimension();

     /**
      *
      * @param params
      * @param grad
      */
     virtual void calculate_error_gradient(std::vector<double> &params, std::vector<double> &grad, double alpha = 1.0 ) = 0;


     /**
      *
      * @return
      */
     virtual std::vector<double>* get_parameters( ) = 0;

     /**
      * //TODO delete after gradient learning is debugged
      * @return
      */
     virtual DataSet* get_dataset() = 0;
Martin Beseda's avatar
Martin Beseda committed
    size_t dimension = 0;
};

class MSE : public ErrorFunction {

public:
    /**
     * Constructor for single neural network
     * @param net
     * @param ds
     */

    /**
     *
     * @param weights
     * @return
     */
     double eval(std::vector<double>* weights = nullptr) override;

     /**
      *
      * @param params
      * @param grad
      */
    void calculate_error_gradient(std::vector<double> &params, std::vector<double> &grad, double alpha = 1.0 ) override;

    /**
     *
     * @return
     */
    std::vector<double>* get_parameters( ) override;

    DataSet* get_dataset() override {
        return this->ds;
    };

private:

    DataSet* ds;
Michal Kravcenko's avatar
Michal Kravcenko committed
class ErrorSum : public ErrorFunction{
     double eval(std::vector<double>* weights = nullptr) override;
     void add_error_function( ErrorFunction *F, double alpha = 1.0 );
Martin Beseda's avatar
Martin Beseda committed
    /**
     *
     * @return
     */
     size_t get_dimension( ) override;

     /**
      *
      * @param params
      * @param grad
      */
    void calculate_error_gradient( std::vector<double> &params, std::vector<double> &grad, double alpha = 1.0 ) override;

    /**
     *
     * @return
     */
    std::vector<double>* get_parameters( ) override;

    DataSet* get_dataset() override {
        return this->summand->at( 0 )->get_dataset();
    };
Martin Beseda's avatar
Martin Beseda committed
    std::vector<ErrorFunction*>* summand;
    std::vector<double> *summand_coefficient;

#endif //INC_4NEURO_ERRORFUNCTION_H