Skip to content
Snippets Groups Projects
LearningSequence.cpp 1.81 KiB
Newer Older
  • Learn to ignore specific revisions
  • /**
     * DESCRIPTION OF THE FILE
     *
     * @author Michal Kravčenko
     * @date 19.2.19 -
     */
    
    #include "LearningSequence.h"
    
    #include "../message.h"
    
    
    namespace lib4neuro {
    
        LearningSequence::LearningSequence( double tolerance, int max_n_cycles ){
            this->tol = tolerance;
            this->max_number_of_cycles = max_n_cycles;
        }
    
    
        LearningSequence::~LearningSequence() = default;
    
        void LearningSequence::add_learning_method(std::shared_ptr<LearningMethod> method) {
    
            this->learning_sequence.push_back( method );
        }
    
        void LearningSequence::optimize(lib4neuro::ErrorFunction &ef, std::ofstream *ofs) {
            double error = ef.eval();
    
            this->optimal_parameters = ef.get_parameters();
    
            double the_best_error = error;
            int mcycles = this->max_number_of_cycles, cycle_idx = 0;
    
    
            std::vector<double> params;
    
            while( error > this->tol && mcycles != 0){
                mcycles--;
                cycle_idx++;
    
    
                for(auto m: this->learning_sequence ){
    
                    //TODO do NOT copy vectors if not needed
                    params = *m->get_parameters();
    
                    error = ef.eval(&params);
    
                    ef.get_network_instance()->copy_parameter_space(&params);
    
                    if( error < the_best_error ){
                        the_best_error = error;
    
                        this->optimal_parameters = ef.get_parameters();
    
                        ef.get_network_instance()->copy_parameter_space( &this->optimal_parameters );
    
                COUT_DEBUG("Cycle: " << cycle_idx << ", the lowest error: " << the_best_error << std::endl );
    
            ef.get_network_instance()->copy_parameter_space( &this->optimal_parameters );