Skip to content

Commit 75dcdfe

Browse files
biefanrlundeen2Copilot
authored
FIX: Support quoted initializer paths in shell (#2093)
Co-authored-by: Richard Lundeen <rlundeen@microsoft.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent c299739 commit 75dcdfe

2 files changed

Lines changed: 109 additions & 1 deletion

File tree

pyrit/cli/pyrit_shell.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import concurrent.futures
1616
import contextlib
1717
import logging
18+
import os
19+
import shlex
1820
import sys
1921
import threading
2022
from pathlib import Path
@@ -28,6 +30,38 @@
2830
_T = TypeVar("_T")
2931

3032

33+
def _split_initializer_paths(arg: str) -> list[str]:
34+
"""
35+
Split a command-line argument string into individual file paths.
36+
37+
Supports quoting paths that contain spaces. On Windows, backslashes are treated
38+
as literal path separators (not escape characters) so that unquoted paths such as
39+
``C:\\Users\\me\\init.py`` are preserved; surrounding quotes are stripped from each
40+
token. On POSIX systems, standard ``shlex`` parsing is used.
41+
42+
Args:
43+
arg: The raw argument string passed to the ``add-initializer`` command.
44+
45+
Returns:
46+
The list of individual file path strings parsed from ``arg``.
47+
48+
Raises:
49+
ValueError: If the argument contains unbalanced quotes.
50+
"""
51+
if os.name == "nt":
52+
lexer = shlex.shlex(arg, posix=False)
53+
lexer.whitespace_split = True
54+
tokens = list(lexer)
55+
return [_strip_surrounding_quotes(token) for token in tokens]
56+
return shlex.split(arg)
57+
58+
59+
def _strip_surrounding_quotes(token: str) -> str:
60+
if len(token) >= 2 and token[0] == token[-1] and token[0] in ("'", '"'):
61+
return token[1:-1]
62+
return token
63+
64+
3165
class PyRITShell(cmd.Cmd):
3266
"""
3367
Interactive shell for PyRIT (thin REST client).
@@ -249,7 +283,13 @@ def do_add_initializer(self, arg: str) -> None:
249283

250284
from pyrit.cli.api_client import ServerNotAvailableError
251285

252-
for script_path_str in arg.split():
286+
try:
287+
script_path_strings = _split_initializer_paths(arg)
288+
except ValueError as exc:
289+
print(f"Error parsing initializer paths: {exc}")
290+
return
291+
292+
for script_path_str in script_path_strings:
253293
script_path = Path(script_path_str).resolve()
254294
if not script_path.exists():
255295
print(f"Error: File not found: {script_path}")

tests/unit/cli/test_pyrit_shell.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,45 @@ def test_success_path(self, shell, tmp_path, capsys):
358358
assert "Registered initializer 'my_init'" in capsys.readouterr().out
359359
client.register_initializer_async.assert_awaited_once()
360360

361+
def test_success_with_quoted_path_containing_spaces(self, shell, tmp_path, capsys):
362+
s, client = shell
363+
script_dir = tmp_path / "initializer scripts"
364+
script_dir.mkdir()
365+
script = script_dir / "my_init.py"
366+
script.write_text("def init(): pass")
367+
client.register_initializer_async = AsyncMock(return_value={"status": "ok"})
368+
369+
s.do_add_initializer(f'"{script}"')
370+
371+
assert "Registered initializer 'my_init'" in capsys.readouterr().out
372+
client.register_initializer_async.assert_awaited_once_with(name="my_init", script_content="def init(): pass")
373+
374+
def test_malformed_path_quote(self, shell, capsys):
375+
s, client = shell
376+
client.register_initializer_async = AsyncMock(return_value={"status": "ok"})
377+
378+
s.do_add_initializer('"unterminated')
379+
380+
assert "Error parsing initializer paths" in capsys.readouterr().out
381+
client.register_initializer_async.assert_not_called()
382+
383+
def test_success_with_multiple_quoted_paths(self, shell, tmp_path, capsys):
384+
s, client = shell
385+
script_dir = tmp_path / "initializer scripts"
386+
script_dir.mkdir()
387+
first = script_dir / "first_init.py"
388+
second = script_dir / "second_init.py"
389+
first.write_text("def init(): pass")
390+
second.write_text("def init(): pass")
391+
client.register_initializer_async = AsyncMock(return_value={"status": "ok"})
392+
393+
s.do_add_initializer(f'"{first}" "{second}"')
394+
395+
out = capsys.readouterr().out
396+
assert "Registered initializer 'first_init'" in out
397+
assert "Registered initializer 'second_init'" in out
398+
assert client.register_initializer_async.await_count == 2
399+
361400
def test_server_not_available_error(self, shell, tmp_path, capsys):
362401
from pyrit.cli.api_client import ServerNotAvailableError
363402

@@ -782,3 +821,32 @@ def test_shell_choices_rejected_before_request(self, shell, capsys):
782821
# do_run surfaces these as "Error: ...".
783822
assert "Error" in out
784823
client.start_scenario_run_async.assert_not_called()
824+
825+
826+
class TestSplitInitializerPaths:
827+
def test_posix_splits_on_whitespace(self):
828+
with patch.object(pyrit_shell.os, "name", "posix"):
829+
assert pyrit_shell._split_initializer_paths("/a/one.py /b/two.py") == ["/a/one.py", "/b/two.py"]
830+
831+
def test_posix_respects_quotes_with_spaces(self):
832+
with patch.object(pyrit_shell.os, "name", "posix"):
833+
assert pyrit_shell._split_initializer_paths('"/a b/one.py"') == ["/a b/one.py"]
834+
835+
def test_windows_preserves_unquoted_backslash_path(self):
836+
with patch.object(pyrit_shell.os, "name", "nt"):
837+
assert pyrit_shell._split_initializer_paths(r"C:\Users\me\init.py") == [r"C:\Users\me\init.py"]
838+
839+
def test_windows_quoted_path_with_spaces_strips_quotes(self):
840+
with patch.object(pyrit_shell.os, "name", "nt"):
841+
assert pyrit_shell._split_initializer_paths(r'"C:\a b\one.py"') == [r"C:\a b\one.py"]
842+
843+
def test_windows_multiple_paths(self):
844+
with patch.object(pyrit_shell.os, "name", "nt"):
845+
result = pyrit_shell._split_initializer_paths(r'"C:\a b\one.py" C:\c\two.py')
846+
assert result == [r"C:\a b\one.py", r"C:\c\two.py"]
847+
848+
@pytest.mark.parametrize("os_name", ["posix", "nt"])
849+
def test_unterminated_quote_raises(self, os_name):
850+
with patch.object(pyrit_shell.os, "name", os_name):
851+
with pytest.raises(ValueError):
852+
pyrit_shell._split_initializer_paths('"unterminated')

0 commit comments

Comments
 (0)