Skip to content
Snippets Groups Projects
DataSet.h 7.03 KiB
Newer Older
  • Learn to ignore specific revisions
  • Martin Beseda's avatar
    Martin Beseda committed
    //
    // Created by martin on 7/13/18.
    //
    
    //TODO generovani dat
    
    Martin Beseda's avatar
    Martin Beseda committed
    
    #ifndef INC_4NEURO_DATASET_H
    #define INC_4NEURO_DATASET_H
    
    
    #include <utility>
    #include <vector>
    #include <boost/serialization/base_object.hpp>
    #include <boost/range/size_type.hpp>
    
    #include <functional>
    
    
    /**
     * Class representing an error caused by an incorrect
     * input/output dimension specification
     */
    class InvalidDimension: public std::runtime_error {
    public:
    
        /**
         * Constructor with the general error message
         */
        InvalidDimension();
    
        /**
         * Constructor with specific error message
         * @param msg Specific error message
         */
        explicit InvalidDimension(std::string msg);
    };
    
    
    /**
     * Class representing data, which can be used for training
     * and testing purposes.
     */
    class DataSet {
        friend class boost::serialization::access;
    
    private:
        /**
         * Number of elements in the data set
         */
        size_t n_elements;
    
    
        /**
         * Dimension of the input
         */
        unsigned int input_dim = 0;
    
        /**
         * Dimension of the output
         */
        unsigned int output_dim = 0;
    
    
        /**
         * Stored data in the format of pairs of corresponding
         * input and output vectors
         */
        std::vector<std::pair<std::vector<double>, std::vector<double>>> data;
    
    
        template <class T>
        std::vector<std::vector<T>> cartesian_product(const std::vector<std::vector<T>>* v);
    
    
    protected:
        /**
         * Serialization function
         * @tparam Archive Boost library template
         * @param ar Boost parameter - filled automatically during serialization!
         * @param version Boost parameter - filled automatically during serialization!
         */
        template<class Archive>
        void serialize(Archive & ar, const unsigned int version){
            if(Archive::is_loading::value) {
                /* LOADING data */
                ar & this->n_elements;
    
                std::vector<std::pair<std::vector<double>, std::vector<double>>> data_tmp;
                double tmp;
    
    
                /* INPUT dimension */
                size_t input_dim;
                ar & input_dim;
    
                /* OUTPUT dimension */
                size_t output_dim;
                ar & output_dim;
    
    
                for(unsigned int i=0; i < this->n_elements; i++) {
    
                    std::vector<double> inputs;
    
                        ar & tmp;
                        inputs.push_back(tmp);
                    }
    
    
                    /* OUTPUT vector */
    
                    std::vector<double> outputs;
    
                        ar & tmp;
                        outputs.push_back(tmp);
                    }
                    /* Append to the data vector */
                    data_tmp.emplace_back(std::make_pair(inputs, outputs));
                }
    
                this->data = data_tmp;
    
            } else {
                /* STORING data */
    
                ar & this->n_elements;
    
                size_t dim_inp, dim_out;
    
                /* INPUT dimension */
                dim_inp = std::get<0>(this->data[0]).size();
                ar & dim_inp;
    
                /* OUTPUT dimension */
                dim_out = std::get<1>(this->data[0]).size();
                ar & dim_out;
    
    
                for(const auto p : this->data) {
                    /* Input vector */
                    for(auto val : std::get<0>(p)) {
                        ar & val;
                    }
    
                    /* Output vector */
                    for(auto val : std::get<1>(p)) {
                        ar & val;
                    }
                }
            }
        };
    
    public:
    
        /**
         * Constructor reading data from the file
         * @param file_path Path to the file with stored data set
         */
        DataSet(std::string file_path);
    
        /**
         * Constructor accepting data vector
         * @param data_ptr Pointer to the vector containing data
         */
        DataSet(std::vector<std::pair<std::vector<double>, std::vector<double>>>* data_ptr);
    
    
        /**
         * Creates a new data set with input values equidistantly positioned
         * over the certain interval and the output value
         * being constant
         *
         * Both input and output are 1-dimensional
         *
         * @todo add bounds as vectors for multi-dimensional data-sets
         *
         * @param lower_bound Lower bound of the input data interval
         * @param upper_bound Upper bound of the input data interval
         * @param size Number of input-output pairs generated
         * @param output Constant output value
         */
        DataSet(double lower_bound, double upper_bound, unsigned int size, double output);
    
    
        DataSet(std::vector<double> bounds, unsigned int no_elems_in_one_dim, std::vector<double> (*output_func)(std::vector<double>), unsigned int output_dim);
    
    
        /**
         * Getter for number of elements
         * @return Number of elements in the data set
         */
        size_t get_n_elements();
    
    
        /**
         * Returns the input dimension
         * @return Input dimension
         */
        unsigned int get_input_dim();
    
    
        /**
         * Return the output dimension
         * @return Output dimension
         */
        unsigned int get_output_dim();
    
    
        /**
         * Getter for the data structure
         * @return Vector of data
         */
        std::vector<std::pair<std::vector<double>, std::vector<double>>>* get_data();
    
        /**
         * Adds a new pair of data to the data set
         * @param inputs Vector of input data
         * @param outputs Vector of output data corresponding to the input data
         */
        void add_data_pair(std::vector<double> inputs, std::vector<double> outputs);
    
    
        //TODO expand method to generate multiple data types - chebyshev etc.
        /**
         * Adds a new data with input values equidistantly positioned
         * over the certain interval and the output value
         * being constant
         *
         * Both input and output are 1-dimensional
         *
         * @param lower_bound Lower bound of the input data interval
         * @param upper_bound Upper bound of the input data interval
         * @param size Number of input-output pairs generated
         * @param output Constant output value
         */
        void add_isotropic_data(double lower_bound, double upper_bound, unsigned int size, double output);
    
    
        /**
         * Adds a new data with input values equidistantly positioned
         * over the certain interval and the output value
         * being constant
         *
         * Input can have arbitrary many dimensions,
         * output can be an arbitrary function
         *
         * @param bounds Odd values are lower bounds and even values are corresponding upper bounds
         * @param size Number of input-output pairs generated
         * @param output_func Function determining output value
         */
        void add_isotropic_data(std::vector<double> bounds, unsigned int no_elems_in_one_dim, std::vector<double> (*output_func)(std::vector<double>));
    
    
        //TODO Chebyshev - ch. interpolation points, i-th point = cos(i*alpha) from 0 to pi
    
    
        /**
         * Prints the data set
         */
        void print_data();
    
        /**
         * Stores the DataSet object to the binary file
         */
        void store_text(std::string file_path);
    };
    
    
    Martin Beseda's avatar
    Martin Beseda committed
    #endif //INC_4NEURO_DATASET_H