Skip to content
Snippets Groups Projects
LearningSequence.cpp 1.69 KiB
Newer Older
  • Learn to ignore specific revisions
  • /**
     * DESCRIPTION OF THE FILE
     *
     * @author Michal Kravčenko
     * @date 19.2.19 -
     */
    
    #include "LearningSequence.h"
    
    namespace lib4neuro {
    
        LearningSequence::LearningSequence( double tolerance, int max_n_cycles ){
            this->tol = tolerance;
            this->max_number_of_cycles = max_n_cycles;
            this->best_parameters = new std::vector<double>();
        }
    
        LearningSequence::~LearningSequence() {
    
    
        }
    
        std::vector<double>* LearningSequence::get_parameters() {
            if( this->learning_sequence.size() > 0 ){
                return this->learning_sequence[0]->get_parameters( );
            }
            return nullptr;
        }
    
        void LearningSequence::add_learning_method(ILearningMethods *method) {
            this->learning_sequence.push_back( method );
        }
    
        void LearningSequence::optimize(lib4neuro::ErrorFunction &ef, std::ofstream *ofs) {
    
            double error = ef.eval();
            double the_best_error = error;
            int mcycles = this->max_number_of_cycles, cycle_idx = 0;
    
            while( error > this->tol && mcycles != 0){
                mcycles--;
                cycle_idx++;
    
                for( auto m: this->learning_sequence ){
                    m->optimize( ef, ofs );
                    error = ef.eval();
    
                    if( error < the_best_error ){
                        the_best_error = error;
                        *this->best_parameters = *ef.get_parameters();
                    }
    
                    if( error <= this->tol ){
                        ef.get_network_instance()->copy_parameter_space( this->best_parameters );
                        return;
                    }
                }
                COUT_DEBUG( "Cycle: " << cycle_idx << ", the lowest error: " << the_best_error << std::endl );
            }
        }
    }