Skip to content

Commit b48085b

Browse files
biefanCopilotromanlutz
authored
FIX: Coerce typed scenario parameters in scan CLI (#2092)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Roman Lutz <romanlutz13@gmail.com>
1 parent 202fc13 commit b48085b

2 files changed

Lines changed: 138 additions & 37 deletions

File tree

pyrit/cli/pyrit_scan.py

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,22 @@
1717
import sys
1818
from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter
1919
from pathlib import Path
20-
from typing import Any
20+
from typing import TYPE_CHECKING, Any, get_args, get_origin
2121

2222
from pyrit.cli._cli_args import (
2323
ARG_HELP,
2424
_parse_initializer_arg,
25+
build_parameters_from_api,
2526
non_negative_int,
2627
positive_int,
2728
validate_log_level_argparse,
2829
)
2930

31+
if TYPE_CHECKING:
32+
from collections.abc import Callable
33+
34+
from pyrit.models.parameter import Parameter
35+
3036
_TERMINAL_STATUSES = {"COMPLETED", "FAILED", "CANCELLED"}
3137

3238

@@ -248,70 +254,94 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser:
248254
_SCENARIO_DEST_PREFIX = "scenario__"
249255

250256

251-
_SCALAR_TYPE_COERCERS: dict[str, Any] = {
252-
"int": int,
253-
"float": float,
254-
"bool": lambda v: str(v).strip().lower() in ("1", "true", "yes", "y", "on"),
255-
"str": str,
256-
}
257+
def _scenario_value_coercer(*, name: str, annotation: Any) -> Callable[[Any], Any] | None:
258+
"""
259+
Build an argparse ``type=`` callable that coerces a single CLI token through
260+
``Parameter.coerce_value`` — the same coercion the shell and backend use.
257261
262+
Returns ``None`` when no coercion is needed (a plain ``str`` or an untyped
263+
passthrough). Coercion/validation failures (including ``Literal`` choice
264+
membership) are re-raised as ``argparse.ArgumentTypeError`` so argparse renders
265+
them as a clean CLI error.
266+
267+
Args:
268+
name: Scenario parameter name (used for the flag in error messages).
269+
annotation: Scalar element type to coerce to (e.g. ``int``, ``bool``, or
270+
``Literal[...]`` for choices), or ``None`` / ``str`` for passthrough.
258271
259-
def _scenario_param_kwargs(*, param: dict[str, Any]) -> dict[str, Any]:
272+
Returns:
273+
Callable[[Any], Any] | None: The coercer, or ``None`` for passthrough.
260274
"""
261-
Build argparse ``add_argument`` kwargs for a scenario-declared parameter dict.
275+
if annotation is None or annotation is str:
276+
return None
262277

263-
Uses ``param_type``, ``is_list`` and ``choices`` from the catalog payload
264-
so list params accept ``nargs='+'`` and scalar params get client-side
265-
type coercion and choice validation.
278+
from pyrit.models.parameter import Parameter
279+
280+
element_param = Parameter(name=name, description="", param_type=annotation)
281+
282+
def _coerce(raw: Any) -> Any:
283+
try:
284+
return element_param.coerce_value(raw)
285+
except (ValueError, TypeError) as exc:
286+
raise argparse.ArgumentTypeError(f"--{name.replace('_', '-')}: invalid value {raw!r} ({exc})") from exc
287+
288+
return _coerce
289+
290+
291+
def _scenario_param_kwargs(*, parameter: Parameter) -> dict[str, Any]:
292+
"""
293+
Build argparse ``add_argument`` kwargs for a scenario-declared ``Parameter``.
294+
295+
List params get ``nargs='+'`` and coerce per element; scalar params coerce the
296+
single token. All coercion — including ``Literal`` choice membership — routes
297+
through ``Parameter.coerce_value`` so scan, the shell, and the backend agree on
298+
accepted values.
266299
267300
Args:
268-
param: Single entry from ``RegisteredScenario.supported_parameters``.
301+
parameter: Scenario parameter built from the catalog payload via
302+
``build_parameters_from_api``.
269303
270304
Returns:
271305
dict[str, Any]: kwargs ready to pass to ``ArgumentParser.add_argument``.
272306
"""
273307
kwargs: dict[str, Any] = {
274-
"dest": f"{_SCENARIO_DEST_PREFIX}{param.get('name', '')}",
308+
"dest": f"{_SCENARIO_DEST_PREFIX}{parameter.name}",
275309
"default": argparse.SUPPRESS,
276-
"help": param.get("description", ""),
310+
"help": parameter.description,
277311
}
278-
if param.get("is_list"):
312+
param_type = parameter.param_type
313+
element_type: Any
314+
if get_origin(param_type) is list:
315+
type_args = get_args(param_type)
316+
element_type = type_args[0] if type_args else str
279317
kwargs["nargs"] = "+"
280318
else:
281-
coercer = _SCALAR_TYPE_COERCERS.get(param.get("param_type", ""))
282-
if coercer is not None and coercer is not str:
283-
param_name = param.get("name", "")
284-
285-
def _typed(raw: str) -> Any:
286-
try:
287-
return coercer(raw)
288-
except (ValueError, TypeError) as exc:
289-
raise argparse.ArgumentTypeError(
290-
f"--{param_name.replace('_', '-')}: invalid value {raw!r} ({exc})"
291-
) from exc
292-
293-
kwargs["type"] = _typed
294-
choices = param.get("choices")
295-
if choices:
296-
kwargs["choices"] = list(choices)
319+
element_type = param_type
320+
321+
coercer = _scenario_value_coercer(name=parameter.name, annotation=element_type)
322+
if coercer is not None:
323+
kwargs["type"] = coercer
297324
return kwargs
298325

299326

300327
def _add_scenario_params_from_api(*, parser: ArgumentParser, params: list[dict[str, Any]]) -> None:
301328
"""
302329
Add scenario-declared parameters (from the API response) as CLI flags.
303330
331+
Catalog payloads are converted to ``Parameter`` objects via
332+
``build_parameters_from_api`` (shared with the shell) so type coercion and
333+
choice handling stay consistent across entry points.
334+
304335
Args:
305336
parser: Parser to extend.
306337
params: List of parameter dicts from ``GET /api/scenarios/catalog/{name}``.
307338
"""
308339
seen_flags: set[str] = set(parser._option_string_actions.keys())
309-
for p in params:
310-
name = p.get("name", "")
311-
flag = f"--{name.replace('_', '-')}"
340+
for parameter in build_parameters_from_api(api_params=params) or []:
341+
flag = f"--{parameter.name.replace('_', '-')}"
312342
if flag in seen_flags:
313343
continue
314-
parser.add_argument(flag, **_scenario_param_kwargs(param=p))
344+
parser.add_argument(flag, **_scenario_param_kwargs(parameter=parameter))
315345
seen_flags.add(flag)
316346

317347

tests/unit/cli/test_pyrit_scan.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,56 @@ def test_int_param_invalid_value_rejected_client_side(self, capsys):
634634
parser.parse_args(["--max-turns", "not-an-int"])
635635
assert "invalid value" in capsys.readouterr().err
636636

637+
def test_bool_param_rejects_invalid_value_client_side(self, capsys):
638+
from argparse import ArgumentParser
639+
640+
parser = ArgumentParser()
641+
pyrit_scan._add_scenario_params_from_api(
642+
parser=parser,
643+
params=[{"name": "dry_run", "description": "...", "param_type": "bool"}],
644+
)
645+
646+
parsed = parser.parse_args(["--dry-run", "false"])
647+
assert parsed.scenario__dry_run is False
648+
649+
parsed = parser.parse_args(["--dry-run", "yes"])
650+
assert parsed.scenario__dry_run is True
651+
652+
with pytest.raises(SystemExit):
653+
parser.parse_args(["--dry-run", "maybe"])
654+
assert "invalid value" in capsys.readouterr().err
655+
656+
# "on"/"y" are NOT part of the canonical boolean vocabulary the shell and
657+
# backend accept (true/false, 1/0, yes/no); scan must reject them too so
658+
# the same flag never behaves differently depending on the entry point.
659+
with pytest.raises(SystemExit):
660+
parser.parse_args(["--dry-run", "on"])
661+
assert "invalid value" in capsys.readouterr().err
662+
663+
def test_list_int_param_coerces_each_value(self):
664+
from argparse import ArgumentParser
665+
666+
parser = ArgumentParser()
667+
pyrit_scan._add_scenario_params_from_api(
668+
parser=parser,
669+
params=[{"name": "sample_ids", "description": "...", "param_type": "list[int]", "is_list": True}],
670+
)
671+
672+
parsed = parser.parse_args(["--sample-ids", "1", "2", "3"])
673+
assert parsed.scenario__sample_ids == [1, 2, 3]
674+
675+
def test_typed_choices_are_compared_after_coercion(self):
676+
from argparse import ArgumentParser
677+
678+
parser = ArgumentParser()
679+
pyrit_scan._add_scenario_params_from_api(
680+
parser=parser,
681+
params=[{"name": "max_turns", "description": "...", "param_type": "int", "choices": ["1", "2"]}],
682+
)
683+
684+
parsed = parser.parse_args(["--max-turns", "1"])
685+
assert parsed.scenario__max_turns == 1
686+
637687
def test_choices_validated_client_side(self, capsys):
638688
from argparse import ArgumentParser
639689

@@ -647,7 +697,7 @@ def test_choices_validated_client_side(self, capsys):
647697

648698
with pytest.raises(SystemExit):
649699
parser.parse_args(["--mode", "warp"])
650-
assert "invalid choice" in capsys.readouterr().err
700+
assert "invalid value" in capsys.readouterr().err
651701

652702

653703
class TestMainExtraPaths:
@@ -810,6 +860,27 @@ def test_scenario_declared_flag_is_forwarded(self, _mock_prog, _mock_print, mock
810860
sent_request = client.start_scenario_run_async.call_args.kwargs["request"]
811861
assert sent_request["scenario_params"] == {"max_turns": "7"}
812862

863+
@patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True)
864+
@patch("pyrit.cli.api_client.PyRITApiClient")
865+
@patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock)
866+
@patch("pyrit.cli._output.print_scenario_run_progress")
867+
def test_typed_scenario_flags_are_forwarded_as_typed_values(
868+
self, _mock_prog, _mock_print, mock_client_class, _mock_probe
869+
):
870+
client = self._build_mock_client(
871+
supported_params=[
872+
{"name": "dry_run", "description": "...", "param_type": "bool"},
873+
{"name": "sample_ids", "description": "...", "param_type": "list[int]", "is_list": True},
874+
]
875+
)
876+
mock_client_class.return_value = client
877+
878+
result = pyrit_scan.main(["foo", "--target", "t", "--dry-run", "yes", "--sample-ids", "1", "2"])
879+
880+
assert result == 0
881+
sent_request = client.start_scenario_run_async.call_args.kwargs["request"]
882+
assert sent_request["scenario_params"] == {"dry_run": True, "sample_ids": [1, 2]}
883+
813884
@patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True)
814885
@patch("pyrit.cli.api_client.PyRITApiClient")
815886
@patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock)

0 commit comments

Comments
 (0)