Skip to content
Snippets Groups Projects
main.cpp 2.16 KiB
Newer Older
  • Learn to ignore specific revisions
  • Michal Kravcenko's avatar
    Michal Kravcenko committed
     * This file serves for testing of various examples, have fun!
    
     *
     * @author Michal Kravčenko
     * @date 14.6.18 -
     */
    
    
    #include <iostream>
    #include <cstdio>
    #include <fstream>
    
    #include <vector>
    #include <utility>
    
    #include <algorithm>
    
    #include "../CrossValidator/CrossValidator.h"
    
        l4n::CSVReader reader("/home/martin/Desktop/ANN_DATA_1_SET.txt", "\t", true);
        reader.read();
    
    
        /* Create data set for both the training and testing of the neural network */
    
        std::vector<unsigned int> inputs = {2,3,4,5,6,7,8,26,27,28};
        std::vector<unsigned int> outputs = {17,18,19,20,21,22,23,24,25};
        l4n::DataSet ds = reader.get_data_set(&inputs, &outputs);
    
    
        /* Normalize data in the set for easier training of the network */
        ds.normalize();
    
    
        /* Neural network construction */
    
        std::vector<unsigned int> neuron_numbers_in_layers = {10,10,10,9};
        l4n::FullyConnectedFFN nn(&neuron_numbers_in_layers, l4n::NEURON_TYPE::LOGISTIC);
    
    
        /* Error function */
        l4n::MSE mse(&nn, &ds);
    
        /* Domain */
        std::vector<double> domain_bounds(2 * (nn.get_n_weights() + nn.get_n_biases()));
    
        /* Training method */
    
    //    for(size_t i = 0; i < domain_bounds.size() / 2; ++i){
    //        domain_bounds[2 * i] = -10;
    //        domain_bounds[2 * i + 1] = 10;
    //    }
    //    l4n::ParticleSwarm ps(&domain_bounds,
    //                          1.711897,
    //                          1.711897,
    //                          0.711897,
    //                          0.5,
    //                          20,
    //                          0.7,
    //                          600,
    //                          1000);
        l4n::GradientDescent gs(1e-3, 1);
    
        nn.randomize_weights();
    
    
        /* Cross - validation */
        l4n::CrossValidator cv(&gs, &mse);
        cv.run_k_fold_test(10, 3);
    
    
        /* Save network to the file */
        nn.save_text("test_net.4n");
    
        /* Check of the saved network */
        std::cout << "The original network info:" << std::endl;
        nn.print_stats();
    
        l4n::NeuralNetwork nn_loaded("test_net.4n");
        std::cout << "The loaded network info:" << std::endl;
        nn_loaded.print_stats();
    
    
    Martin Beseda's avatar
    Martin Beseda committed
    }