Commit b4c6705d authored by Martin Beseda's avatar Martin Beseda

ENH: Implemented serialization in Network classes.

parent 82a94cda
......@@ -5,11 +5,17 @@
* @date 13.6.18 -
*/
#include <iostream>
#include <boost/random/mersenne_twister.hpp>
#include <boost/random/uniform_real_distribution.hpp>
#include "NeuralNetwork.h"
BOOST_CLASS_EXPORT(NeuronBinary);
BOOST_CLASS_EXPORT(NeuronConstant);
BOOST_CLASS_EXPORT(NeuronLinear);
BOOST_CLASS_EXPORT(NeuronLogistic);
BOOST_CLASS_EXPORT(NeuronLogistic_d1);
BOOST_CLASS_EXPORT(NeuronLogistic_d2);
BOOST_CLASS_EXPORT(ConnectionFunctionGeneral);
BOOST_CLASS_EXPORT(ConnectionFunctionIdentity);
NeuralNetwork::NeuralNetwork() {
this->neurons = new std::vector<Neuron*>(0);
this->neuron_biases = new std::vector<double>(0);
......@@ -32,10 +38,17 @@ NeuralNetwork::NeuralNetwork() {
this->layers_analyzed = false;
}
NeuralNetwork::NeuralNetwork(std::string filepath) {
std::ifstream ifs(filepath);
boost::archive::text_iarchive ia(ifs);
ia >> *this;
ifs.close();
}
NeuralNetwork::~NeuralNetwork() {
if(this->neurons){
for( auto n: *this->neurons ){
for( auto n: *(this->neurons) ){
delete n;
n = nullptr;
}
......@@ -589,7 +602,10 @@ void NeuralNetwork::print_weights() {
}
void NeuralNetwork::print_stats(){
printf("Number of neurons: %d, number of active weights: %d, number of active biases: %d\n", (int)this->neurons->size(), (int)this->connection_weights->size(), (int)this->neuron_biases->size());
std::cout << "Number of neurons: " << this->neurons->size() << std::endl
<< "Number of connections: " << this->connection_list->size() << std::endl
<< "Number of active weights: " << this->connection_weights->size() << std::endl
<< "Number of active biases: " << this->neuron_biases->size() << std::endl;
}
std::vector<double>* NeuralNetwork::get_parameter_ptr_biases() {
......@@ -762,4 +778,13 @@ void NeuralNetwork::analyze_layer_structure() {
}
this->layers_analyzed = true;
}
void NeuralNetwork::save_text(std::string filepath) {
std::ofstream ofs(filepath);
{
boost::archive::text_oarchive oa(ofs);
oa << *this;
ofs.close();
}
}
\ No newline at end of file
......@@ -11,19 +11,39 @@
#ifndef INC_4NEURO_NEURALNETWORK_H
#define INC_4NEURO_NEURALNETWORK_H
#include <iostream>
#include <vector>
#include <algorithm>
#include <utility>
#include "../Neuron/Neuron.h"
#include "../NetConnection/ConnectionFunctionGeneral.h"
#include "../NetConnection/ConnectionFunctionIdentity.h"
#include "../settings.h"
#include <fstream>
enum NET_TYPE{GENERAL};
#include <boost/random/mersenne_twister.hpp>
#include <boost/random/uniform_real_distribution.hpp>
enum BIAS_TYPE{NEXT_BIAS, NO_BIAS, EXISTING_BIAS};
#include <boost/archive/text_oarchive.hpp>
#include <boost/archive/text_iarchive.hpp>
#include <boost/serialization/list.hpp>
#include <boost/serialization/string.hpp>
#include <boost/serialization/version.hpp>
#include <boost/serialization/split_member.hpp>
#include <boost/serialization/export.hpp>
#include <boost/serialization/vector.hpp>
#include <boost/serialization/utility.hpp>
enum SIMPLE_CONNECTION_TYPE{NEXT_WEIGHT, UNITARY_WEIGHT, EXISTING_WEIGHT};
#include "Neuron/Neuron.h"
#include "Neuron/NeuronConstant.h"
#include "Neuron/NeuronBinary.h"
#include "Neuron/NeuronLinear.h"
#include "Neuron/NeuronLogistic.h"
#include "NetConnection/ConnectionFunctionGeneral.h"
#include "NetConnection/ConnectionFunctionIdentity.h"
#include "settings.h"
enum class BIAS_TYPE{NEXT_BIAS, NO_BIAS, EXISTING_BIAS};
enum class SIMPLE_CONNECTION_TYPE{NEXT_WEIGHT, UNITARY_WEIGHT, EXISTING_WEIGHT};
/**
......@@ -31,10 +51,7 @@ enum SIMPLE_CONNECTION_TYPE{NEXT_WEIGHT, UNITARY_WEIGHT, EXISTING_WEIGHT};
*/
class NeuralNetwork {
private:
/**
*
*/
NET_TYPE network_type = GENERAL;
friend class boost::serialization::access;
/**
*
......@@ -140,13 +157,36 @@ private:
*/
void analyze_layer_structure( );
template<class Archive>
void serialize(Archive & ar, const unsigned int version) {
ar & this->neurons;
ar & this->input_neuron_indices;
ar & this->output_neuron_indices;
ar & this->connection_list;
ar & this->neuron_biases;
ar & this-> neuron_bias_indices;
ar & this->neuron_potentials;
ar & this->connection_weights;
ar & this->inward_adjacency;
ar & this->outward_adjacency;
ar & this->neuron_layers_feedforward;
ar & this->neuron_layers_feedbackward;
ar & this->layers_analyzed;
ar & this->delete_weights;
ar & this->delete_biases;
};
public:
/**
*
*/
NeuralNetwork();
explicit NeuralNetwork();
/**
*
*/
explicit NeuralNetwork(std::string filepath);
/**
*
......@@ -184,14 +224,12 @@ public:
*/
virtual void eval_single(std::vector<double> &input, std::vector<double> &output, std::vector<double> *custom_weights_and_biases = nullptr);
/**
* Adds a new neuron to the list of neurons. Also assigns a valid bias value to its activation function
* @param[in] n
* @return
*/
size_t add_neuron(Neuron* n, BIAS_TYPE bt = NEXT_BIAS, size_t bias_idx = 0);
size_t add_neuron(Neuron* n, BIAS_TYPE bt = BIAS_TYPE::NEXT_BIAS, size_t bias_idx = 0);
/**
*
......@@ -199,7 +237,7 @@ public:
* @param n2_idx
* @return
*/
size_t add_connection_simple(size_t n1_idx, size_t n2_idx, SIMPLE_CONNECTION_TYPE sct = NEXT_WEIGHT, size_t weight_idx = 0 );
size_t add_connection_simple(size_t n1_idx, size_t n2_idx, SIMPLE_CONNECTION_TYPE sct = SIMPLE_CONNECTION_TYPE::NEXT_WEIGHT, size_t weight_idx = 0 );
/**
* Take the existing connection with index 'connection_idx' in 'parent_network' and adds it to the structure of this
......@@ -291,7 +329,13 @@ public:
* @return
*/
std::vector<double>* get_parameter_ptr_biases();
};
/**
*
* @param filepath
*/
void save_text(std::string filepath);
};
#endif //INC_4NEURO_NEURALNETWORK_H
......@@ -12,9 +12,18 @@
class NeuralNetworkSum : public NeuralNetwork {
private:
friend class boost::serialization::access;
std::vector<NeuralNetwork*> * summand;
std::vector<double> * summand_coefficient;
template <class Archive>
void serialize(Archive & ar, const unsigned int version) {
ar & boost::serialization::base_object<NeuralNetwork>(*this);
ar & this->summand;
ar & this->summand_coefficient;
};
public:
NeuralNetworkSum( );
virtual ~NeuralNetworkSum( );
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment