From 27c926931860dbba275541e5155ee6c9ebc64850 Mon Sep 17 00:00:00 2001 From: Martin Beseda <martin.beseda@vsb.cz> Date: Sat, 14 Sep 2019 15:42:59 +0200 Subject: [PATCH] [ENH] Added functions to get DataSet outputs and inputs as Armadillo matrices. --- src/DataSet/DataSet.cpp | 28 +++++++++++++++++++++++++++- src/DataSet/DataSet.h | 13 +++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/DataSet/DataSet.cpp b/src/DataSet/DataSet.cpp index e65a67e0..8bf39621 100644 --- a/src/DataSet/DataSet.cpp +++ b/src/DataSet/DataSet.cpp @@ -1,9 +1,13 @@ #include <algorithm> + +#include <armadillo> #include <boost/serialization/export.hpp> #include "DataSetSerialization.h" #include "exceptions.h" +#include "DataSet.h" + BOOST_CLASS_EXPORT_IMPLEMENT(lib4neuro::DataSet); @@ -34,7 +38,6 @@ namespace lib4neuro { } this->normalization_strategy = std::make_shared<DoubleUnitStrategy>(DoubleUnitStrategy()); - } DataSet::DataSet(std::vector<std::pair<std::vector<double>, std::vector<double>>>* data_ptr, @@ -497,4 +500,27 @@ namespace lib4neuro { } this->output_dim += n_columns; } + + arma::Mat<double>* DataSet::get_inputs_matrix() { + this->inputs_matrix = new arma::Mat<double>(this->data.size(), this->data.at(0).first.size()); +// arma::Mat<double> m(this->data.size(), this->data.at(0).first.size()); + + for (size_t i = 0; i < this->data.size(); i++) { + this->inputs_matrix->row(i) = arma::Row<double>(this->data.at(i).first); + } + +// this->inputs_matrix = &m; + return this->inputs_matrix; + } + + arma::Mat<double>* DataSet::get_outputs_matrix() { + this->outputs_matrix = new arma::Mat<double>(this->data.size(), this->data.at(0).second.size()); + + for(size_t i = 0; i < this->data.size(); i++) { + this->outputs_matrix->row(i) = arma::Row<double>(this->data.at(i).second); + } + +// this->outputs_matrix = &m; + return this->outputs_matrix; + } } diff --git a/src/DataSet/DataSet.h b/src/DataSet/DataSet.h index f253d9da..c14311b6 100644 --- a/src/DataSet/DataSet.h +++ b/src/DataSet/DataSet.h @@ -14,6 +14,11 @@ #include "../settings.h" #include "../NormalizationStrategy/NormalizationStrategy.h" +/* Forward declaration or arma::Mat<> type */ +namespace arma { + template<class T> + class Mat; +} namespace lib4neuro { /** @@ -68,6 +73,10 @@ namespace lib4neuro { //TODO let user choose in the constructor! std::shared_ptr<NormalizationStrategy> normalization_strategy; + arma::Mat<double>* inputs_matrix; + + arma::Mat<double>* outputs_matrix; + public: @@ -314,6 +323,10 @@ namespace lib4neuro { * @param n_columns Number of columns to be inserted */ LIB4NEURO_API void add_zero_output_columns(size_t n_columns); + + [[nodiscard]] LIB4NEURO_API arma::Mat<double>* get_inputs_matrix(); + + [[nodiscard]] LIB4NEURO_API arma::Mat<double>* get_outputs_matrix(); }; } #endif //INC_4NEURO_DATASET_H -- GitLab