Skip to content
Snippets Groups Projects
ErrorFunctions.h 1.6 KiB
Newer Older
//
// Created by martin on 7/15/18.
//

#ifndef INC_4NEURO_ERRORFUNCTION_H
#define INC_4NEURO_ERRORFUNCTION_H

#include "../Network/NeuralNetwork.h"
#include "../DataSet/DataSet.h"
class ErrorFunction {
public:

    /**
     *
     * @param weights
     * @return
     */
Michal Kravcenko's avatar
Michal Kravcenko committed
    virtual double eval(std::vector<double>* weights = nullptr) = 0;
    /**
     * 
     * @return 
     */
Martin Beseda's avatar
Martin Beseda committed
    size_t dimension = 0;
};

class MSE : public ErrorFunction {

public:
    /**
     * Constructor for single neural network
     * @param net
     * @param ds
     */
    LIB4NEURO_API MSE(NeuralNetwork* net, DataSet* ds);

    /**
     *
     * @param weights
     * @return
     */
    LIB4NEURO_API virtual double eval(std::vector<double>* weights = nullptr);

private:

    NeuralNetwork* net;
    DataSet* ds;
Michal Kravcenko's avatar
Michal Kravcenko committed
class ErrorSum : public ErrorFunction{
    LIB4NEURO_API virtual double eval(std::vector<double>* weights = nullptr);
    LIB4NEURO_API void add_error_function( ErrorFunction *F, double alpha = 1.0 );
Martin Beseda's avatar
Martin Beseda committed
    /**
     *
     * @return
     */
Martin Beseda's avatar
Martin Beseda committed
    std::vector<ErrorFunction*>* summand;
    std::vector<double> *summand_coefficient;

#endif //INC_4NEURO_ERRORFUNCTION_H