diff --git a/src/ErrorFunction/ErrorFunctions.cpp b/src/ErrorFunction/ErrorFunctions.cpp index db09859032b596ecaaa9a7260d67908b690bc794..81cf3daaafd03ad8e20123ba1f0b5e482da9af6a 100644 --- a/src/ErrorFunction/ErrorFunctions.cpp +++ b/src/ErrorFunction/ErrorFunctions.cpp @@ -31,13 +31,13 @@ double MSE::eval(std::vector<double> *weights) { for(unsigned int i = 0; i < n_elements; ++i){ // Iterate through every element in the test set - this->net->eval_single(std::get<0>(data->at(i)), output, weights); // Compute the net output and store it into 'output' variable + this->net->eval_single(data->at(i).first, output, weights); // Compute the net output and store it into 'output' variable // printf("errors: "); for(unsigned int j = 0; j < dim_out; ++j) { // Compute difference for every element of the output vector - val = output[j] - std::get<1>(data->at(i))[j]; + val = output[j] - data->at(i).second[j]; error += val * val; // printf("%f, ", val * val); diff --git a/src/LearningMethods/ParticleSwarm.h b/src/LearningMethods/ParticleSwarm.h index c3bdf0fe2a1f1e89e52de88e90ea4593cf529193..2e91e8efe42ad122305ef8f2420987d15b0c4057 100644 --- a/src/LearningMethods/ParticleSwarm.h +++ b/src/LearningMethods/ParticleSwarm.h @@ -17,9 +17,9 @@ #include <iterator> #include <algorithm> -#include "../Network/NeuralNetwork.h" -#include "../DataSet/DataSet.h" -#include "../ErrorFunction/ErrorFunctions.h" +#include "Network/NeuralNetwork.h" +#include "DataSet/DataSet.h" +#include "ErrorFunction/ErrorFunctions.h" class Particle{ diff --git a/src/NetConnection/ConnectionFunctionGeneral.cpp b/src/NetConnection/ConnectionFunctionGeneral.cpp index 74093c7ebe3a38bd4d0986bd821cfb5cd26c20e0..9b1b3e645541284985ea7011158d0033ca59a5aa 100644 --- a/src/NetConnection/ConnectionFunctionGeneral.cpp +++ b/src/NetConnection/ConnectionFunctionGeneral.cpp @@ -5,13 +5,9 @@ * @date 14.6.18 - */ - #include "ConnectionFunctionGeneral.h" -ConnectionFunctionGeneral::ConnectionFunctionGeneral() { - - -} +ConnectionFunctionGeneral::ConnectionFunctionGeneral() {} ConnectionFunctionGeneral::ConnectionFunctionGeneral(std::vector<size_t > ¶m_indices, std::string &function_string) { this->param_indices = param_indices; diff --git a/src/NetConnection/ConnectionFunctionGeneral.h b/src/NetConnection/ConnectionFunctionGeneral.h index 0c4afe0a5b63d7c80b873824baf299eec332c5d9..8f25e89ffaaccf98f5f4514bdd86406d0df6d3b0 100644 --- a/src/NetConnection/ConnectionFunctionGeneral.h +++ b/src/NetConnection/ConnectionFunctionGeneral.h @@ -8,10 +8,22 @@ #ifndef INC_4NEURO_CONNECTIONWEIGHT_H #define INC_4NEURO_CONNECTIONWEIGHT_H +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <boost/serialization/export.hpp> +#include <boost/serialization/vector.hpp> #include <functional> #include <vector> class ConnectionFunctionGeneral { +private: + friend class boost::serialization::access; + + template <class Archive> + void serialize(Archive & ar, const unsigned int version) { + ar & this->param_indices; + }; + protected: /** @@ -51,7 +63,6 @@ public: */ virtual void eval_partial_derivative( std::vector<double> ¶meter_space, std::vector<double> &weight_gradient, double alpha ); - }; diff --git a/src/NetConnection/ConnectionFunctionIdentity.cpp b/src/NetConnection/ConnectionFunctionIdentity.cpp index 3716ea75dc8ff4beb89e11de357dc1df5339ccd6..bd535cbf94af0e1b019abc5c5dee4592d9c67c43 100644 --- a/src/NetConnection/ConnectionFunctionIdentity.cpp +++ b/src/NetConnection/ConnectionFunctionIdentity.cpp @@ -8,10 +8,12 @@ #include "ConnectionFunctionIdentity.h" ConnectionFunctionIdentity::ConnectionFunctionIdentity( ) { +// this->type = CONNECTION_TYPE::IDENTITY; this->is_unitary = true; } ConnectionFunctionIdentity::ConnectionFunctionIdentity( size_t pidx ) { +// this->type = CONNECTION_TYPE::IDENTITY; this->param_idx = pidx; this->is_unitary = false; } diff --git a/src/NetConnection/ConnectionFunctionIdentity.h b/src/NetConnection/ConnectionFunctionIdentity.h index 104fc58049a137f240ecabacb3b0d4a7bba6ec5b..000279cc683048fce04d94f7ddfb53d411a5d705 100644 --- a/src/NetConnection/ConnectionFunctionIdentity.h +++ b/src/NetConnection/ConnectionFunctionIdentity.h @@ -16,12 +16,23 @@ class ConnectionFunctionGeneral; * */ class ConnectionFunctionIdentity:public ConnectionFunctionGeneral { + friend class boost::serialization::access; + friend class NeuralNetwork; + private: size_t param_idx = 0; bool is_unitary = false; +protected: + template<class Archive> + void serialize(Archive & ar, const unsigned int version){ + ar & boost::serialization::base_object<ConnectionFunctionGeneral>(*this); + ar & this->param_idx; + ar & this->is_unitary; + }; + public: /**