|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <algorithm> |
| 4 | +#include <charconv> |
| 5 | +#include <stdexcept> |
| 6 | +#include <string> |
| 7 | +#include <string_view> |
| 8 | +#include <system_error> |
| 9 | + |
| 10 | +namespace examples |
| 11 | +{ |
| 12 | +namespace cli |
| 13 | +{ |
| 14 | + |
| 15 | +inline std::string |
| 16 | +normalize_option( std::string option ) |
| 17 | +{ |
| 18 | + if( option.rfind( "--", 0 ) != 0 ) |
| 19 | + return option; |
| 20 | + |
| 21 | + const auto eq_pos = option.find( '=' ); |
| 22 | + const auto end = eq_pos == std::string::npos ? option.size() : eq_pos; |
| 23 | + std::replace( option.begin() + 2, option.begin() + static_cast<std::ptrdiff_t>( end ), '_', '-' ); |
| 24 | + return option; |
| 25 | +} |
| 26 | + |
| 27 | +inline bool |
| 28 | +is_positional( std::string_view arg ) |
| 29 | +{ |
| 30 | + return !arg.empty() && arg.front() != '-'; |
| 31 | +} |
| 32 | + |
| 33 | +inline int |
| 34 | +parse_int( const std::string& label, const std::string& value ) |
| 35 | +{ |
| 36 | + int result = 0; |
| 37 | + const char* begin = value.data(); |
| 38 | + const char* end = begin + value.size(); |
| 39 | + const auto [ptr, ec] = std::from_chars( begin, end, result ); |
| 40 | + if( ec != std::errc() || ptr != end ) |
| 41 | + throw std::invalid_argument( "Invalid value for " + label + ": '" + value + "'" ); |
| 42 | + return result; |
| 43 | +} |
| 44 | + |
| 45 | +class ArgParser |
| 46 | +{ |
| 47 | +public: |
| 48 | + ArgParser( int argc, char** argv ) : argc_( argc ), argv_( argv ) {} |
| 49 | + |
| 50 | + bool |
| 51 | + empty() const |
| 52 | + { |
| 53 | + return index_ >= argc_; |
| 54 | + } |
| 55 | + |
| 56 | + std::string_view |
| 57 | + peek() const |
| 58 | + { |
| 59 | + if( empty() ) |
| 60 | + return std::string_view{}; |
| 61 | + return argv_[index_]; |
| 62 | + } |
| 63 | + |
| 64 | + std::string |
| 65 | + take() |
| 66 | + { |
| 67 | + if( empty() ) |
| 68 | + throw std::out_of_range( "No more arguments to consume" ); |
| 69 | + return std::string( argv_[index_++] ); |
| 70 | + } |
| 71 | + |
| 72 | + bool |
| 73 | + consume_flag( std::string_view long_name, std::string_view short_name = {} ) |
| 74 | + { |
| 75 | + if( empty() ) |
| 76 | + return false; |
| 77 | + |
| 78 | + std::string current = normalize_option( std::string( peek() ) ); |
| 79 | + if( current == long_name || ( !short_name.empty() && current == short_name ) ) |
| 80 | + { |
| 81 | + ++index_; |
| 82 | + return true; |
| 83 | + } |
| 84 | + return false; |
| 85 | + } |
| 86 | + |
| 87 | + bool |
| 88 | + consume_option( std::string_view long_name, std::string& value ) |
| 89 | + { |
| 90 | + if( empty() ) |
| 91 | + return false; |
| 92 | + |
| 93 | + std::string current = normalize_option( std::string( peek() ) ); |
| 94 | + const std::string prefix = std::string( long_name ) + "="; |
| 95 | + if( current == long_name ) |
| 96 | + { |
| 97 | + ++index_; |
| 98 | + if( empty() ) |
| 99 | + throw std::invalid_argument( "Missing value for option '" + std::string( long_name ) + "'" ); |
| 100 | + value = take(); |
| 101 | + return true; |
| 102 | + } |
| 103 | + if( current.rfind( prefix, 0 ) == 0 ) |
| 104 | + { |
| 105 | + value = current.substr( prefix.size() ); |
| 106 | + ++index_; |
| 107 | + return true; |
| 108 | + } |
| 109 | + return false; |
| 110 | + } |
| 111 | + |
| 112 | +private: |
| 113 | + int argc_ = 0; |
| 114 | + char** argv_ = nullptr; |
| 115 | + int index_ = 1; |
| 116 | +}; |
| 117 | + |
| 118 | +} // namespace cli |
| 119 | +} // namespace examples |
| 120 | + |
| 121 | +namespace examples |
| 122 | +{ |
| 123 | +namespace cli |
| 124 | +{ |
| 125 | + |
| 126 | +struct SolverOptions |
| 127 | +{ |
| 128 | + bool show_help = false; |
| 129 | + std::string solver = "ilqr"; |
| 130 | +}; |
| 131 | + |
| 132 | +inline SolverOptions |
| 133 | +parse_solver_options( int argc, char** argv, std::string default_solver = "ilqr" ) |
| 134 | +{ |
| 135 | + SolverOptions options; |
| 136 | + options.solver = std::move( default_solver ); |
| 137 | + ArgParser args( argc, argv ); |
| 138 | + |
| 139 | + while( !args.empty() ) |
| 140 | + { |
| 141 | + const std::string raw_arg = std::string( args.peek() ); |
| 142 | + if( args.consume_flag( "--help", "-h" ) ) |
| 143 | + { |
| 144 | + options.show_help = true; |
| 145 | + continue; |
| 146 | + } |
| 147 | + |
| 148 | + std::string value; |
| 149 | + if( args.consume_option( "--solver", value ) ) |
| 150 | + { |
| 151 | + options.solver = value; |
| 152 | + continue; |
| 153 | + } |
| 154 | + |
| 155 | + throw std::invalid_argument( "Unknown argument '" + raw_arg + "'" ); |
| 156 | + } |
| 157 | + |
| 158 | + return options; |
| 159 | +} |
| 160 | + |
| 161 | +struct MultiAgentOptions |
| 162 | +{ |
| 163 | + bool show_help = false; |
| 164 | + int agents = 10; |
| 165 | + int max_outer = 10; |
| 166 | + std::string solver = "ilqr"; |
| 167 | + std::string strategy = "centralized"; |
| 168 | +}; |
| 169 | + |
| 170 | +inline MultiAgentOptions |
| 171 | +parse_multi_agent_options( int argc, char** argv, MultiAgentOptions defaults = {} ) |
| 172 | +{ |
| 173 | + MultiAgentOptions options = std::move( defaults ); |
| 174 | + ArgParser args( argc, argv ); |
| 175 | + bool positional_agents = false; |
| 176 | + |
| 177 | + while( !args.empty() ) |
| 178 | + { |
| 179 | + const std::string raw_arg = std::string( args.peek() ); |
| 180 | + if( args.consume_flag( "--help", "-h" ) ) |
| 181 | + { |
| 182 | + options.show_help = true; |
| 183 | + continue; |
| 184 | + } |
| 185 | + |
| 186 | + std::string value; |
| 187 | + if( args.consume_option( "--agents", value ) ) |
| 188 | + { |
| 189 | + options.agents = parse_int( "--agents", value ); |
| 190 | + continue; |
| 191 | + } |
| 192 | + if( args.consume_option( "--solver", value ) ) |
| 193 | + { |
| 194 | + options.solver = value; |
| 195 | + continue; |
| 196 | + } |
| 197 | + if( args.consume_option( "--strategy", value ) ) |
| 198 | + { |
| 199 | + options.strategy = value; |
| 200 | + continue; |
| 201 | + } |
| 202 | + if( args.consume_option( "--max-outer", value ) ) |
| 203 | + { |
| 204 | + options.max_outer = parse_int( "--max-outer", value ); |
| 205 | + continue; |
| 206 | + } |
| 207 | + |
| 208 | + if( is_positional( raw_arg ) && !positional_agents ) |
| 209 | + { |
| 210 | + args.take(); |
| 211 | + options.agents = parse_int( "agents", raw_arg ); |
| 212 | + positional_agents = true; |
| 213 | + continue; |
| 214 | + } |
| 215 | + |
| 216 | + throw std::invalid_argument( "Unknown argument '" + raw_arg + "'" ); |
| 217 | + } |
| 218 | + |
| 219 | + return options; |
| 220 | +} |
| 221 | + |
| 222 | +struct RocketOptions |
| 223 | +{ |
| 224 | + bool show_help = false; |
| 225 | + bool dump_traces = false; |
| 226 | + std::string solver = "osqp"; |
| 227 | +}; |
| 228 | + |
| 229 | +inline RocketOptions |
| 230 | +parse_rocket_options( int argc, char** argv, RocketOptions defaults = {} ) |
| 231 | +{ |
| 232 | + RocketOptions options = std::move( defaults ); |
| 233 | + ArgParser args( argc, argv ); |
| 234 | + |
| 235 | + while( !args.empty() ) |
| 236 | + { |
| 237 | + const std::string raw_arg = std::string( args.peek() ); |
| 238 | + if( args.consume_flag( "--help", "-h" ) ) |
| 239 | + { |
| 240 | + options.show_help = true; |
| 241 | + continue; |
| 242 | + } |
| 243 | + if( args.consume_flag( "--dump" ) ) |
| 244 | + { |
| 245 | + options.dump_traces = true; |
| 246 | + continue; |
| 247 | + } |
| 248 | + |
| 249 | + std::string value; |
| 250 | + if( args.consume_option( "--solver", value ) ) |
| 251 | + { |
| 252 | + options.solver = value; |
| 253 | + continue; |
| 254 | + } |
| 255 | + |
| 256 | + throw std::invalid_argument( "Unknown argument '" + raw_arg + "'" ); |
| 257 | + } |
| 258 | + |
| 259 | + return options; |
| 260 | +} |
| 261 | + |
| 262 | +} // namespace cli |
| 263 | +} // namespace examples |
| 264 | + |
0 commit comments