From 5de6c3de9b8bd9842f910f9b4b6c7434872a303d Mon Sep 17 00:00:00 2001 From: mor0146 <vojtech.moravec.st@vsb.cz> Date: Fri, 2 Aug 2019 11:02:07 +0200 Subject: [PATCH] Improved the cli class. --- czi/source_code/app/main.cpp | 6 +- .../include/azgra/cli/cli_arguments.h | 6 +- .../azgra_lib/include/azgra/cli/cli_option.h | 96 +++- .../include/azgra/utilities/vector_linq.h | 499 ++++++++++-------- .../azgra_lib/src/cli/cli_arguments.cpp | 57 +- 5 files changed, 414 insertions(+), 250 deletions(-) diff --git a/czi/source_code/app/main.cpp b/czi/source_code/app/main.cpp index 8ff41a0..72a4d5f 100644 --- a/czi/source_code/app/main.cpp +++ b/czi/source_code/app/main.cpp @@ -180,8 +180,10 @@ int main(azgra::i32 argc, const char **argv) azgra::cli::CliFlag flagBzip2CompressionOption("BZIP2", "bzip2 compression", '\0', "bzip2"); azgra::cli::CliFlag flagB3dCompressionOption("B3D", "B3D cuda compression", '\0', "b3d"); - azgra::cli::CliOptionGroup compressorGroup("CompressorGroup", {&flagGzipCompressionOption, &flagLzmaCompressionOption, - &flagBzip2CompressionOption, &flagB3dCompressionOption}); + azgra::cli::CliFlagGroup compressorGroup("CompressorGroup", + {&flagGzipCompressionOption, &flagLzmaCompressionOption, + &flagBzip2CompressionOption, &flagB3dCompressionOption}, + azgra::cli::CliGroupMatchPolicy_AtLeastOneIfGroupIsRequired); // Methods azgra::cli::CliMethod methodVersion("version", diff --git a/czi/source_code/azgra_lib/include/azgra/cli/cli_arguments.h b/czi/source_code/azgra_lib/include/azgra/cli/cli_arguments.h index 0bf68a4..981e880 100644 --- a/czi/source_code/azgra_lib/include/azgra/cli/cli_arguments.h +++ b/czi/source_code/azgra_lib/include/azgra/cli/cli_arguments.h @@ -4,7 +4,6 @@ #include <azgra/string/simple_string.h> #include <vector> -#include <sstream> #include <cstdio> #include <cassert> @@ -25,7 +24,7 @@ namespace azgra string::SmartStringView<char> appDescription; int outputWidth = 80; bool someMethodMatched = false; - std::vector<CliOptionGroup> groups; + std::vector<CliFlagGroup> groups; bool process_matched_flag(const string::SmartStringView<char> &match, bool shortMatch, const char **arguments, int &parseIndex); @@ -37,6 +36,7 @@ namespace azgra void print_flags(std::stringstream &outStream, const std::vector<CliOption *> &flags) const; std::vector<CliOption*> get_flags_not_in_group() const; + void mark_required_groups(); public: std::vector<CliMethod *> methods; @@ -44,7 +44,7 @@ namespace azgra CliArguments(const string::SmartStringView<char> &name, const string::SmartStringView<char> &description, int width = 80); - void add_group(const CliOptionGroup &flagGroup); + void add_group(CliFlagGroup &flagGroup); bool parse(const int argc, const char **argv); diff --git a/czi/source_code/azgra_lib/include/azgra/cli/cli_option.h b/czi/source_code/azgra_lib/include/azgra/cli/cli_option.h index 2feb3f4..8568d94 100644 --- a/czi/source_code/azgra_lib/include/azgra/cli/cli_option.h +++ b/czi/source_code/azgra_lib/include/azgra/cli/cli_option.h @@ -1,11 +1,21 @@ #pragma once +#include <sstream> #include <azgra/string/smart_string_view.h> +#include <azgra/utilities/vector_linq.h> namespace azgra { namespace cli { + enum CliGroupMatchPolicy + { + CliGroupMatchPolicy_All, + CliGroupMatchPolicy_AtLeastOne, + CliGroupMatchPolicy_AtLeastOneIfGroupIsRequired, + CliGroupMatchPolicy_NoPolicy + }; + class CliOption { friend class CliArguments; @@ -16,7 +26,7 @@ namespace azgra bool isMatched = false; bool isRequired = false; bool hasMatchCharacter = true; - bool isInGroup = false; + CliOption *group = nullptr; string::SmartStringView<char> name; string::SmartStringView<char> description; @@ -30,19 +40,38 @@ namespace azgra isRequired = required; } + void mark_as_matched() + { + isMatched = true; + if (group) + { + group->isMatched = true; + } + } + public: virtual ~CliOption() = default;; - bool is_matched() const + bool is_matched() const noexcept { return isMatched; } - operator bool() const + operator bool() const noexcept { return isMatched; } + + bool is_required() const noexcept + { + return isRequired; + } + + bool is_grouped() const noexcept + { + return (group != nullptr); + } }; class CliMethod : public CliOption @@ -114,20 +143,77 @@ namespace azgra } }; - class CliOptionGroup : public CliOption + class CliFlagGroup : public CliOption { friend class CliArguments; private: //NOTE: We can have group rules. Like AtLeastOne, OnlyOne, All std::vector<CliOption *> options; + CliGroupMatchPolicy m_matchPolicy = CliGroupMatchPolicy_NoPolicy; public: - CliOptionGroup(const string::SmartStringView<char> &name, const std::vector<CliOption *> &flagsInGroup) : + CliFlagGroup(const string::SmartStringView<char> &name, const std::vector<CliOption *> &flagsInGroup, + CliGroupMatchPolicy matchPolicy = CliGroupMatchPolicy_NoPolicy) : CliOption(name, "", '\0', "", false) { hasMatchCharacter = false; isRequired = false; options = flagsInGroup; + m_matchPolicy = matchPolicy; + } + + bool check_group_policy(std::stringstream &errorStream) const + { + std::function<bool(CliOption *)> testCondition = [](CliOption *flag) + { return flag->is_matched(); }; + + switch (m_matchPolicy) + { + case CliGroupMatchPolicy_All: + { + bool result = linq::all(options, testCondition); + if (!result) + { + errorStream << '<' << name.string_view() + << ">: CliGroupMatchPolicy_All requires all flags in group to be matched."; + } + return result; + //return linq::all(linq::where(options, filterFunction), testCondition); + } + + case CliGroupMatchPolicy_AtLeastOne: + { + bool result = linq::any(options, testCondition); + if (!result) + { + errorStream << '<' << name.string_view() + << "> CliGroupMatchPolicy_AtLeastOne requires at least one flag to be matched."; + } + return result; + } + case CliGroupMatchPolicy_AtLeastOneIfGroupIsRequired: + { + if (isRequired) + { + bool result = linq::any(options, testCondition); + if (!result) + { + errorStream << '<' << name.string_view() << + "> CliGroupMatchPolicy_AtLeastOneIfGroupIsRequired requires at least one flag to be matched," + << " when this group is required."; + } + return result; + } + return true; + } + case CliGroupMatchPolicy_NoPolicy: + { + return true; + } + default: + INVALID_CASE; + } + return false; } }; }; diff --git a/czi/source_code/azgra_lib/include/azgra/utilities/vector_linq.h b/czi/source_code/azgra_lib/include/azgra/utilities/vector_linq.h index 167f0f1..5ea6118 100644 --- a/czi/source_code/azgra_lib/include/azgra/utilities/vector_linq.h +++ b/czi/source_code/azgra_lib/include/azgra/utilities/vector_linq.h @@ -1,4 +1,5 @@ #pragma once + #include <vector> #include <functional> #include <limits> @@ -7,224 +8,282 @@ using namespace std; namespace azgra { - namespace linq - { - template <typename T> - vector<T> where(const std::vector<T>& src, function<bool(const T&)> predicate) - { - vector<T> result; - for (size_t i = 0; i < src.size(); i++) - { - if (predicate(src[i])) - { - result.push_back(src[i]); - } - } - return result; - } - - template <typename T> - bool any(const vector<T>& src, function<bool(const T&)> predicate) - { - for (size_t i = 0; i < src.size(); i++) - { - if (predicate(src[i])) - { - return true; - } - } - return false; - } - - template <typename T> - bool all(const vector<T>& src, function<bool(const T&)> predicate) - { - for (size_t i = 0; i < src.size(); i++) - { - if (!predicate(src[i])) - { - return false; - } - } - return true; - } - - template <typename T> - bool contains(const vector<T>& src, const T& element) - { - for (size_t i = 0; i < src.size(); i++) - { - if (src[i] == element) - return true; - } - return false; - } - - template <typename T> - bool contains(const vector<T>& src, const T& element, std::function<bool(const T&, const T&)> cmp) - { - for (size_t i = 0; i < src.size(); i++) - { - if (cmp(src[i], element)) - return true; - } - return false; - } - - template <typename T> - azgra::i64 get_index(const vector<T>& src, const T& element) - { - for (size_t i = 0; i < src.size(); i++) - { - if (src[i] == element) - return i; - } - return -1; - } - - template <typename T> - size_t count(const vector<T>& src, const T& element) - { - size_t result = 0; - for (size_t i = 0; i < src.size(); i++) - { - if (src[i] == element) - result += 1; - } - return result; - } - - template <typename T> - vector<T> except(const vector<T>& src, const vector<T>& except) - { - vector<T> result; - for (size_t i = 0; i < src.size(); i++) - { - if (!contains(except, src[i])) - { - result.push_back(src[i]); - } - } - return result; - } - - template <typename T> - T min(const vector<T>& src) - { - T min = std::numeric_limits<T>::max(); - - for (size_t i = 0; i < src.size(); i++) - { - if (src[i] < min) - { - min = src[i]; - } - } - return min; - } - - template <typename T> - T max(const vector<T>& src) - { - T max = std::numeric_limits<T>::min(); - - for (size_t i = 0; i < src.size(); i++) - { - if (src[i] > max) - { - max = src[i]; - } - } - return max; - } - - template <typename T> - std::pair<T, T> min_max(const vector<T>& src) - { - T min = std::numeric_limits<T>::max(); - T max = std::numeric_limits<T>::min(); - - for (size_t i = 0; i < src.size(); i++) - { - if (src[i] < min) - { - min = src[i]; - } - if (src[i] > max) - { - max = src[i]; - } - } - return make_pair(min, max); - } - - template <typename T> - T sum(const vector<T>& src) - { - T sum = 0; - for (size_t i = 0; i < src.size(); i++) - { - sum += src[i]; - } - return sum; - } - - template <typename TargetType, typename SourceType> - vector<TargetType> cast_to(const std::vector<SourceType>& src) - { - std::vector<TargetType> result(src.size()); - for (size_t i = 0; i < src.size(); i++) - { - result[i] = static_cast<TargetType>(src[i]); - } - return result; - } - - template <typename T> - std::vector<size_t> add_together(std::vector<T>& result, const std::vector<T>& add) - { - always_assert(result.size() == add.size()); - for (size_t i = 0; i < result.size(); i++) - { - result[i] += add[i]; - } - return result; - } - - template <typename T> - std::vector<size_t> div_by(std::vector<T>& result, size_t div) - { - - for (size_t i = 0; i < result.size(); i++) - { - result[i] = result[i] / div; - } - return result; - } - - template <typename T> - bool equals(const std::vector<T>& a, const std::vector<T>& b) - { - if (a.size() != b.size()) - return false; - - for (size_t i = 0; i < a.size(); i++) - { - if (a[i] != b[i]) - return false; - } - - return true; - } - - template <typename T> - bool equals_memcmp(const std::vector<T>& a, const std::vector<T>& b) - { - if (a.size() != b.size()) - return false; - - return (memcmp(a.data(), b.data(), sizeof(T) * a.size()) == 0); - } - - }; // namespace linq + namespace linq + { + template<typename T> + vector<T> where_ref(const std::vector<T> &src, function<bool(const T &)> predicate) + { + vector<T> result; + for (size_t i = 0; i < src.size(); i++) + { + if (predicate(src[i])) + { + result.push_back(src[i]); + } + } + return result; + } + + template<typename T> + vector<T> where(const std::vector<T> &src, function<bool(const T)> predicate) + { + vector<T> result; + for (size_t i = 0; i < src.size(); i++) + { + if (predicate(src[i])) + { + result.push_back(src[i]); + } + } + return result; + } + + template<typename T> + bool any_ref(const vector<T> &src, function<bool(const T &)> predicate) + { + for (size_t i = 0; i < src.size(); i++) + { + if (predicate(src[i])) + { + return true; + } + } + return false; + } + + template<typename T> + bool any(const vector<T> &src, function<bool(const T)> predicate) + { + for (size_t i = 0; i < src.size(); i++) + { + if (predicate(src[i])) + { + return true; + } + } + return false; + } + + template<typename T> + bool for_each_ref(const vector<T> &src, function<void(const T &)> work) + { + for (size_t i = 0; i < src.size(); i++) + { + work(src[i]); + } + } + + template<typename T> + bool for_each(const vector<T> &src, function<void(const T)> work) + { + for (size_t i = 0; i < src.size(); i++) + { + work(src[i]); + } + } + + template<typename T> + bool all_ref(const vector<T> &src, function<bool(const T &)> predicate) + { + for (size_t i = 0; i < src.size(); i++) + { + if (!predicate(src[i])) + { + return false; + } + } + return true; + } + + template<typename T> + bool all(const vector<T> &src, function<bool(const T)> predicate) + { + for (size_t i = 0; i < src.size(); i++) + { + if (!predicate(src[i])) + { + return false; + } + } + return true; + } + + template<typename T> + bool contains(const vector<T> &src, const T &element) + { + for (size_t i = 0; i < src.size(); i++) + { + if (src[i] == element) + return true; + } + return false; + } + + template<typename T> + bool contains(const vector<T> &src, const T &element, std::function<bool(const T &, const T &)> cmp) + { + for (size_t i = 0; i < src.size(); i++) + { + if (cmp(src[i], element)) + return true; + } + return false; + } + + template<typename T> + azgra::i64 get_index(const vector<T> &src, const T &element) + { + for (size_t i = 0; i < src.size(); i++) + { + if (src[i] == element) + return i; + } + return -1; + } + + template<typename T> + size_t count(const vector<T> &src, const T &element) + { + size_t result = 0; + for (size_t i = 0; i < src.size(); i++) + { + if (src[i] == element) + result += 1; + } + return result; + } + + template<typename T> + vector<T> except(const vector<T> &src, const vector<T> &except) + { + vector<T> result; + for (size_t i = 0; i < src.size(); i++) + { + if (!contains(except, src[i])) + { + result.push_back(src[i]); + } + } + return result; + } + + template<typename T> + T min(const vector<T> &src) + { + T min = std::numeric_limits<T>::max(); + + for (size_t i = 0; i < src.size(); i++) + { + if (src[i] < min) + { + min = src[i]; + } + } + return min; + } + + template<typename T> + T max(const vector<T> &src) + { + T max = std::numeric_limits<T>::min(); + + for (size_t i = 0; i < src.size(); i++) + { + if (src[i] > max) + { + max = src[i]; + } + } + return max; + } + + template<typename T> + std::pair<T, T> min_max(const vector<T> &src) + { + T min = std::numeric_limits<T>::max(); + T max = std::numeric_limits<T>::min(); + + for (size_t i = 0; i < src.size(); i++) + { + if (src[i] < min) + { + min = src[i]; + } + if (src[i] > max) + { + max = src[i]; + } + } + return make_pair(min, max); + } + + template<typename T> + T sum(const vector<T> &src) + { + T sum = 0; + for (size_t i = 0; i < src.size(); i++) + { + sum += src[i]; + } + return sum; + } + + template<typename TargetType, typename SourceType> + vector<TargetType> cast_to(const std::vector<SourceType> &src) + { + std::vector<TargetType> result(src.size()); + for (size_t i = 0; i < src.size(); i++) + { + result[i] = static_cast<TargetType>(src[i]); + } + return result; + } + + template<typename T> + std::vector<size_t> add_together(std::vector<T> &result, const std::vector<T> &add) + { + always_assert(result.size() == add.size()); + for (size_t i = 0; i < result.size(); i++) + { + result[i] += add[i]; + } + return result; + } + + template<typename T> + std::vector<size_t> div_by(std::vector<T> &result, size_t div) + { + + for (size_t i = 0; i < result.size(); i++) + { + result[i] = result[i] / div; + } + return result; + } + + template<typename T> + bool equals(const std::vector<T> &a, const std::vector<T> &b) + { + if (a.size() != b.size()) + return false; + + for (size_t i = 0; i < a.size(); i++) + { + if (a[i] != b[i]) + return false; + } + + return true; + } + + template<typename T> + bool equals_memcmp(const std::vector<T> &a, const std::vector<T> &b) + { + if (a.size() != b.size()) + return false; + + return (memcmp(a.data(), b.data(), sizeof(T) * a.size()) == 0); + } + + }; // namespace linq }; // namespace azgra \ No newline at end of file diff --git a/czi/source_code/azgra_lib/src/cli/cli_arguments.cpp b/czi/source_code/azgra_lib/src/cli/cli_arguments.cpp index 63e78d9..ec20bc3 100644 --- a/czi/source_code/azgra_lib/src/cli/cli_arguments.cpp +++ b/czi/source_code/azgra_lib/src/cli/cli_arguments.cpp @@ -21,6 +21,7 @@ namespace azgra if ((!shortMatch && flag->matchString == match.substring(2)) || (shortMatch && flag->hasMatchCharacter && flag->matchCharacter == match[1])) { + flag->mark_as_matched(); matchedFlag = flag; break; } @@ -38,7 +39,6 @@ namespace azgra if (normalFlag) { //fprintf(stdout, "Captured normal flag: %s\n", normalFlag->name.data()); - normalFlag->isMatched = true; return true; } else @@ -53,7 +53,6 @@ namespace azgra bool CliArguments::process_matched_value_flag(CliOption *matchedFlag, const char *rawFlagValue) { // Supported types: int, uint, string, float - matchedFlag->isMatched = true; auto *intFlag = dynamic_cast<CliValueFlag<int> *> (matchedFlag); if (intFlag) { @@ -108,7 +107,7 @@ namespace azgra if (method->name == match) { //fprintf(stdout, "Captured method: %s\n", method->name.data()); - method->isMatched = true; + method->mark_as_matched(); someMethodMatched = true; return true; } @@ -128,7 +127,7 @@ namespace azgra if (flag->hasMatchCharacter && flag->matchCharacter == match[i]) { found = true; - flag->isMatched = true; + flag->mark_as_matched(); //fprintf(stdout, "MultiFlag: matched flag: %s\n", flag->name.data()); break; } @@ -140,29 +139,21 @@ namespace azgra std::vector<CliOption *> CliArguments::get_flags_not_in_group() const { - std::vector<CliOption *> result; - for (CliOption *flag : flags) - { - if (!flag->isInGroup) - { - result.push_back(flag); - } - } + std::function<bool(CliOption *)> filter = [](CliOption *flag) + { return !(flag->is_grouped()); }; + + std::vector<CliOption *> result = linq::where(flags, filter); return result; } void CliArguments::print_flags(std::stringstream &outStream, const std::vector<CliOption *> &flagsToPrint) const { - - for (const CliOption *flag : flagsToPrint) { const char *req = (flag->isRequired ? "Required" : "Optional"); - - if (flag->hasMatchCharacter) { - char matchCharString[2] = {flag->matchCharacter, '\0'}; + char matchCharString[2] = {flag->matchCharacter, '\0'}; string::SimpleString matcherString({"{-", matchCharString, ", --", flag->matchString.data(), "} "}); @@ -214,9 +205,31 @@ namespace azgra } } + void CliArguments::mark_required_groups() + { + function<bool(CliOption *)> filter_isGroup = [](CliOption *option) + { + CliFlagGroup *group = dynamic_cast<CliFlagGroup *>(option); + return (group != nullptr); + }; + function<void(CliOption *)> work_markAsRequired = [](CliOption *option) + { + option->isRequired = true; + }; + + for (const CliMethod *method : methods) + { + auto methodRequiredFlags = method->get_required_flags(); + auto requiredGroups = linq::where(methodRequiredFlags, filter_isGroup); + linq::for_each(requiredGroups, work_markAsRequired); + } + } bool CliArguments::parse(const int argc, const char **argv) { + // Lest mark required groups from methods. + mark_required_groups(); + const char *helpIdentifier = "--help"; const char *helpIdentifierShort = "-h"; @@ -288,6 +301,10 @@ namespace azgra } } + for (const CliFlagGroup &group : groups) + { + result &= group.check_group_policy(errorStream); + } return result; } @@ -337,7 +354,7 @@ namespace azgra helpStream << "Flags:" << std::endl; print_flags(helpStream, get_flags_not_in_group()); - for (const CliOptionGroup &group : groups) + for (const CliFlagGroup &group : groups) { helpStream << "Flags - " << group.name.string_view() << std::endl; print_flags(helpStream, group.options); @@ -357,12 +374,12 @@ namespace azgra return someMethodMatched; } - void CliArguments::add_group(const CliOptionGroup &flagGroup) + void CliArguments::add_group(CliFlagGroup &flagGroup) { groups.push_back(flagGroup); for (CliOption *flag : flagGroup.options) { - flag->isInGroup = true; + flag->group = &flagGroup; flags.push_back(flag); } } -- GitLab