Newer
Older
//
// Created by martin on 14.11.18.
//
#include "CrossValidator.h"
#include "../message.h"
namespace lib4neuro {
LIB4NEURO_API CrossValidator::CrossValidator(ILearningMethods* optimizer, ErrorFunction* ef) {
this->optimizer = optimizer;
this->ef = ef;
}
LIB4NEURO_API void CrossValidator::run_k_fold_test(unsigned int k, unsigned int tests_number) {
NeuralNetwork *net = this->ef->get_network_instance();
for(unsigned int i = 0; i < tests_number; i++) {
Martin Beseda
committed
COUT_DEBUG(<< "Cross-validation run " << i+1 << std::endl);
this->ef->divide_data_train_test(1.0/k);
Martin Beseda
committed
COUT_DEBUG(<< "number of train data points: " << this->ef->get_dataset()->get_n_elements() << std::endl);
net->randomize_parameters();
net->scale_parameters( 1.0 / (net->get_n_weights() + net->get_n_biases()));
net->print_weights();
this->optimizer->optimize(*this->ef);
this->ef->return_full_data_set_for_training();
}
}