Skip to content
Snippets Groups Projects
LearningSequence.cpp 2.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • /**
     * DESCRIPTION OF THE FILE
     *
     * @author Michal Kravčenko
     * @date 19.2.19 -
     */
    
    #include "LearningSequence.h"
    
    #include "../message.h"
    
    
    namespace lib4neuro {
    
        LearningSequence::LearningSequence( double tolerance, int max_n_cycles ){
            this->tol = tolerance;
            this->max_number_of_cycles = max_n_cycles;
    
    //        this->best_parameters = new std::vector<double>();
    
        LearningSequence::~LearningSequence() = default;
    
    
        std::vector<double>* LearningSequence::get_parameters() {
            if( this->learning_sequence.size() > 0 ){
                return this->learning_sequence[0]->get_parameters( );
            }
            return nullptr;
        }
    
    
        void LearningSequence::add_learning_method(std::shared_ptr<LearningMethod> method) {
    
            this->learning_sequence.push_back( method );
        }
    
        void LearningSequence::optimize(lib4neuro::ErrorFunction &ef, std::ofstream *ofs) {
    
    
            puts("*********************** 11");
    
    
    
            puts("*********************** 12");
    
    
            double the_best_error = error;
            int mcycles = this->max_number_of_cycles, cycle_idx = 0;
    
            while( error > this->tol && mcycles != 0){
                mcycles--;
                cycle_idx++;
    
    
                puts("*********************** 7");
    
    
                for( auto m: this->learning_sequence ){
    
    
                    puts("*********************** 8");
    
    
    
                    puts("*********************** 9");
    
    
                    ef.get_network_instance()->copy_parameter_space(m->get_parameters());
    
                    puts("*********************** 10");
    
    
                    if( error < the_best_error ){
                        the_best_error = error;
    
                        this->best_parameters = *ef.get_parameters();
    
                        ef.get_network_instance()->copy_parameter_space( &this->best_parameters );
    
                COUT_DEBUG("Cycle: " << cycle_idx << ", the lowest error: " << the_best_error << std::endl );
    
            ef.get_network_instance()->copy_parameter_space( &this->best_parameters );