Skip to content
Snippets Groups Projects
LearningSequence.cpp 1.90 KiB
/**
 * DESCRIPTION OF THE FILE
 *
 * @author Michal Kravčenko
 * @date 19.2.19 -
 */

#include "LearningSequence.h"
#include "../message.h"

namespace lib4neuro {

    LearningSequence::LearningSequence(double tolerance,
                                       int max_n_cycles) {
        this->tol = tolerance;
        this->max_number_of_cycles = max_n_cycles;
    }

    LearningSequence::~LearningSequence() = default;

    void LearningSequence::add_learning_method(std::shared_ptr<LearningMethod> method) {
        this->learning_sequence.push_back(method);
    }

    void LearningSequence::optimize(lib4neuro::ErrorFunction& ef,
                                    std::ofstream* ofs) {
        double error = ef.eval();
        this->optimal_parameters = ef.get_parameters();
        double the_best_error = error;
        int mcycles = this->max_number_of_cycles, cycle_idx = 0;

        std::vector<double> params;
        while (error > this->tol && mcycles != 0) {
            mcycles--;
            cycle_idx++;

            for (auto m: this->learning_sequence) {
                m->optimize(ef,
                            ofs);

                //TODO do NOT copy vectors if not needed
                params = *m->get_parameters();
                error = ef.eval(&params);

                ef.get_network_instance()->copy_parameter_space(&params);

                if (error < the_best_error) {
                    the_best_error = error;
                    this->optimal_parameters = ef.get_parameters();
                }

                if (error <= this->tol) {
                    ef.get_network_instance()->copy_parameter_space(&this->optimal_parameters);
                    return;
                }
            }
            COUT_DEBUG("Cycle: " << cycle_idx << ", the lowest error: " << the_best_error << std::endl);
        }
        ef.get_network_instance()->copy_parameter_space(&this->optimal_parameters);
    }
}