Skip to content
Snippets Groups Projects
Commit e346029c authored by Michal Kravcenko's avatar Michal Kravcenko
Browse files

FIX: fixed an issue with NaN in GradientDescent caused by acos(x), added more...

FIX: fixed an issue with NaN in GradientDescent caused by acos(x), added more debug info to errorfunction validation
parent 3b0c0af2
No related branches found
No related tags found
No related merge requests found
Subproject commit f66133343e03fad8d35e234c2f10bd791685466a
Subproject commit a86cf5b7c6742da6113b598a8d9c0c3a3cdf8175
Subproject commit b3b4cee1c52baf935d68fe3bb7fb1a0ec6b79694
Subproject commit e0e880c3797ea363d24782ba63fe362f7d94f89c
Subproject commit fb8d7abab4c3de2ad2c1df0b80fdd7687628c8d6
Subproject commit 5c0f29012511339ba5cc2672f99a1356c5387b62
......@@ -105,22 +105,23 @@ namespace lib4neuro {
size_t dim_in = data_set->get_input_dim();
size_t dim_out = data_set->get_output_dim();
size_t n_elements = data_set->get_n_elements();
double error = 0.0, val;
double error = 0.0, val, output_norm;
std::vector<std::pair<std::vector<double>, std::vector<double>>>* data = data_set->get_data();
size_t n_elements = data->size();
//TODO instead use something smarter
std::vector<double> output(dim_out);
COUT_DEBUG("Evaluation of the error function MSE on the given data-set (format 'data-set element index' 'input'"
" 'real output' 'predicted output'):"
" 'real output' 'predicted output' 'absolute error' 'relative error'):"
<< std::endl);
*results_file_path << "[Data-set element index] [Input] [Real output] [Predicted output]" << std::endl;
*results_file_path << "[Data-set element index] [Input] [Real output] [Predicted output] [Abs. error] [Rel. error]" << std::endl;
for (auto i = 0; i < data->size(); i++) { // Iterate through every element in the test set
error = 0.0;
output_norm = 0.0;
/* Compute the net output and store it into 'output' variable */
this->net->eval_single(data->at(i).first,
output,
......@@ -146,20 +147,29 @@ namespace lib4neuro {
val = output.at(j) - data->at(i).second.at(j);
error += val * val;
}
output_norm += output.at(j) * output.at(j);
}
#ifdef L4N_DEBUG
COUT_DEBUG(i << ": "
<< ss_input.str() << " ; "
<< ss_real_output.str() << " ; "
<< ss_predicted_output.str() << std::endl);
<< ss_predicted_output.str() << "; "
<< std::sqrt(error) << "; "
<< 2 * std::sqrt(error) / (std::sqrt(error) + std::sqrt(output_norm)) <<
std::endl);
*results_file_path << i << ": "
<< ss_input.str() << " ; "
<< ss_real_output.str() << " ; "
<< ss_predicted_output.str() << std::endl;
<< ss_predicted_output.str() << "; "
<< std::sqrt(error) << "; "
<< 2 * std::sqrt(error) / (std::sqrt(error) + std::sqrt(output_norm)) <<
std::endl;
#endif
}
double result = error / n_elements;
double result = std::sqrt(error) / n_elements;
*results_file_path << "MSE = " << result << std::endl;
return result;
}
......
......@@ -107,6 +107,7 @@ namespace lib4neuro {
/* step length calculation */
if (iter_counter < 10 || iter_counter % this->restart_frequency == 0) {
/* fixed step length */
//gamma = 0.1 * this->tolerance;
gamma = 0.1 * this->tolerance;
} else {
/* angle between two consecutive gradients */
......@@ -115,6 +116,12 @@ namespace lib4neuro {
sx += (gradient_current->at(i) * gradient_prev->at(i));
}
sx /= grad_norm * grad_norm_prev;
if( sx < -1.0 + 5e-12 ){
sx = -1 + 5e-12;
}
else if( sx > 1.0 - 5e-12 ){
sx = 1 - 5e-12;
}
beta = std::sqrt(std::acos(sx) / lib4neuro::PI);
eval_step_size_mk(gamma, beta, c, grad_norm_prev, grad_norm, val, prev_val);
......@@ -140,7 +147,7 @@ namespace lib4neuro {
<< ". C: " << c
<< ". Gradient norm: " << grad_norm
<< ". Total error: " << val
<< "." << std::endl);
<< "." << "\r");
WRITE_TO_OFS_DEBUG(ofs, "Iteration: " << (unsigned int)(iter_counter)
<< ". Step size: " << gamma
......@@ -151,6 +158,12 @@ namespace lib4neuro {
}
COUT_DEBUG(std::string("Iteration: ") << (unsigned int)(iter_counter)
<< ". Step size: " << gamma
<< ". C: " << c
<< ". Gradient norm: " << grad_norm
<< ". Total error: " << val
<< "." << std::endl);
if(iter_idx == 0) {
COUT_INFO("Maximum number of iterations (" << this->maximum_niters << ") was reached!" << std::endl);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment