Skip to content
Snippets Groups Projects
Commit 6389037f authored by Martin Beseda's avatar Martin Beseda
Browse files

ENH: Added 'parameters' parameter to the method calculate_single_residual.

parent 4d04630e
No related branches found
No related tags found
No related merge requests found
......@@ -367,12 +367,14 @@ namespace lib4neuro {
}
}
double MSE::calculate_single_residual(std::vector<double>* input, std::vector<double>* output) {
double MSE::calculate_single_residual(std::vector<double>* input,
std::vector<double>* output,
std::vector<double>* parameters) {
//TODO maybe move to the general ErrorFunction
//TODO check input vector sizes - they HAVE TO be allocated before calling this function
return -this->eval_on_single_input(input, output);
return -this->eval_on_single_input(input, output, parameters);
}
void MSE::calculate_residual_gradient(std::vector<double>* input,
......@@ -397,11 +399,11 @@ namespace lib4neuro {
if(delta != 0) {
/* Computation of f_val1 = f(x + delta) */
parameters->at(i) = former_parameter_value + delta;
f_val1 = -1 * this->eval_on_single_input(input, output, parameters);
f_val1 = this->calculate_single_residual(input, output, parameters);
/* Computation of f_val2 = f(x - delta) */
parameters->at(i) = former_parameter_value - delta;
f_val2 = -1 * this->eval_on_single_input(input, output, parameters);
f_val2 = this->calculate_single_residual(input, output, parameters);
gradient->at(i) = (f_val1 - f_val2) / (2*delta);
}
......
......@@ -183,7 +183,8 @@ class ErrorFunctionDifferentiable : public ErrorFunction {
double h = 1e-3) = 0;
virtual double calculate_single_residual(std::vector<double>* input,
std::vector<double>* output) = 0;
std::vector<double>* output,
std::vector<double>* parameters = nullptr) = 0;
};
class MSE : public ErrorFunctionDifferentiable {
......@@ -226,7 +227,8 @@ class ErrorFunctionDifferentiable : public ErrorFunction {
* @return
*/
virtual double calculate_single_residual(std::vector<double>* input,
std::vector<double>* output) override ;
std::vector<double>* output,
std::vector<double>* parameters) override;
/**
* Compute gradient of the residual function f(x) = 0 - MSE(x) for a specific input x.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment