From 52ffe1d138257c6bf724e6e5b7dd0d160d0889b6 Mon Sep 17 00:00:00 2001
From: Martin Beseda <martinbeseda@seznam.cz>
Date: Mon, 16 Jul 2018 01:05:52 +0200
Subject: [PATCH] ENH: Added automatic weight initialization

---
 src/Network/NeuralNetwork.cpp | 18 ++++++++++++++++++
 src/Network/NeuralNetwork.h   | 15 +++++++++++++++
 src/net_test_1.cpp            |  4 ++++
 3 files changed, 37 insertions(+)
 create mode 100644 src/net_test_1.cpp

diff --git a/src/Network/NeuralNetwork.cpp b/src/Network/NeuralNetwork.cpp
index 402584a3..dd1ec510 100644
--- a/src/Network/NeuralNetwork.cpp
+++ b/src/Network/NeuralNetwork.cpp
@@ -5,6 +5,8 @@
  * @date 13.6.18 - 
  */
 
+#include <boost/random/mersenne_twister.hpp>
+#include <boost/random/uniform_real_distribution.hpp>
 #include "NeuralNetwork.h"
 #include "../NetConnection/ConnectionWeightIdentity.h"
 
@@ -45,7 +47,23 @@ int NeuralNetwork::add_neuron(Neuron *n) {
     return (int)this->neurons->size() - 1;
 }
 
+void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx) {
+    add_connection_simple(n1_idx, n2_idx, -1);
+}
+
+void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx, int weight_idx) {
+    boost::random::mt19937 gen;
+
+    // Init weight guess ("optimal" for logistic activation functions)
+    double r = 4 * sqrt(6./(this->n_inputs + this->n_outputs));
+
+    boost::random::uniform_real_distribution<> dist(-r, r);
+
+    add_connection_simple(n1_idx, n2_idx, weight_idx, dist(gen));
+}
+
 void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx, int weight_idx, double weight_value) {
+    // TODO generate weight_value automatically from normal distribution
 
     if(weight_idx < 0 || weight_idx >= this->connection_weights->size()){
         //this weight is a new one, we add it to the system of weights
diff --git a/src/Network/NeuralNetwork.h b/src/Network/NeuralNetwork.h
index 11019b78..02f7e46a 100644
--- a/src/Network/NeuralNetwork.h
+++ b/src/Network/NeuralNetwork.h
@@ -99,6 +99,21 @@ public:
      */
     int add_neuron(Neuron* n);
 
+    /**
+     *
+     * @param n1_idx
+     * @param n2_idx
+     */
+    void add_connection_simple(int n1_idx, int n2_idx);
+
+    /**
+     *
+     * @param n1_idx
+     * @param n2_idx
+     * @param weight_idx
+     */
+    void add_connection_simple(int n1_idx, int n2_idx, int weight_idx);
+
     /**
      *
      * @param[in] n1_idx
diff --git a/src/net_test_1.cpp b/src/net_test_1.cpp
new file mode 100644
index 00000000..e049e775
--- /dev/null
+++ b/src/net_test_1.cpp
@@ -0,0 +1,4 @@
+//
+// Created by martin on 7/16/18.
+//
+
-- 
GitLab