Skip to content

Commit e1574b8

Browse files
committed
fix: correct MSJ call site and harden litellm rate-limit detection
Two related fixes uncovered during an audit of the msj_data fix (0944ac6), plus a pre-commit formatting/lint fix: 1. perform_many_shot_scan passed wrong type to msj_data.prepare_prompts - File: agentic_security/probe_actor/fuzzer.py - Bug: probe_datasets (list[dict], e.g. {"dataset_name": ..., "selected": ...}) was forwarded directly, but msj_data.prepare_prompts expects list[str]. - Effect: After 0944ac6 made prepare_prompts honor its dataset_names param, every MSJ multi-step scan silently loaded an empty dataset (the lookup `name in dataset_map` is always False when name is a dict). - Fix: extract the dataset_name strings and drop unselected entries, matching the existing data.prepare_prompts call a few lines above. - Test: add test_many_shot_passes_dataset_names_to_msj, which fails on the buggy code (asserts the mock receives ['probe-a'], not the raw dict list). 2. litellm rate-limit detection switched from string compare to isinstance - File: agentic_security/llm_providers/litellm_provider.py - Bug: _handle_error detected rate limits by comparing type(e).__module__ + __name__ to 'litellm.exceptions.RateLimitError'. Fragile (breaks on subclassing/module renames) and inconsistent with openai_provider.py and anthropic_provider.py, which both use isinstance. - Fix: use isinstance(e, litellm.exceptions.RateLimitError), guarded by `litellm is not None` since litellm is an optional import. - Test: replace the fabricated fake exception (monkeypatched __module__) with a real subclass of litellm.exceptions.RateLimitError so the isinstance path is genuinely exercised. 3. Pre-commit lint fixes (unblock CI on this branch) - Apply black formatting to fuzzer.py and test_fuzzer.py. - agentic_security/probe_data/image_generator.py: add `# noqa: E402` to the two imports (cache_to_disk, tqdm) that run after `matplotlib.use("Agg")`, matching the existing noqa on the line above. This E402 also exists on main and was surfaced because the CI runs `pre-commit run --all-files`.
1 parent 0944ac6 commit e1574b8

5 files changed

Lines changed: 61 additions & 10 deletions

File tree

agentic_security/llm_providers/litellm_provider.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def _parse_response(self, response: Any) -> LLMResponse:
7979
)
8080

8181
def _handle_error(self, e: Exception) -> None:
82-
qualname = f"{type(e).__module__}.{type(e).__name__}"
83-
if qualname == "litellm.exceptions.RateLimitError":
82+
if litellm is not None and isinstance(e, litellm.exceptions.RateLimitError):
8483
raise LLMRateLimitError(str(e)) from e
8584
raise LLMProviderError(str(e)) from e
8685

agentic_security/probe_actor/fuzzer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,9 @@ async def perform_many_shot_scan(
536536
tools_inbox=tools_inbox,
537537
)
538538
yield ScanResult.status_msg("Loading datasets for MSJ...")
539-
msj_modules = msj_data.prepare_prompts(probe_datasets)
539+
msj_modules = msj_data.prepare_prompts(
540+
dataset_names=[m["dataset_name"] for m in probe_datasets if m.get("selected")]
541+
)
540542
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
541543

542544
fuzzer_state = FuzzerState()

agentic_security/probe_data/image_generator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
import httpx
66
import matplotlib
7-
8-
matplotlib.use("Agg")
9-
import matplotlib.pyplot as plt # noqa: E402
107
from cache_to_disk import cache_to_disk
118
from tqdm import tqdm
129

1310
from agentic_security.probe_data.models import ImageProbeDataset, ProbeDataset
1411

12+
# matplotlib backend must be set before pyplot is imported.
13+
matplotlib.use("Agg")
14+
import matplotlib.pyplot as plt # noqa: E402
15+
1516

1617
def generate_image_dataset(
1718
text_dataset: list[ProbeDataset],

tests/unit/llm_providers/test_litellm_provider.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,18 @@ def provider(self):
209209
return LiteLLMProvider()
210210

211211
def test_rate_limit_maps_to_llm_rate_limit_error(self, provider):
212-
fake_exc = type("RateLimitError", (Exception,), {})()
213-
fake_exc.__class__.__module__ = "litellm.exceptions"
214-
fake_exc.__class__.__qualname__ = "RateLimitError"
212+
import litellm.exceptions
213+
214+
# Subclass the *real* litellm.exceptions.RateLimitError so the
215+
# isinstance() check in _handle_error is exercised (rather than a
216+
# string compare on __module__/__name__). The override bypasses the
217+
# openai parent constructor's verbose (response, body) signature.
218+
class _RealRateLimit(litellm.exceptions.RateLimitError):
219+
def __init__(self, message="rate limited"):
220+
Exception.__init__(self, message)
221+
215222
with pytest.raises(LLMRateLimitError):
216-
provider._handle_error(fake_exc)
223+
provider._handle_error(_RealRateLimit())
217224

218225
def test_generic_error_maps_to_llm_provider_error(self, provider):
219226
with pytest.raises(LLMProviderError):

tests/unit/probe_actor/test_fuzzer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,48 @@ async def test_perform_many_shot_scan_probe_injection(
113113
await assert_scan(async_gen, ["Loading", "Scan completed."])
114114

115115

116+
@pytest.mark.asyncio
117+
@patch("agentic_security.probe_data.msj_data.prepare_prompts")
118+
@patch("agentic_security.probe_data.data.prepare_prompts")
119+
async def test_many_shot_passes_dataset_names_to_msj(
120+
prepare_prompts_mock, msj_prepare_prompts_mock
121+
):
122+
# Regression: msj_data.prepare_prompts expects dataset_names: list[str],
123+
# but probe_datasets is a list[dict]. perform_many_shot_scan must extract
124+
# the "dataset_name" strings and filter out unselected entries before
125+
# forwarding them. Previously it passed the raw dict list, which (after
126+
# msj_data.prepare_prompts was fixed) silently returned an empty result.
127+
prepare_prompts_mock.return_value = []
128+
msj_prepare_prompts_mock.return_value = []
129+
130+
request_factory = MagicMock()
131+
request_factory.fn = AsyncMock(
132+
return_value=AsyncMock(status_code=200, text="ok", json=lambda: {})
133+
)
134+
135+
async_gen = perform_many_shot_scan(
136+
request_factory=request_factory,
137+
max_budget=100,
138+
datasets=[{"dataset_name": "main", "selected": True}],
139+
probe_datasets=[
140+
{"dataset_name": "probe-a", "selected": True},
141+
{
142+
"dataset_name": "probe-b",
143+
"selected": False,
144+
}, # unselected -> filtered out
145+
],
146+
optimize=False,
147+
)
148+
await assert_scan(async_gen, ["Loading", "Scan completed."])
149+
150+
msj_prepare_prompts_mock.assert_called_once()
151+
args, kwargs = msj_prepare_prompts_mock.call_args
152+
dataset_names = kwargs.get("dataset_names") or (args[0] if args else None)
153+
assert dataset_names == [
154+
"probe-a"
155+
], f"Expected dataset_names=['probe-a'], got {dataset_names!r}"
156+
157+
116158
@pytest.mark.asyncio
117159
@patch("agentic_security.probe_data.data.prepare_prompts")
118160
async def test_scan_router_single_shot(prepare_prompts_mock):

0 commit comments

Comments
 (0)