Skip to content
Snippets Groups Projects
Commit 36acbb81 authored by Michal Kravcenko's avatar Michal Kravcenko
Browse files

FIX: fixed an issue with NaN in GradientDescent caused by acos(x), added more...

FIX: fixed an issue with NaN in GradientDescent caused by acos(x), added more debug info to errorfunction validation
parent 5cba0859
No related branches found
No related tags found
No related merge requests found
Subproject commit f66133343e03fad8d35e234c2f10bd791685466a Subproject commit a86cf5b7c6742da6113b598a8d9c0c3a3cdf8175
Subproject commit b3b4cee1c52baf935d68fe3bb7fb1a0ec6b79694 Subproject commit e0e880c3797ea363d24782ba63fe362f7d94f89c
Subproject commit fb8d7abab4c3de2ad2c1df0b80fdd7687628c8d6 Subproject commit 5c0f29012511339ba5cc2672f99a1356c5387b62
...@@ -105,22 +105,23 @@ namespace lib4neuro { ...@@ -105,22 +105,23 @@ namespace lib4neuro {
size_t dim_in = data_set->get_input_dim(); size_t dim_in = data_set->get_input_dim();
size_t dim_out = data_set->get_output_dim(); size_t dim_out = data_set->get_output_dim();
size_t n_elements = data_set->get_n_elements(); double error = 0.0, val, output_norm;
double error = 0.0, val;
std::vector<std::pair<std::vector<double>, std::vector<double>>>* data = data_set->get_data(); std::vector<std::pair<std::vector<double>, std::vector<double>>>* data = data_set->get_data();
size_t n_elements = data->size();
//TODO instead use something smarter //TODO instead use something smarter
std::vector<double> output(dim_out); std::vector<double> output(dim_out);
COUT_DEBUG("Evaluation of the error function MSE on the given data-set (format 'data-set element index' 'input'" COUT_DEBUG("Evaluation of the error function MSE on the given data-set (format 'data-set element index' 'input'"
" 'real output' 'predicted output'):" " 'real output' 'predicted output' 'absolute error' 'relative error'):"
<< std::endl); << std::endl);
*results_file_path << "[Data-set element index] [Input] [Real output] [Predicted output]" << std::endl; *results_file_path << "[Data-set element index] [Input] [Real output] [Predicted output] [Abs. error] [Rel. error]" << std::endl;
for (auto i = 0; i < data->size(); i++) { // Iterate through every element in the test set for (auto i = 0; i < data->size(); i++) { // Iterate through every element in the test set
error = 0.0;
output_norm = 0.0;
/* Compute the net output and store it into 'output' variable */ /* Compute the net output and store it into 'output' variable */
this->net->eval_single(data->at(i).first, this->net->eval_single(data->at(i).first,
output, output,
...@@ -146,20 +147,29 @@ namespace lib4neuro { ...@@ -146,20 +147,29 @@ namespace lib4neuro {
val = output.at(j) - data->at(i).second.at(j); val = output.at(j) - data->at(i).second.at(j);
error += val * val; error += val * val;
}
output_norm += output.at(j) * output.at(j);
}
#ifdef L4N_DEBUG
COUT_DEBUG(i << ": " COUT_DEBUG(i << ": "
<< ss_input.str() << " ; " << ss_input.str() << " ; "
<< ss_real_output.str() << " ; " << ss_real_output.str() << " ; "
<< ss_predicted_output.str() << std::endl); << ss_predicted_output.str() << "; "
<< std::sqrt(error) << "; "
<< 2 * std::sqrt(error) / (std::sqrt(error) + std::sqrt(output_norm)) <<
std::endl);
*results_file_path << i << ": " *results_file_path << i << ": "
<< ss_input.str() << " ; " << ss_input.str() << " ; "
<< ss_real_output.str() << " ; " << ss_real_output.str() << " ; "
<< ss_predicted_output.str() << std::endl; << ss_predicted_output.str() << "; "
<< std::sqrt(error) << "; "
<< 2 * std::sqrt(error) / (std::sqrt(error) + std::sqrt(output_norm)) <<
std::endl;
#endif
} }
double result = error / n_elements; double result = std::sqrt(error) / n_elements;
*results_file_path << "MSE = " << result << std::endl; *results_file_path << "MSE = " << result << std::endl;
return result; return result;
} }
......
...@@ -107,6 +107,7 @@ namespace lib4neuro { ...@@ -107,6 +107,7 @@ namespace lib4neuro {
/* step length calculation */ /* step length calculation */
if (iter_counter < 10 || iter_counter % this->restart_frequency == 0) { if (iter_counter < 10 || iter_counter % this->restart_frequency == 0) {
/* fixed step length */ /* fixed step length */
//gamma = 0.1 * this->tolerance;
gamma = 0.1 * this->tolerance; gamma = 0.1 * this->tolerance;
} else { } else {
/* angle between two consecutive gradients */ /* angle between two consecutive gradients */
...@@ -115,6 +116,12 @@ namespace lib4neuro { ...@@ -115,6 +116,12 @@ namespace lib4neuro {
sx += (gradient_current->at(i) * gradient_prev->at(i)); sx += (gradient_current->at(i) * gradient_prev->at(i));
} }
sx /= grad_norm * grad_norm_prev; sx /= grad_norm * grad_norm_prev;
if( sx < -1.0 + 5e-12 ){
sx = -1 + 5e-12;
}
else if( sx > 1.0 - 5e-12 ){
sx = 1 - 5e-12;
}
beta = std::sqrt(std::acos(sx) / lib4neuro::PI); beta = std::sqrt(std::acos(sx) / lib4neuro::PI);
eval_step_size_mk(gamma, beta, c, grad_norm_prev, grad_norm, val, prev_val); eval_step_size_mk(gamma, beta, c, grad_norm_prev, grad_norm, val, prev_val);
...@@ -140,7 +147,7 @@ namespace lib4neuro { ...@@ -140,7 +147,7 @@ namespace lib4neuro {
<< ". C: " << c << ". C: " << c
<< ". Gradient norm: " << grad_norm << ". Gradient norm: " << grad_norm
<< ". Total error: " << val << ". Total error: " << val
<< "." << std::endl); << "." << "\r");
WRITE_TO_OFS_DEBUG(ofs, "Iteration: " << (unsigned int)(iter_counter) WRITE_TO_OFS_DEBUG(ofs, "Iteration: " << (unsigned int)(iter_counter)
<< ". Step size: " << gamma << ". Step size: " << gamma
...@@ -151,6 +158,12 @@ namespace lib4neuro { ...@@ -151,6 +158,12 @@ namespace lib4neuro {
} }
COUT_DEBUG(std::string("Iteration: ") << (unsigned int)(iter_counter)
<< ". Step size: " << gamma
<< ". C: " << c
<< ". Gradient norm: " << grad_norm
<< ". Total error: " << val
<< "." << std::endl);
if(iter_idx == 0) { if(iter_idx == 0) {
COUT_INFO("Maximum number of iterations (" << this->maximum_niters << ") was reached!" << std::endl); COUT_INFO("Maximum number of iterations (" << this->maximum_niters << ") was reached!" << std::endl);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment