|
17 | 17 | import sys |
18 | 18 | from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter |
19 | 19 | from pathlib import Path |
20 | | -from typing import Any |
| 20 | +from typing import TYPE_CHECKING, Any, get_args, get_origin |
21 | 21 |
|
22 | 22 | from pyrit.cli._cli_args import ( |
23 | 23 | ARG_HELP, |
24 | 24 | _parse_initializer_arg, |
| 25 | + build_parameters_from_api, |
25 | 26 | non_negative_int, |
26 | 27 | positive_int, |
27 | 28 | validate_log_level_argparse, |
28 | 29 | ) |
29 | 30 |
|
| 31 | +if TYPE_CHECKING: |
| 32 | + from collections.abc import Callable |
| 33 | + |
| 34 | + from pyrit.models.parameter import Parameter |
| 35 | + |
30 | 36 | _TERMINAL_STATUSES = {"COMPLETED", "FAILED", "CANCELLED"} |
31 | 37 |
|
32 | 38 |
|
@@ -248,70 +254,94 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser: |
248 | 254 | _SCENARIO_DEST_PREFIX = "scenario__" |
249 | 255 |
|
250 | 256 |
|
251 | | -_SCALAR_TYPE_COERCERS: dict[str, Any] = { |
252 | | - "int": int, |
253 | | - "float": float, |
254 | | - "bool": lambda v: str(v).strip().lower() in ("1", "true", "yes", "y", "on"), |
255 | | - "str": str, |
256 | | -} |
| 257 | +def _scenario_value_coercer(*, name: str, annotation: Any) -> Callable[[Any], Any] | None: |
| 258 | + """ |
| 259 | + Build an argparse ``type=`` callable that coerces a single CLI token through |
| 260 | + ``Parameter.coerce_value`` — the same coercion the shell and backend use. |
257 | 261 |
|
| 262 | + Returns ``None`` when no coercion is needed (a plain ``str`` or an untyped |
| 263 | + passthrough). Coercion/validation failures (including ``Literal`` choice |
| 264 | + membership) are re-raised as ``argparse.ArgumentTypeError`` so argparse renders |
| 265 | + them as a clean CLI error. |
| 266 | +
|
| 267 | + Args: |
| 268 | + name: Scenario parameter name (used for the flag in error messages). |
| 269 | + annotation: Scalar element type to coerce to (e.g. ``int``, ``bool``, or |
| 270 | + ``Literal[...]`` for choices), or ``None`` / ``str`` for passthrough. |
258 | 271 |
|
259 | | -def _scenario_param_kwargs(*, param: dict[str, Any]) -> dict[str, Any]: |
| 272 | + Returns: |
| 273 | + Callable[[Any], Any] | None: The coercer, or ``None`` for passthrough. |
260 | 274 | """ |
261 | | - Build argparse ``add_argument`` kwargs for a scenario-declared parameter dict. |
| 275 | + if annotation is None or annotation is str: |
| 276 | + return None |
262 | 277 |
|
263 | | - Uses ``param_type``, ``is_list`` and ``choices`` from the catalog payload |
264 | | - so list params accept ``nargs='+'`` and scalar params get client-side |
265 | | - type coercion and choice validation. |
| 278 | + from pyrit.models.parameter import Parameter |
| 279 | + |
| 280 | + element_param = Parameter(name=name, description="", param_type=annotation) |
| 281 | + |
| 282 | + def _coerce(raw: Any) -> Any: |
| 283 | + try: |
| 284 | + return element_param.coerce_value(raw) |
| 285 | + except (ValueError, TypeError) as exc: |
| 286 | + raise argparse.ArgumentTypeError(f"--{name.replace('_', '-')}: invalid value {raw!r} ({exc})") from exc |
| 287 | + |
| 288 | + return _coerce |
| 289 | + |
| 290 | + |
| 291 | +def _scenario_param_kwargs(*, parameter: Parameter) -> dict[str, Any]: |
| 292 | + """ |
| 293 | + Build argparse ``add_argument`` kwargs for a scenario-declared ``Parameter``. |
| 294 | +
|
| 295 | + List params get ``nargs='+'`` and coerce per element; scalar params coerce the |
| 296 | + single token. All coercion — including ``Literal`` choice membership — routes |
| 297 | + through ``Parameter.coerce_value`` so scan, the shell, and the backend agree on |
| 298 | + accepted values. |
266 | 299 |
|
267 | 300 | Args: |
268 | | - param: Single entry from ``RegisteredScenario.supported_parameters``. |
| 301 | + parameter: Scenario parameter built from the catalog payload via |
| 302 | + ``build_parameters_from_api``. |
269 | 303 |
|
270 | 304 | Returns: |
271 | 305 | dict[str, Any]: kwargs ready to pass to ``ArgumentParser.add_argument``. |
272 | 306 | """ |
273 | 307 | kwargs: dict[str, Any] = { |
274 | | - "dest": f"{_SCENARIO_DEST_PREFIX}{param.get('name', '')}", |
| 308 | + "dest": f"{_SCENARIO_DEST_PREFIX}{parameter.name}", |
275 | 309 | "default": argparse.SUPPRESS, |
276 | | - "help": param.get("description", ""), |
| 310 | + "help": parameter.description, |
277 | 311 | } |
278 | | - if param.get("is_list"): |
| 312 | + param_type = parameter.param_type |
| 313 | + element_type: Any |
| 314 | + if get_origin(param_type) is list: |
| 315 | + type_args = get_args(param_type) |
| 316 | + element_type = type_args[0] if type_args else str |
279 | 317 | kwargs["nargs"] = "+" |
280 | 318 | else: |
281 | | - coercer = _SCALAR_TYPE_COERCERS.get(param.get("param_type", "")) |
282 | | - if coercer is not None and coercer is not str: |
283 | | - param_name = param.get("name", "") |
284 | | - |
285 | | - def _typed(raw: str) -> Any: |
286 | | - try: |
287 | | - return coercer(raw) |
288 | | - except (ValueError, TypeError) as exc: |
289 | | - raise argparse.ArgumentTypeError( |
290 | | - f"--{param_name.replace('_', '-')}: invalid value {raw!r} ({exc})" |
291 | | - ) from exc |
292 | | - |
293 | | - kwargs["type"] = _typed |
294 | | - choices = param.get("choices") |
295 | | - if choices: |
296 | | - kwargs["choices"] = list(choices) |
| 319 | + element_type = param_type |
| 320 | + |
| 321 | + coercer = _scenario_value_coercer(name=parameter.name, annotation=element_type) |
| 322 | + if coercer is not None: |
| 323 | + kwargs["type"] = coercer |
297 | 324 | return kwargs |
298 | 325 |
|
299 | 326 |
|
300 | 327 | def _add_scenario_params_from_api(*, parser: ArgumentParser, params: list[dict[str, Any]]) -> None: |
301 | 328 | """ |
302 | 329 | Add scenario-declared parameters (from the API response) as CLI flags. |
303 | 330 |
|
| 331 | + Catalog payloads are converted to ``Parameter`` objects via |
| 332 | + ``build_parameters_from_api`` (shared with the shell) so type coercion and |
| 333 | + choice handling stay consistent across entry points. |
| 334 | +
|
304 | 335 | Args: |
305 | 336 | parser: Parser to extend. |
306 | 337 | params: List of parameter dicts from ``GET /api/scenarios/catalog/{name}``. |
307 | 338 | """ |
308 | 339 | seen_flags: set[str] = set(parser._option_string_actions.keys()) |
309 | | - for p in params: |
310 | | - name = p.get("name", "") |
311 | | - flag = f"--{name.replace('_', '-')}" |
| 340 | + for parameter in build_parameters_from_api(api_params=params) or []: |
| 341 | + flag = f"--{parameter.name.replace('_', '-')}" |
312 | 342 | if flag in seen_flags: |
313 | 343 | continue |
314 | | - parser.add_argument(flag, **_scenario_param_kwargs(param=p)) |
| 344 | + parser.add_argument(flag, **_scenario_param_kwargs(parameter=parameter)) |
315 | 345 | seen_flags.add(flag) |
316 | 346 |
|
317 | 347 |
|
|
0 commit comments