Skip to content
Snippets Groups Projects
Commit bb473a84 authored by Martin Beseda's avatar Martin Beseda
Browse files

FIX: Restructured ErrorFunction classes

parent 57bd1808
No related branches found
No related tags found
No related merge requests found
......@@ -504,24 +504,6 @@ namespace lib4neuro {
return output;
}
void ErrorSum::calculate_error_gradient(std::vector<double>& params,
std::vector<double>& grad,
double alpha,
size_t batch) {
ErrorFunction* ef = nullptr;
for (size_t i = 0; i < this->summand->size(); ++i) {
ef = this->summand->at(i);
if (ef) {
ef->calculate_error_gradient(params,
grad,
this->summand_coefficient->at(i) * alpha,
batch);
}
}
}
void ErrorSum::add_error_function(ErrorFunction* F,
double alpha) {
if (!this->summand) {
......@@ -562,4 +544,38 @@ namespace lib4neuro {
DataSet* ErrorSum::get_dataset() {
return this->summand->at(0)->get_dataset();
};
void ErrorSum::calculate_error_gradient(std::vector<double>& params,
std::vector<double>& grad,
double alpha,
size_t batch) {
ErrorFunction* ef = nullptr;
for (size_t i = 0; i < this->summand->size(); ++i) {
ef = this->summand->at(i);
if (ef) {
ef->calculate_error_gradient(params,
grad,
this->summand_coefficient->at(i) * alpha,
batch);
}
}
}
void ErrorSum::calculate_residual_gradient(std::vector<double>* input,
std::vector<double>* output,
std::vector<double>* gradient,
double h) {
THROW_NOT_IMPLEMENTED_ERROR();
}
double ErrorSum::calculate_single_residual(std::vector<double>* input,
std::vector<double>* output,
std::vector<double>* parameters) {
THROW_NOT_IMPLEMENTED_ERROR();
return 0;
}
}
......@@ -36,19 +36,6 @@ namespace lib4neuro {
*/
LIB4NEURO_API virtual size_t get_dimension();
/**
*
* @param params
* @param grad
* @param alpha
* @param batch
*/
virtual void
calculate_error_gradient(std::vector<double>& params,
std::vector<double>& grad,
double alpha = 1.0,
size_t batch = 0) = 0;
/**
*
* @return
......@@ -140,6 +127,23 @@ namespace lib4neuro {
bool denormalize_data = true,
bool verbose = false) = 0;
virtual void
calculate_error_gradient(std::vector<double>& params,
std::vector<double>& grad,
double alpha,
size_t batch) = 0;
virtual void
calculate_residual_gradient(std::vector<double>* input,
std::vector<double>* output,
std::vector<double>* gradient,
double h = 1e-3) = 0;
virtual double
calculate_single_residual(std::vector<double>* input,
std::vector<double>* output,
std::vector<double>* parameters = nullptr) = 0;
protected:
/**
......@@ -167,27 +171,8 @@ namespace lib4neuro {
*/
DataSet* ds_test = nullptr;
};
class ErrorFunctionDifferentiable : public ErrorFunction {
public:
virtual void
calculate_error_gradient(std::vector<double>& params,
std::vector<double>& grad,
double alpha,
size_t batch) = 0;
virtual void
calculate_residual_gradient(std::vector<double>* input,
std::vector<double>* output,
std::vector<double>* gradient,
double h = 1e-3) = 0;
virtual double calculate_single_residual(std::vector<double>* input,
std::vector<double>* output,
std::vector<double>* parameters = nullptr) = 0;
};
class MSE : public ErrorFunctionDifferentiable {
class MSE : public ErrorFunction{
public:
/**
......@@ -336,7 +321,7 @@ class ErrorFunctionDifferentiable : public ErrorFunction {
*/
LIB4NEURO_API double eval(std::vector<double>* weights = nullptr,
bool denormalize_data = false,
bool verbose = false) override;
bool verbose = false);
/**
*
......@@ -424,6 +409,19 @@ class ErrorFunctionDifferentiable : public ErrorFunction {
std::vector<double>& grad,
double alpha = 1.0,
size_t batch = 0) override;
LIB4NEURO_API void
calculate_residual_gradient(std::vector<double>* input,
std::vector<double>* output,
std::vector<double>* gradient,
double h = 1e-3) override;
LIB4NEURO_API double
calculate_single_residual(std::vector<double>* input,
std::vector<double>* output,
std::vector<double>* parameters = nullptr) override;
/**
*
* @return
......@@ -436,11 +434,10 @@ class ErrorFunctionDifferentiable : public ErrorFunction {
*/
LIB4NEURO_API DataSet* get_dataset() override;
private:
protected:
std::vector<ErrorFunction*>* summand;
std::vector<double>* summand_coefficient;
};
}
#endif //INC_4NEURO_ERRORFUNCTION_H
......@@ -18,6 +18,8 @@ MOCK_BASE_CLASS(mock_ErrorFunction, lib4neuro::ErrorFunction)
MOCK_METHOD(eval, 3)
MOCK_METHOD(get_dimension, 0)
MOCK_METHOD(calculate_error_gradient, 4)
MOCK_METHOD(calculate_residual_gradient, 4)
MOCK_METHOD(calculate_single_residual, 3)
MOCK_METHOD(get_parameters, 0)
MOCK_METHOD(get_dataset, 0)
MOCK_METHOD(get_network_instance, 0)
......
......@@ -42,7 +42,7 @@ namespace lib4neuro {
}
void GradientDescent::optimize(lib4neuro::ErrorFunctionDifferentiable &ef, std::ofstream* ofs) {
void GradientDescent::optimize(lib4neuro::ErrorFunction &ef, std::ofstream* ofs) {
/* Copy data set max and min values, if it's normalized */
if(ef.get_dataset()->is_normalized()) {
......
......@@ -92,7 +92,7 @@ namespace lib4neuro {
*
* @param ef
*/
LIB4NEURO_API void optimize(lib4neuro::ErrorFunctionDifferentiable &ef, std::ofstream* ofs = nullptr) override;
LIB4NEURO_API void optimize(lib4neuro::ErrorFunction &ef, std::ofstream* ofs = nullptr) override;
/**
*
......
......@@ -28,13 +28,6 @@ namespace lib4neuro {
};
class GradientLearningMethod : public LearningMethod {
private:
/**
* Runs the method specific learning algorithm minimizing the given error function
*/
virtual void optimize(ErrorFunctionDifferentiable& ef,
std::ofstream* ofs = nullptr) = 0;
public:
/**
* Runs the method specific learning algorithm minimizing the given error function
......
......@@ -10,7 +10,6 @@
namespace lib4neuro {
void GradientLearningMethod::optimize(ErrorFunction& ef,
std::ofstream* ofs) {
auto& new_ef = dynamic_cast<ErrorFunctionDifferentiable&>(ef);
this->optimize(new_ef, ofs);
this->optimize(ef, ofs);
}
}
\ No newline at end of file
......@@ -38,13 +38,13 @@ struct lib4neuro::LevenbergMarquardt::LevenbergMarquardtImpl {
* @param h Step size
* @return Jacobian matrix
*/
arma::Mat<double>* get_Jacobian_matrix(lib4neuro::ErrorFunctionDifferentiable& ef,
arma::Mat<double>* get_Jacobian_matrix(lib4neuro::ErrorFunction& ef,
arma::Mat<double>* J,
double h=1e-3);
};
arma::Mat<double>* lib4neuro::LevenbergMarquardt::LevenbergMarquardtImpl::get_Jacobian_matrix(
lib4neuro::ErrorFunctionDifferentiable& ef,
lib4neuro::ErrorFunction& ef,
arma::Mat<double>* J,
double h) {
......@@ -84,12 +84,12 @@ namespace lib4neuro {
this->p_impl->maximum_niters = max_iters;
}
void LevenbergMarquardt::optimize(lib4neuro::ErrorFunctionDifferentiable& ef,
void LevenbergMarquardt::optimize(lib4neuro::ErrorFunction& ef,
std::ofstream* ofs) {
optimize(ef, LM_UPDATE_TYPE::MARQUARDT, ofs);
}
void LevenbergMarquardt::optimize(lib4neuro::ErrorFunctionDifferentiable& ef,
void LevenbergMarquardt::optimize(lib4neuro::ErrorFunction& ef,
lib4neuro::LM_UPDATE_TYPE update_type,
std::ofstream* ofs) {
......
......@@ -40,9 +40,9 @@ namespace lib4neuro {
double lambda_increase=11,
double lambda_decrease=9);
void optimize(ErrorFunctionDifferentiable &ef, std::ofstream* ofs = nullptr);
void optimize(ErrorFunction &ef, std::ofstream* ofs = nullptr);
void optimize(ErrorFunctionDifferentiable &ef,
void optimize(ErrorFunction &ef,
LM_UPDATE_TYPE update_type,
std::ofstream* ofs = nullptr);
......
......@@ -58,7 +58,7 @@ void optimize_via_particle_swarm( l4n::NeuralNetwork &net, l4n::ErrorFunction &e
std::cout << "***********************************************************************************************************************" <<std::endl;
}
void optimize_via_gradient_descent( l4n::NeuralNetwork &net, l4n::ErrorFunction &ef ){
void optimize_via_gradient_descent( l4n::NeuralNetwork &net, l4n::ErrorSum &ef ){
l4n::GradientDescent gd( 1e-6, 1000 );
......@@ -68,7 +68,7 @@ void optimize_via_gradient_descent( l4n::NeuralNetwork &net, l4n::ErrorFunction
net.copy_parameter_space(parameters);
/* ERROR CALCULATION */
std::cout << "Run finished! Error of the network[Gradient descent]: " << ef.eval( nullptr )<< std::endl;
std::cout << "Run finished! Error of the network[Gradient descent]: " << ef.eval()<< std::endl;
std::cout << "***********************************************************************************************************************" <<std::endl;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment