Skip to content
Snippets Groups Projects
Commit 27c92693 authored by Martin Beseda's avatar Martin Beseda
Browse files

[ENH] Added functions to get DataSet outputs and inputs as Armadillo matrices.

parent e2258c44
No related branches found
No related tags found
No related merge requests found
#include <algorithm>
#include <armadillo>
#include <boost/serialization/export.hpp>
#include "DataSetSerialization.h"
#include "exceptions.h"
#include "DataSet.h"
BOOST_CLASS_EXPORT_IMPLEMENT(lib4neuro::DataSet);
......@@ -34,7 +38,6 @@ namespace lib4neuro {
}
this->normalization_strategy = std::make_shared<DoubleUnitStrategy>(DoubleUnitStrategy());
}
DataSet::DataSet(std::vector<std::pair<std::vector<double>, std::vector<double>>>* data_ptr,
......@@ -497,4 +500,27 @@ namespace lib4neuro {
}
this->output_dim += n_columns;
}
arma::Mat<double>* DataSet::get_inputs_matrix() {
this->inputs_matrix = new arma::Mat<double>(this->data.size(), this->data.at(0).first.size());
// arma::Mat<double> m(this->data.size(), this->data.at(0).first.size());
for (size_t i = 0; i < this->data.size(); i++) {
this->inputs_matrix->row(i) = arma::Row<double>(this->data.at(i).first);
}
// this->inputs_matrix = &m;
return this->inputs_matrix;
}
arma::Mat<double>* DataSet::get_outputs_matrix() {
this->outputs_matrix = new arma::Mat<double>(this->data.size(), this->data.at(0).second.size());
for(size_t i = 0; i < this->data.size(); i++) {
this->outputs_matrix->row(i) = arma::Row<double>(this->data.at(i).second);
}
// this->outputs_matrix = &m;
return this->outputs_matrix;
}
}
......@@ -14,6 +14,11 @@
#include "../settings.h"
#include "../NormalizationStrategy/NormalizationStrategy.h"
/* Forward declaration or arma::Mat<> type */
namespace arma {
template<class T>
class Mat;
}
namespace lib4neuro {
/**
......@@ -68,6 +73,10 @@ namespace lib4neuro {
//TODO let user choose in the constructor!
std::shared_ptr<NormalizationStrategy> normalization_strategy;
arma::Mat<double>* inputs_matrix;
arma::Mat<double>* outputs_matrix;
public:
......@@ -314,6 +323,10 @@ namespace lib4neuro {
* @param n_columns Number of columns to be inserted
*/
LIB4NEURO_API void add_zero_output_columns(size_t n_columns);
[[nodiscard]] LIB4NEURO_API arma::Mat<double>* get_inputs_matrix();
[[nodiscard]] LIB4NEURO_API arma::Mat<double>* get_outputs_matrix();
};
}
#endif //INC_4NEURO_DATASET_H
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment