From 27c926931860dbba275541e5155ee6c9ebc64850 Mon Sep 17 00:00:00 2001
From: Martin Beseda <martin.beseda@vsb.cz>
Date: Sat, 14 Sep 2019 15:42:59 +0200
Subject: [PATCH] [ENH] Added functions to get DataSet outputs and inputs as
 Armadillo matrices.

---
 src/DataSet/DataSet.cpp | 28 +++++++++++++++++++++++++++-
 src/DataSet/DataSet.h   | 13 +++++++++++++
 2 files changed, 40 insertions(+), 1 deletion(-)

diff --git a/src/DataSet/DataSet.cpp b/src/DataSet/DataSet.cpp
index e65a67e0..8bf39621 100644
--- a/src/DataSet/DataSet.cpp
+++ b/src/DataSet/DataSet.cpp
@@ -1,9 +1,13 @@
 
 #include <algorithm>
+
+#include <armadillo>
 #include <boost/serialization/export.hpp>
 
 #include "DataSetSerialization.h"
 #include "exceptions.h"
+#include "DataSet.h"
+
 
 BOOST_CLASS_EXPORT_IMPLEMENT(lib4neuro::DataSet);
 
@@ -34,7 +38,6 @@ namespace lib4neuro {
         }
 
         this->normalization_strategy = std::make_shared<DoubleUnitStrategy>(DoubleUnitStrategy());
-
     }
 
     DataSet::DataSet(std::vector<std::pair<std::vector<double>, std::vector<double>>>* data_ptr,
@@ -497,4 +500,27 @@ namespace lib4neuro {
 		}
 		this->output_dim += n_columns;
 	}
+
+    arma::Mat<double>* DataSet::get_inputs_matrix() {
+        this->inputs_matrix = new arma::Mat<double>(this->data.size(), this->data.at(0).first.size());
+//        arma::Mat<double> m(this->data.size(), this->data.at(0).first.size());
+
+        for (size_t i = 0; i < this->data.size(); i++) {
+            this->inputs_matrix->row(i) = arma::Row<double>(this->data.at(i).first);
+        }
+
+//        this->inputs_matrix = &m;
+        return this->inputs_matrix;
+    }
+
+    arma::Mat<double>* DataSet::get_outputs_matrix() {
+        this->outputs_matrix = new arma::Mat<double>(this->data.size(), this->data.at(0).second.size());
+
+        for(size_t i = 0; i < this->data.size(); i++) {
+            this->outputs_matrix->row(i) = arma::Row<double>(this->data.at(i).second);
+        }
+
+//        this->outputs_matrix = &m;
+        return this->outputs_matrix;
+    }
 }
diff --git a/src/DataSet/DataSet.h b/src/DataSet/DataSet.h
index f253d9da..c14311b6 100644
--- a/src/DataSet/DataSet.h
+++ b/src/DataSet/DataSet.h
@@ -14,6 +14,11 @@
 #include "../settings.h"
 #include "../NormalizationStrategy/NormalizationStrategy.h"
 
+/* Forward declaration or arma::Mat<> type */
+namespace arma {
+    template<class T>
+    class Mat;
+}
 
 namespace lib4neuro {
     /**
@@ -68,6 +73,10 @@ namespace lib4neuro {
         //TODO let user choose in the constructor!
         std::shared_ptr<NormalizationStrategy> normalization_strategy;
 
+        arma::Mat<double>* inputs_matrix;
+
+        arma::Mat<double>* outputs_matrix;
+
 
     public:
 
@@ -314,6 +323,10 @@ namespace lib4neuro {
 		 * @param n_columns Number of columns to be inserted
 		 */
 		LIB4NEURO_API void add_zero_output_columns(size_t n_columns);
+
+        [[nodiscard]] LIB4NEURO_API arma::Mat<double>* get_inputs_matrix();
+
+        [[nodiscard]] LIB4NEURO_API arma::Mat<double>* get_outputs_matrix();
     };
 }
 #endif //INC_4NEURO_DATASET_H
-- 
GitLab