From 52ffe1d138257c6bf724e6e5b7dd0d160d0889b6 Mon Sep 17 00:00:00 2001 From: Martin Beseda <martinbeseda@seznam.cz> Date: Mon, 16 Jul 2018 01:05:52 +0200 Subject: [PATCH] ENH: Added automatic weight initialization --- src/Network/NeuralNetwork.cpp | 18 ++++++++++++++++++ src/Network/NeuralNetwork.h | 15 +++++++++++++++ src/net_test_1.cpp | 4 ++++ 3 files changed, 37 insertions(+) create mode 100644 src/net_test_1.cpp diff --git a/src/Network/NeuralNetwork.cpp b/src/Network/NeuralNetwork.cpp index 402584a3..dd1ec510 100644 --- a/src/Network/NeuralNetwork.cpp +++ b/src/Network/NeuralNetwork.cpp @@ -5,6 +5,8 @@ * @date 13.6.18 - */ +#include <boost/random/mersenne_twister.hpp> +#include <boost/random/uniform_real_distribution.hpp> #include "NeuralNetwork.h" #include "../NetConnection/ConnectionWeightIdentity.h" @@ -45,7 +47,23 @@ int NeuralNetwork::add_neuron(Neuron *n) { return (int)this->neurons->size() - 1; } +void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx) { + add_connection_simple(n1_idx, n2_idx, -1); +} + +void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx, int weight_idx) { + boost::random::mt19937 gen; + + // Init weight guess ("optimal" for logistic activation functions) + double r = 4 * sqrt(6./(this->n_inputs + this->n_outputs)); + + boost::random::uniform_real_distribution<> dist(-r, r); + + add_connection_simple(n1_idx, n2_idx, weight_idx, dist(gen)); +} + void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx, int weight_idx, double weight_value) { + // TODO generate weight_value automatically from normal distribution if(weight_idx < 0 || weight_idx >= this->connection_weights->size()){ //this weight is a new one, we add it to the system of weights diff --git a/src/Network/NeuralNetwork.h b/src/Network/NeuralNetwork.h index 11019b78..02f7e46a 100644 --- a/src/Network/NeuralNetwork.h +++ b/src/Network/NeuralNetwork.h @@ -99,6 +99,21 @@ public: */ int add_neuron(Neuron* n); + /** + * + * @param n1_idx + * @param n2_idx + */ + void add_connection_simple(int n1_idx, int n2_idx); + + /** + * + * @param n1_idx + * @param n2_idx + * @param weight_idx + */ + void add_connection_simple(int n1_idx, int n2_idx, int weight_idx); + /** * * @param[in] n1_idx diff --git a/src/net_test_1.cpp b/src/net_test_1.cpp new file mode 100644 index 00000000..e049e775 --- /dev/null +++ b/src/net_test_1.cpp @@ -0,0 +1,4 @@ +// +// Created by martin on 7/16/18. +// + -- GitLab