diff --git a/src/ErrorFunction/ErrorFunctions.cpp b/src/ErrorFunction/ErrorFunctions.cpp index 72490ffc00e97d5fde1baa6e5a61082d99e4722a..7602ee5c83f2b20d150154822534a47cfd899898 100644 --- a/src/ErrorFunction/ErrorFunctions.cpp +++ b/src/ErrorFunction/ErrorFunctions.cpp @@ -367,12 +367,14 @@ namespace lib4neuro { } } - double MSE::calculate_single_residual(std::vector<double>* input, std::vector<double>* output) { + double MSE::calculate_single_residual(std::vector<double>* input, + std::vector<double>* output, + std::vector<double>* parameters) { //TODO maybe move to the general ErrorFunction //TODO check input vector sizes - they HAVE TO be allocated before calling this function - return -this->eval_on_single_input(input, output); + return -this->eval_on_single_input(input, output, parameters); } void MSE::calculate_residual_gradient(std::vector<double>* input, @@ -397,11 +399,11 @@ namespace lib4neuro { if(delta != 0) { /* Computation of f_val1 = f(x + delta) */ parameters->at(i) = former_parameter_value + delta; - f_val1 = -1 * this->eval_on_single_input(input, output, parameters); + f_val1 = this->calculate_single_residual(input, output, parameters); /* Computation of f_val2 = f(x - delta) */ parameters->at(i) = former_parameter_value - delta; - f_val2 = -1 * this->eval_on_single_input(input, output, parameters); + f_val2 = this->calculate_single_residual(input, output, parameters); gradient->at(i) = (f_val1 - f_val2) / (2*delta); } diff --git a/src/ErrorFunction/ErrorFunctions.h b/src/ErrorFunction/ErrorFunctions.h index 2549c9f1988dd1ab0cdbd63b303bc1f5572353f7..262559db1a8f5181c94ea32ce2a04de0103b20c2 100644 --- a/src/ErrorFunction/ErrorFunctions.h +++ b/src/ErrorFunction/ErrorFunctions.h @@ -183,7 +183,8 @@ class ErrorFunctionDifferentiable : public ErrorFunction { double h = 1e-3) = 0; virtual double calculate_single_residual(std::vector<double>* input, - std::vector<double>* output) = 0; + std::vector<double>* output, + std::vector<double>* parameters = nullptr) = 0; }; class MSE : public ErrorFunctionDifferentiable { @@ -226,7 +227,8 @@ class ErrorFunctionDifferentiable : public ErrorFunction { * @return */ virtual double calculate_single_residual(std::vector<double>* input, - std::vector<double>* output) override ; + std::vector<double>* output, + std::vector<double>* parameters) override; /** * Compute gradient of the residual function f(x) = 0 - MSE(x) for a specific input x.