Skip to content
Snippets Groups Projects
CrossValidator.cpp 945 B
Newer Older
  • Learn to ignore specific revisions
  • //
    // Created by martin on 14.11.18.
    //
    
    #include "CrossValidator.h"
    #include "../message.h"
    
    namespace lib4neuro {
        LIB4NEURO_API CrossValidator::CrossValidator(ILearningMethods* optimizer, ErrorFunction* ef) {
            this->optimizer = optimizer;
            this->ef = ef;
        }
    
        LIB4NEURO_API void CrossValidator::run_k_fold_test(unsigned int k, unsigned int tests_number) {
            for(unsigned int i = 0; i < tests_number; i++) {
                std::cout << "Cross-validation run " << i+1 << std::endl;
    
                this->ef->divide_data_train_test(1.0/k);
                this->ef->get_network_instance()->print_weights();
                this->ef->get_network_instance()->randomize_weights();
    //            this->ef->get_network_instance()->print_stats();
                this->ef->get_network_instance()->print_weights();
                this->optimizer->optimize(*this->ef);
                this->ef->return_full_data_set_for_training();
            }
        }
    }