Skip to content
Snippets Groups Projects
ErrorFunctions.h 1.58 KiB
Newer Older
//
// Created by martin on 7/15/18.
//

#ifndef INC_4NEURO_ERRORFUNCTION_H
#define INC_4NEURO_ERRORFUNCTION_H

#include "../Network/NeuralNetwork.h"
#include "../DataSet/DataSet.h"
David Vojtek's avatar
David Vojtek committed

enum ErrorFunctionType{
    ErrorFuncMSE
};


class ErrorFunction {
public:

    /**
     *
     * @param weights
     * @return
     */
    virtual double eval(std::vector<double>* weights = nullptr) = 0;
    
    /**
     * 
     * @return 
     */
    LIB4NEURO_API virtual size_t get_dimension();

protected:

    /**
     *
     */
    size_t dimension = 0;
};

class MSE : public ErrorFunction {

public:
    /**
     * Constructor for single neural network
     * @param net
     * @param ds
     */
    LIB4NEURO_API MSE(NeuralNetwork* net, DataSet* ds);

    /**
     *
     * @param weights
     * @return
     */
    LIB4NEURO_API virtual double eval(std::vector<double>* weights = nullptr);

private:

    NeuralNetwork* net;
    DataSet* ds;
};

class ErrorSum : public ErrorFunction{
public:
    /**
     *
     */
    LIB4NEURO_API ErrorSum();

    /**
     *
     */
    LIB4NEURO_API ~ErrorSum();

    /**
     *
     * @param weights
     * @return
     */
    LIB4NEURO_API virtual double eval(std::vector<double>* weights = nullptr);

    /**
     *
     * @param F
     */
    LIB4NEURO_API void add_error_function( ErrorFunction *F, double alpha = 1.0 );

    /**
     *
     * @return
     */
    LIB4NEURO_API size_t get_dimension() override;

private:
    std::vector<ErrorFunction*>* summand;
    std::vector<double> *summand_coefficient;
};


#endif //INC_4NEURO_ERRORFUNCTION_H