Skip to content
Snippets Groups Projects
Commit 44ee2589 authored by Martin Beseda's avatar Martin Beseda
Browse files

ENH: Method run_k_fold_test overloaded, so that it can output not only to...

ENH: Method run_k_fold_test overloaded, so that it can output not only to STDOUT, but also to the file, either according to its name or to the open handler.
parent cfc23b21
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@
//
#include "CrossValidator.h"
#include "../message.h"
#include "message.h"
namespace lib4neuro {
LIB4NEURO_API CrossValidator::CrossValidator(ILearningMethods* optimizer, ErrorFunction* ef) {
......@@ -11,19 +11,63 @@ namespace lib4neuro {
this->ef = ef;
}
LIB4NEURO_API void CrossValidator::run_k_fold_test(unsigned int k, unsigned int tests_number) {
LIB4NEURO_API void CrossValidator::run_k_fold_test(unsigned int k, unsigned int tests_number, std::ofstream* results_file_path) {
//TODO do not duplicate code - write in a more elegant way
NeuralNetwork *net = this->ef->get_network_instance();
double cv_err_sum = 0;
for(unsigned int i = 0; i < tests_number; i++) {
COUT_INFO("Cross-validation run " << i+1 << std::endl);
*results_file_path << "Cross-validation run " << i+1 << std::endl;
this->ef->divide_data_train_test(1.0/k);
*results_file_path << "Number of train data points: " << this->ef->get_dataset()->get_n_elements() << std::endl;
*results_file_path << "Number of test data points: " << this->ef->get_test_dataset()->get_n_elements() << std::endl;
net->randomize_parameters();
net->scale_parameters( 1.0 / (net->get_n_weights() + net->get_n_biases()));
this->optimizer->optimize(*this->ef);
/* Error evaluation and writing */
double err = this->ef->eval_on_test_data(results_file_path);
cv_err_sum += err;
COUT_INFO("CV error (run " << i+1 << "): " << err << std::endl << std::endl);
this->ef->return_full_data_set_for_training();
}
COUT_INFO("CV error mean: " << cv_err_sum/tests_number << std::endl);
*results_file_path << "CV error mean: " << cv_err_sum/tests_number << std::endl;
}
LIB4NEURO_API void CrossValidator::run_k_fold_test(unsigned int k, unsigned int tests_number, std::string results_file_path) {
NeuralNetwork *net = this->ef->get_network_instance();
double cv_err_sum = 0;
for(unsigned int i = 0; i < tests_number; i++) {
COUT_DEBUG(<< "Cross-validation run " << i+1 << std::endl);
COUT_INFO("Cross-validation run " << i+1 << std::endl);
this->ef->divide_data_train_test(1.0/k);
COUT_DEBUG(<< "number of train data points: " << this->ef->get_dataset()->get_n_elements() << std::endl);
COUT_DEBUG("Number of train data points: " << this->ef->get_dataset()->get_n_elements() << std::endl);
COUT_DEBUG("Number of test data points: " << this->ef->get_test_dataset()->get_n_elements() << std::endl);
net->randomize_parameters();
net->scale_parameters( 1.0 / (net->get_n_weights() + net->get_n_biases()));
net->print_weights();
this->optimizer->optimize(*this->ef);
/* Error evaluation and writing */
double err;
if(results_file_path == "") {
err = this->ef->eval_on_test_data();
} else {
err = this->ef->eval_on_test_data(results_file_path + "_cv" + std::to_string(i) + ".dat");
}
cv_err_sum += err;
COUT_INFO("CV error (run " << i+1 << "): " << err << std::endl << std::endl);
this->ef->return_full_data_set_for_training();
}
COUT_INFO("CV error mean: " << cv_err_sum/tests_number << std::endl);
}
}
......@@ -36,12 +36,22 @@ namespace lib4neuro {
*/
LIB4NEURO_API CrossValidator(ILearningMethods* optimizer, ErrorFunction* ef);
/**
*
* @param k
* @param test_number
* @param results_file_path
*/
LIB4NEURO_API void
run_k_fold_test(unsigned int k, unsigned int test_number, std::string results_file_path = "");
/**
*
* @param k
* @param tests_number
* @param results_file_path
*/
LIB4NEURO_API void run_k_fold_test(unsigned int k, unsigned int test_number);
LIB4NEURO_API void run_k_fold_test(unsigned int k, unsigned int tests_number, std::ofstream* results_file_path);
};
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment