Commit 0c9173d5 authored by Michal Kravcenko's avatar Michal Kravcenko
Browse files

-reworked the way weights are stored in NNs

parent a3c47005
......@@ -5,6 +5,7 @@
* @date 14.6.18 -
*/
#include "ConnectionWeight.h"
ConnectionWeight::ConnectionWeight() {
......@@ -12,42 +13,44 @@ ConnectionWeight::ConnectionWeight() {
}
ConnectionWeight::ConnectionWeight(int param_count, std::function<double(double **, int)> *f) {
this->param_ptrs = new double*[param_count];
ConnectionWeight::ConnectionWeight(int param_count, std::vector<double>* w_array, std::function<double(double *, int*, int)> *f) {
this->param_indices = new int[param_count];
this->n_params = param_count;
this->weight_array = w_array;
this->weight_function = f;
}
ConnectionWeight::~ConnectionWeight() {
if(this->param_ptrs){
delete [] this->param_ptrs;
this->param_ptrs = nullptr;
if(this->param_indices){
delete [] this->param_indices;
this->param_indices = nullptr;
}
}
void ConnectionWeight::adjust_weights(double *values) {
for(int i = 0; i < this->n_params; ++i){
(*this->param_ptrs[i]) += values[i];
this->weight_array->at(this->param_indices[i]) += values[i];
}
}
void ConnectionWeight::set_weights(double *values) {
for(int i = 0; i < this->n_params; ++i){
(*this->param_ptrs[i]) = values[i];
this->weight_array->at(this->param_indices[i]) = values[i];
}
}
void ConnectionWeight::SetParamPointer(double **param_ptr) {
void ConnectionWeight::SetParamIndices(int *param_indices) {
for(int i = 0; i < this->n_params; ++i){
this->param_ptrs[i] = param_ptr[i];
this->SetParamIndex(param_indices[i], i);
}
}
void ConnectionWeight::SetParamPointer(double *param_ptr, int idx) {
this->param_ptrs[idx] = param_ptr;
void ConnectionWeight::SetParamIndex(int value, int idx) {
this->param_indices[idx] = value;
}
double ConnectionWeight::eval() {
return (*this->weight_function)(this->param_ptrs, this->n_params);
return (*this->weight_function)(&this->weight_array->at(0),this->param_indices, this->n_params);
}
\ No newline at end of file
......@@ -9,13 +9,19 @@
#define INC_4NEURO_CONNECTIONWEIGHT_H
#include <functional>
#include <vector>
class ConnectionWeight {
protected:
/**
*
*/
double ** param_ptrs = nullptr;
std::vector<double>* weight_array = nullptr;
/**
*
*/
int* param_indices = nullptr;
/**
*
......@@ -25,7 +31,7 @@ protected:
/**
*
*/
std::function<double(double **, int)> *weight_function = nullptr;
std::function<double(double *, int*, int)> *weight_function = nullptr;
public:
......@@ -39,20 +45,20 @@ public:
* @param param_count
* @param f
*/
ConnectionWeight(int param_count, std::function<double(double **, int)> *f);
ConnectionWeight(int param_count, std::vector<double>* w_array, std::function<double(double *, int*, int)> *f);
/**
*
* @param param_ptr
* @param value
* @param idx
*/
void SetParamPointer(double *param_ptr, int idx);
void SetParamIndex(int value, int idx);
/**
*
* @param param_ptr
*/
void SetParamPointer(double **param_ptr);
void SetParamIndices(int* param_ptr);
/**
*
......
......@@ -7,11 +7,16 @@
#include "ConnectionWeightIdentity.h"
ConnectionWeightIdentity::ConnectionWeightIdentity() {
ConnectionWeightIdentity::ConnectionWeightIdentity(std::vector<double>* w_array) {
this->n_params = 1;
this->param_ptrs = new double*[1];
this->weight_array = w_array;
this->param_indices = new int[1];
}
double ConnectionWeightIdentity::eval() {
return (*this->param_ptrs[0]);
double a = this->weight_array->at(this->param_indices[0]);
return a;
}
\ No newline at end of file
......@@ -22,7 +22,7 @@ public:
/**
*
*/
ConnectionWeightIdentity();
ConnectionWeightIdentity(std::vector<double>* w_array);
/**
*
......
......@@ -11,6 +11,9 @@
NeuralNetwork::NeuralNetwork() {
this->neurons = new std::vector<Neuron*>(0);
this->connection_weights = new std::vector<double>(0);
//TODO tady pozor, pri nedostatecne alokaci se pri pridani hrany nad limit reserve prealokuje cele pole a tim padem padaji reference na vahy v ConnectionWeight
this->connection_weights->reserve(0);
}
NeuralNetwork::~NeuralNetwork() {
......@@ -53,8 +56,8 @@ void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx, int weight_idx
Neuron *neuron_out = this->neurons->at(n1_idx);
Neuron *neuron_in = this->neurons->at(n2_idx);
ConnectionWeightIdentity *con_weight_u1u2 = new ConnectionWeightIdentity();
con_weight_u1u2->SetParamPointer(&this->connection_weights->at(weight_idx), 0);
ConnectionWeightIdentity *con_weight_u1u2 = new ConnectionWeightIdentity(this->connection_weights);
con_weight_u1u2->SetParamIndex(weight_idx, 0);
Connection *u1u2 = new Connection(neuron_out, neuron_in, con_weight_u1u2);
......@@ -62,9 +65,9 @@ void NeuralNetwork::add_connection_simple(int n1_idx, int n2_idx, int weight_idx
neuron_in->add_connection_in(u1u2);
}
void NeuralNetwork::add_connection_general(int n1_idx, int n2_idx, std::function<double(double **, int)> *f, int* weight_indices, double* weight_values, int n_weights) {
void NeuralNetwork::add_connection_general(int n1_idx, int n2_idx, std::function<double(double *, int*, int)> *f, int* weight_indices, double* weight_values, int n_weights) {
ConnectionWeight *con_weight_u1u2 = new ConnectionWeight(n_weights, f);
ConnectionWeight *con_weight_u1u2 = new ConnectionWeight(n_weights, this->connection_weights, f);
//we analyze weights
int weight_idx = 0;
double weight_value = 0.0;
......@@ -78,7 +81,7 @@ void NeuralNetwork::add_connection_general(int n1_idx, int n2_idx, std::function
weight_indices[wi] = (int)this->connection_weights->size() - 1;
}
con_weight_u1u2->SetParamPointer(&this->connection_weights->at(weight_indices[wi]), wi);
con_weight_u1u2->SetParamIndex(weight_indices[wi], wi);
}
Neuron *neuron_out = this->neurons->at(n1_idx);
......
......@@ -59,11 +59,17 @@ private:
*/
bool in_out_determined = false;
/**
*
*/
std::vector<Neuron*>* active_eval_set = nullptr;
/**
*
*/
void determine_inputs_outputs();
/**
*
*/
void determine_inputs_outputs();
public:
......@@ -111,7 +117,7 @@ public:
* @param weight_values
* @param n_weights
*/
void add_connection_general(int n1_idx, int n2_idx, std::function<double(double **, int)> *f, int* weight_indices, double* weight_values, int n_weights);
void add_connection_general(int n1_idx, int n2_idx, std::function<double(double *, int*, int)> *f, int* weight_indices, double* weight_values, int n_weights);
......
......@@ -34,12 +34,11 @@ void test1(){
////////////////////// END SIMPLE EDGE WEIGHT ////////////////////////////////////////
/////////////////////////BEGIN OF COMPLEX EDGE WEIGHT//////////////////////////////
//TODO vyresit memleak
std::function<double(double **, int)> weight_function = [](double ** params, int n_params){
std::function<double(double *, int*, int)> weight_function = [](double * weight_array, int * index_array, int n_params){
//w(x, y) = x + y
double a = (*(params[0]));
double b = (*(params[1]));
printf("eval: %f, %f\n", a, b);
double a = weight_array[index_array[0]];
double b = weight_array[index_array[1]];
// printf("eval: %f, %f\n", a, b);
return (a + 0.0 * b);
};
int weight_indices [2]= {0, -1};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment