Skip to content
Snippets Groups Projects
Commit 82a94cda authored by Martin Beseda's avatar Martin Beseda
Browse files

ENH: Added serialization to Connection classes + modified acces to std::pairs

parent 255cf7c2
No related branches found
No related tags found
No related merge requests found
......@@ -31,13 +31,13 @@ double MSE::eval(std::vector<double> *weights) {
for(unsigned int i = 0; i < n_elements; ++i){ // Iterate through every element in the test set
this->net->eval_single(std::get<0>(data->at(i)), output, weights); // Compute the net output and store it into 'output' variable
this->net->eval_single(data->at(i).first, output, weights); // Compute the net output and store it into 'output' variable
// printf("errors: ");
for(unsigned int j = 0; j < dim_out; ++j) { // Compute difference for every element of the output vector
val = output[j] - std::get<1>(data->at(i))[j];
val = output[j] - data->at(i).second[j];
error += val * val;
// printf("%f, ", val * val);
......
......@@ -17,9 +17,9 @@
#include <iterator>
#include <algorithm>
#include "../Network/NeuralNetwork.h"
#include "../DataSet/DataSet.h"
#include "../ErrorFunction/ErrorFunctions.h"
#include "Network/NeuralNetwork.h"
#include "DataSet/DataSet.h"
#include "ErrorFunction/ErrorFunctions.h"
class Particle{
......
......@@ -5,13 +5,9 @@
* @date 14.6.18 -
*/
#include "ConnectionFunctionGeneral.h"
ConnectionFunctionGeneral::ConnectionFunctionGeneral() {
}
ConnectionFunctionGeneral::ConnectionFunctionGeneral() {}
ConnectionFunctionGeneral::ConnectionFunctionGeneral(std::vector<size_t > &param_indices, std::string &function_string) {
this->param_indices = param_indices;
......
......@@ -8,10 +8,22 @@
#ifndef INC_4NEURO_CONNECTIONWEIGHT_H
#define INC_4NEURO_CONNECTIONWEIGHT_H
#include <boost/archive/text_oarchive.hpp>
#include <boost/archive/text_iarchive.hpp>
#include <boost/serialization/export.hpp>
#include <boost/serialization/vector.hpp>
#include <functional>
#include <vector>
class ConnectionFunctionGeneral {
private:
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive & ar, const unsigned int version) {
ar & this->param_indices;
};
protected:
/**
......@@ -51,7 +63,6 @@ public:
*/
virtual void eval_partial_derivative( std::vector<double> &parameter_space, std::vector<double> &weight_gradient, double alpha );
};
......
......@@ -8,10 +8,12 @@
#include "ConnectionFunctionIdentity.h"
ConnectionFunctionIdentity::ConnectionFunctionIdentity( ) {
// this->type = CONNECTION_TYPE::IDENTITY;
this->is_unitary = true;
}
ConnectionFunctionIdentity::ConnectionFunctionIdentity( size_t pidx ) {
// this->type = CONNECTION_TYPE::IDENTITY;
this->param_idx = pidx;
this->is_unitary = false;
}
......
......@@ -16,12 +16,23 @@ class ConnectionFunctionGeneral;
*
*/
class ConnectionFunctionIdentity:public ConnectionFunctionGeneral {
friend class boost::serialization::access;
friend class NeuralNetwork;
private:
size_t param_idx = 0;
bool is_unitary = false;
protected:
template<class Archive>
void serialize(Archive & ar, const unsigned int version){
ar & boost::serialization::base_object<ConnectionFunctionGeneral>(*this);
ar & this->param_idx;
ar & this->is_unitary;
};
public:
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment