Skip to content
Snippets Groups Projects
Commit 82a94cda authored by Martin Beseda's avatar Martin Beseda
Browse files

ENH: Added serialization to Connection classes + modified acces to std::pairs

parent 255cf7c2
No related branches found
No related tags found
No related merge requests found
...@@ -31,13 +31,13 @@ double MSE::eval(std::vector<double> *weights) { ...@@ -31,13 +31,13 @@ double MSE::eval(std::vector<double> *weights) {
for(unsigned int i = 0; i < n_elements; ++i){ // Iterate through every element in the test set for(unsigned int i = 0; i < n_elements; ++i){ // Iterate through every element in the test set
this->net->eval_single(std::get<0>(data->at(i)), output, weights); // Compute the net output and store it into 'output' variable this->net->eval_single(data->at(i).first, output, weights); // Compute the net output and store it into 'output' variable
// printf("errors: "); // printf("errors: ");
for(unsigned int j = 0; j < dim_out; ++j) { // Compute difference for every element of the output vector for(unsigned int j = 0; j < dim_out; ++j) { // Compute difference for every element of the output vector
val = output[j] - std::get<1>(data->at(i))[j]; val = output[j] - data->at(i).second[j];
error += val * val; error += val * val;
// printf("%f, ", val * val); // printf("%f, ", val * val);
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#include <iterator> #include <iterator>
#include <algorithm> #include <algorithm>
#include "../Network/NeuralNetwork.h" #include "Network/NeuralNetwork.h"
#include "../DataSet/DataSet.h" #include "DataSet/DataSet.h"
#include "../ErrorFunction/ErrorFunctions.h" #include "ErrorFunction/ErrorFunctions.h"
class Particle{ class Particle{
......
...@@ -5,13 +5,9 @@ ...@@ -5,13 +5,9 @@
* @date 14.6.18 - * @date 14.6.18 -
*/ */
#include "ConnectionFunctionGeneral.h" #include "ConnectionFunctionGeneral.h"
ConnectionFunctionGeneral::ConnectionFunctionGeneral() { ConnectionFunctionGeneral::ConnectionFunctionGeneral() {}
}
ConnectionFunctionGeneral::ConnectionFunctionGeneral(std::vector<size_t > &param_indices, std::string &function_string) { ConnectionFunctionGeneral::ConnectionFunctionGeneral(std::vector<size_t > &param_indices, std::string &function_string) {
this->param_indices = param_indices; this->param_indices = param_indices;
......
...@@ -8,10 +8,22 @@ ...@@ -8,10 +8,22 @@
#ifndef INC_4NEURO_CONNECTIONWEIGHT_H #ifndef INC_4NEURO_CONNECTIONWEIGHT_H
#define INC_4NEURO_CONNECTIONWEIGHT_H #define INC_4NEURO_CONNECTIONWEIGHT_H
#include <boost/archive/text_oarchive.hpp>
#include <boost/archive/text_iarchive.hpp>
#include <boost/serialization/export.hpp>
#include <boost/serialization/vector.hpp>
#include <functional> #include <functional>
#include <vector> #include <vector>
class ConnectionFunctionGeneral { class ConnectionFunctionGeneral {
private:
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive & ar, const unsigned int version) {
ar & this->param_indices;
};
protected: protected:
/** /**
...@@ -51,7 +63,6 @@ public: ...@@ -51,7 +63,6 @@ public:
*/ */
virtual void eval_partial_derivative( std::vector<double> &parameter_space, std::vector<double> &weight_gradient, double alpha ); virtual void eval_partial_derivative( std::vector<double> &parameter_space, std::vector<double> &weight_gradient, double alpha );
}; };
......
...@@ -8,10 +8,12 @@ ...@@ -8,10 +8,12 @@
#include "ConnectionFunctionIdentity.h" #include "ConnectionFunctionIdentity.h"
ConnectionFunctionIdentity::ConnectionFunctionIdentity( ) { ConnectionFunctionIdentity::ConnectionFunctionIdentity( ) {
// this->type = CONNECTION_TYPE::IDENTITY;
this->is_unitary = true; this->is_unitary = true;
} }
ConnectionFunctionIdentity::ConnectionFunctionIdentity( size_t pidx ) { ConnectionFunctionIdentity::ConnectionFunctionIdentity( size_t pidx ) {
// this->type = CONNECTION_TYPE::IDENTITY;
this->param_idx = pidx; this->param_idx = pidx;
this->is_unitary = false; this->is_unitary = false;
} }
......
...@@ -16,12 +16,23 @@ class ConnectionFunctionGeneral; ...@@ -16,12 +16,23 @@ class ConnectionFunctionGeneral;
* *
*/ */
class ConnectionFunctionIdentity:public ConnectionFunctionGeneral { class ConnectionFunctionIdentity:public ConnectionFunctionGeneral {
friend class boost::serialization::access;
friend class NeuralNetwork;
private: private:
size_t param_idx = 0; size_t param_idx = 0;
bool is_unitary = false; bool is_unitary = false;
protected:
template<class Archive>
void serialize(Archive & ar, const unsigned int version){
ar & boost::serialization::base_object<ConnectionFunctionGeneral>(*this);
ar & this->param_idx;
ar & this->is_unitary;
};
public: public:
/** /**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment