Skip to content
Snippets Groups Projects
Commit 52ffe1d1 authored by Martin Beseda's avatar Martin Beseda
Browse files

ENH: Added automatic weight initialization

parent c5f4959c
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
* @date 13.6.18 - * @date 13.6.18 -
*/ */
#include <boost/random/mersenne_twister.hpp>
#include <boost/random/uniform_real_distribution.hpp>
#include "NeuralNetwork.h" #include "NeuralNetwork.h"
#include "../NetConnection/ConnectionWeightIdentity.h" #include "../NetConnection/ConnectionWeightIdentity.h"
...@@ -45,7 +47,23 @@ int NeuralNetwork::add_neuron(Neuron *n) { ...@@ -45,7 +47,23 @@ int NeuralNetwork::add_neuron(Neuron *n) {
return (int)this->neurons->size() - 1; return (int)this->neurons->size() - 1;
} }
void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx) {
add_connection_simple(n1_idx, n2_idx, -1);
}
void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx, int weight_idx) {
boost::random::mt19937 gen;
// Init weight guess ("optimal" for logistic activation functions)
double r = 4 * sqrt(6./(this->n_inputs + this->n_outputs));
boost::random::uniform_real_distribution<> dist(-r, r);
add_connection_simple(n1_idx, n2_idx, weight_idx, dist(gen));
}
void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx, int weight_idx, double weight_value) { void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx, int weight_idx, double weight_value) {
// TODO generate weight_value automatically from normal distribution
if(weight_idx < 0 || weight_idx >= this->connection_weights->size()){ if(weight_idx < 0 || weight_idx >= this->connection_weights->size()){
//this weight is a new one, we add it to the system of weights //this weight is a new one, we add it to the system of weights
......
...@@ -99,6 +99,21 @@ public: ...@@ -99,6 +99,21 @@ public:
*/ */
int add_neuron(Neuron* n); int add_neuron(Neuron* n);
/**
*
* @param n1_idx
* @param n2_idx
*/
void add_connection_simple(int n1_idx, int n2_idx);
/**
*
* @param n1_idx
* @param n2_idx
* @param weight_idx
*/
void add_connection_simple(int n1_idx, int n2_idx, int weight_idx);
/** /**
* *
* @param[in] n1_idx * @param[in] n1_idx
......
//
// Created by martin on 7/16/18.
//
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment