Skip to content
Snippets Groups Projects
CrossValidator.cpp 3.09 KiB
Newer Older
  • Learn to ignore specific revisions
  • //
    // Created by martin on 14.11.18.
    //
    
    #include "CrossValidator.h"
    
    
    namespace lib4neuro {
        LIB4NEURO_API CrossValidator::CrossValidator(ILearningMethods* optimizer, ErrorFunction* ef) {
            this->optimizer = optimizer;
            this->ef = ef;
        }
    
    
        LIB4NEURO_API void CrossValidator::run_k_fold_test(unsigned int k, unsigned int tests_number, std::ofstream* results_file_path) {
            //TODO do not duplicate code - write in a more elegant way
    
            NeuralNetwork *net = this->ef->get_network_instance();
    
    
            double cv_err_sum = 0;
    
            for(unsigned int i = 0; i < tests_number; i++) {
                COUT_INFO("Cross-validation run " << i+1 << std::endl);
                *results_file_path << "Cross-validation run " << i+1 << std::endl;
    
                this->ef->divide_data_train_test(1.0/k);
                *results_file_path << "Number of train data points: " << this->ef->get_dataset()->get_n_elements() << std::endl;
                *results_file_path << "Number of test data points: " << this->ef->get_test_dataset()->get_n_elements() << std::endl;
                net->randomize_parameters();
                net->scale_parameters( 1.0 / (net->get_n_weights() + net->get_n_biases()));
                this->optimizer->optimize(*this->ef);
    
                /* Error evaluation and writing */
                double err = this->ef->eval_on_test_data(results_file_path);
                cv_err_sum += err;
                COUT_INFO("CV error (run " << i+1 << "): " << err << std::endl << std::endl);
    
                this->ef->return_full_data_set_for_training();
            }
    
            COUT_INFO("CV error mean: " << cv_err_sum/tests_number << std::endl);
            *results_file_path << "CV error mean: " << cv_err_sum/tests_number << std::endl;
        }
    
        LIB4NEURO_API void CrossValidator::run_k_fold_test(unsigned int k, unsigned int tests_number, std::string results_file_path) {
            NeuralNetwork *net = this->ef->get_network_instance();
    
            double cv_err_sum = 0;
    
    
            for(unsigned int i = 0; i < tests_number; i++) {
    
                COUT_INFO("Cross-validation run " << i+1 << std::endl);
    
    
                this->ef->divide_data_train_test(1.0/k);
    
                COUT_DEBUG("Number of train data points: " << this->ef->get_dataset()->get_n_elements() << std::endl);
                COUT_DEBUG("Number of test data points: " << this->ef->get_test_dataset()->get_n_elements() << std::endl);
    
                net->randomize_parameters();
                net->scale_parameters( 1.0 / (net->get_n_weights() + net->get_n_biases()));
    
                this->optimizer->optimize(*this->ef);
    
    
                /* Error evaluation and writing */
                double err;
                if(results_file_path == "") {
                    err = this->ef->eval_on_test_data();
                } else {
                    err = this->ef->eval_on_test_data(results_file_path + "_cv" + std::to_string(i) + ".dat");
                }
                cv_err_sum += err;
                COUT_INFO("CV error (run " << i+1 << "): " << err << std::endl << std::endl);
    
    
                this->ef->return_full_data_set_for_training();
            }
    
    
            COUT_INFO("CV error mean: " << cv_err_sum/tests_number << std::endl);
    
    Martin Beseda's avatar
    Martin Beseda committed
    }