//
// Created by martin on 7/15/18.
//

#ifndef INC_4NEURO_ERRORFUNCTION_H
#define INC_4NEURO_ERRORFUNCTION_H

#include "../settings.h"
#include "../Network/NeuralNetwork.h"
#include "../DataSet/DataSet.h"


enum ErrorFunctionType{
    ErrorFuncMSE
};


class ErrorFunction {
public:

    /**
     *
     * @param weights
     * @return
     */
    virtual double eval(std::vector<double>* weights = nullptr) = 0;
    
    /**
     * 
     * @return 
     */
     LIB4NEURO_API virtual size_t get_dimension();

     /**
      *
      * @param params
      * @param grad
      */
     virtual void calculate_error_gradient(std::vector<double> &params, std::vector<double> &grad, double alpha = 1.0 ) = 0;


     /**
      *
      * @return
      */
     virtual std::vector<double>* get_parameters( ) = 0;

     /**
      * //TODO delete after gradient learning is debugged
      * @return
      */
     virtual DataSet* get_dataset() = 0;

protected:

    /**
     *
     */
    size_t dimension = 0;

    /**
     *
     */
    NeuralNetwork* net = nullptr;
};

class MSE : public ErrorFunction {

public:
    /**
     * Constructor for single neural network
     * @param net
     * @param ds
     */
     LIB4NEURO_API MSE(NeuralNetwork* net, DataSet* ds);

    /**
     *
     * @param weights
     * @return
     */
     LIB4NEURO_API double eval(std::vector<double>* weights = nullptr) override;

     /**
      *
      * @param params
      * @param grad
      */
    LIB4NEURO_API void calculate_error_gradient(std::vector<double> &params, std::vector<double> &grad, double alpha = 1.0 ) override;

    /**
     *
     * @return
     */
    LIB4NEURO_API std::vector<double>* get_parameters( ) override;

    LIB4NEURO_API DataSet* get_dataset() override {
        return this->ds;
    };

private:

    DataSet* ds;
};

class ErrorSum : public ErrorFunction{
public:
    /**
     *
     */
     LIB4NEURO_API ErrorSum();

    /**
     *
     */
     LIB4NEURO_API ~ErrorSum();

    /**
     *
     * @param weights
     * @return
     */
     LIB4NEURO_API double eval(std::vector<double>* weights = nullptr) override;

    /**
     *
     * @param F
     */
     LIB4NEURO_API void add_error_function( ErrorFunction *F, double alpha = 1.0 );

    /**
     *
     * @return
     */
     LIB4NEURO_API size_t get_dimension( ) override;

     /**
      *
      * @param params
      * @param grad
      */
    LIB4NEURO_API void calculate_error_gradient( std::vector<double> &params, std::vector<double> &grad, double alpha = 1.0 ) override;

    /**
     *
     * @return
     */
    LIB4NEURO_API std::vector<double>* get_parameters( ) override;

    LIB4NEURO_API DataSet* get_dataset() override {
        return this->summand->at( 0 )->get_dataset();
    };

private:
    std::vector<ErrorFunction*>* summand;
    std::vector<double> *summand_coefficient;
};


#endif //INC_4NEURO_ERRORFUNCTION_H