Skip to content
Snippets Groups Projects
network_serialization.cpp 2.46 KiB
Newer Older
  • Learn to ignore specific revisions
  • /**
     * Example of saving neural network to a file and loading it.
     * Network creation and training is copied from net_test_1.
     *
     * @author Martin Beseda
     * @date 9.8.18
     */
    
    #include <vector>
    #include "4neuro.h"
    
    int main() {
        /* TRAIN DATA DEFINITION */
        std::vector<std::pair<std::vector<double>, std::vector<double>>> data_vec;
        std::vector<double> inp, out;
    
        inp = {0, 1};
        out = {0.5};
        data_vec.emplace_back(std::make_pair(inp, out));
    
        inp = {1, 0.5};
        out = {0.75};
        data_vec.emplace_back(std::make_pair(inp, out));
    
        DataSet ds(&data_vec);
    
        /* NETWORK DEFINITION */
        NeuralNetwork net;
    
        /* Input neurons */
        NeuronLinear *i1 = new NeuronLinear( );  //f(x) = x
        NeuronLinear *i2 = new NeuronLinear( );  //f(x) = x
    
        /* Output neuron */
        double b = 1.0;//bias
        NeuronLinear *o1 = new NeuronLinear( );  //f(x) = x + 1
    
    
    
        /* Adding neurons to the net */
        size_t idx1 = net.add_neuron(i1, BIAS_TYPE::NO_BIAS);
        size_t idx2 = net.add_neuron(i2, BIAS_TYPE::NO_BIAS);
        size_t idx3 = net.add_neuron(o1, BIAS_TYPE::NEXT_BIAS);
    
        std::vector<double> *bv = net.get_parameter_ptr_biases();
        for(size_t i = 0; i < 1; ++i){
            bv->at(i) = 1.0;
        }
    
        /* Adding connections */
        net.add_connection_simple(idx1, idx3, SIMPLE_CONNECTION_TYPE::NEXT_WEIGHT);
        net.add_connection_simple(idx2, idx3, SIMPLE_CONNECTION_TYPE::NEXT_WEIGHT);
    
        //net.randomize_weights();
    
        /* specification of the input/output neurons */
        std::vector<size_t> net_input_neurons_indices(2);
        std::vector<size_t> net_output_neurons_indices(1);
        net_input_neurons_indices[0] = idx1;
        net_input_neurons_indices[1] = idx2;
    
        net_output_neurons_indices[0] = idx3;
    
        net.specify_input_neurons(net_input_neurons_indices);
        net.specify_output_neurons(net_output_neurons_indices);
    
        /* ERROR FUNCTION SPECIFICATION */
        MSE mse(&net, &ds);
    
        /* TRAINING METHOD SETUP */
        unsigned int max_iters = 20;
    
        double domain_bounds[4] = {-800.0, 800.0, -800.0, 800.0};
    
        double c1 = 0.5, c2 = 1.5, w = 0.8;
    
        unsigned int n_particles = 10;
    
        ParticleSwarm swarm_01(&mse, domain_bounds, c1, c2, w, n_particles, max_iters);
    
        swarm_01.optimize(0.5, 0.02);
    
        /* SAVE NETWORK TO THE FILE */
        std::cout << "Network 1" << std::endl;
        net.print_stats();
        net.save_text("saved_network.4nt");
    
        std::cout << "Network 2" << std::endl;
        NeuralNetwork net2("saved_network.4nt");
        net2.print_stats();
    
        return 0;
    }