Skip to content
Snippets Groups Projects
ErrorFunctions.cpp 17.50 KiB
//
// Created by martin on 7/15/18.
//

#include <vector>
#include <cmath>
#include <sstream>
#include <boost/random/mersenne_twister.hpp>
#include <boost/random/uniform_int_distribution.hpp>

#include "ErrorFunctions.h"
#include "exceptions.h"
#include "message.h"

namespace lib4neuro {

    size_t ErrorFunction::get_dimension() {
        return this->dimension;
    }

    NeuralNetwork* ErrorFunction::get_network_instance() {
        return this->net;
    }

    void ErrorFunction::divide_data_train_test(double percent_test) {
        size_t ds_size = this->ds->get_n_elements();

        /* Store the full data set */
        this->ds_full = this->ds;

        /* Choose random subset of the DataSet for training and the remaining part for validation */
        boost::random::mt19937 gen;
        boost::random::uniform_int_distribution<> dist(0,
                                                       ds_size - 1);

        size_t test_set_size = ceil(ds_size * percent_test);

        std::vector<unsigned int> test_indices;
        test_indices.reserve(test_set_size);
        for (unsigned int i = 0; i < test_set_size; i++) {
            test_indices.emplace_back(dist(gen));
        }
        std::sort(test_indices.begin(),
                  test_indices.end(),
                  std::greater<unsigned int>());

        std::vector<std::pair<std::vector<double>, std::vector<double>>> test_data, train_data;

        /* Copy all the data to train_data */
        for (auto e : *this->ds_full->get_data()) {
            train_data.emplace_back(e);
        }

        /* Move the testing data from train_data to test_data */
        for (auto ind : test_indices) {
            test_data.emplace_back(train_data.at(ind));
            train_data.erase(train_data.begin() + ind);
        }

        /* Re-initialize data set for training */
        this->ds = new DataSet(&train_data,
                               this->ds_full->get_normalization_strategy());

        /* Initialize test data */
        this->ds_test = new DataSet(&test_data,
                                    this->ds_full->get_normalization_strategy());
    }

    void ErrorFunction::return_full_data_set_for_training() {
        if (this->ds_test) {