diff --git a/agentic_security/llm_providers/litellm_provider.py b/agentic_security/llm_providers/litellm_provider.py index 46f720af..0aa47d70 100644 --- a/agentic_security/llm_providers/litellm_provider.py +++ b/agentic_security/llm_providers/litellm_provider.py @@ -79,8 +79,7 @@ def _parse_response(self, response: Any) -> LLMResponse: ) def _handle_error(self, e: Exception) -> None: - qualname = f"{type(e).__module__}.{type(e).__name__}" - if qualname == "litellm.exceptions.RateLimitError": + if litellm is not None and isinstance(e, litellm.exceptions.RateLimitError): raise LLMRateLimitError(str(e)) from e raise LLMProviderError(str(e)) from e diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index 60b21357..823dc186 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -536,7 +536,9 @@ async def perform_many_shot_scan( tools_inbox=tools_inbox, ) yield ScanResult.status_msg("Loading datasets for MSJ...") - msj_modules = msj_data.prepare_prompts(probe_datasets) + msj_modules = msj_data.prepare_prompts( + dataset_names=[m["dataset_name"] for m in probe_datasets if m.get("selected")] + ) yield ScanResult.status_msg("Datasets loaded. Starting scan...") fuzzer_state = FuzzerState() diff --git a/agentic_security/probe_data/image_generator.py b/agentic_security/probe_data/image_generator.py index c417c816..3fb47dfe 100644 --- a/agentic_security/probe_data/image_generator.py +++ b/agentic_security/probe_data/image_generator.py @@ -4,14 +4,15 @@ import httpx import matplotlib - -matplotlib.use("Agg") -import matplotlib.pyplot as plt # noqa: E402 from cache_to_disk import cache_to_disk from tqdm import tqdm from agentic_security.probe_data.models import ImageProbeDataset, ProbeDataset +# matplotlib backend must be set before pyplot is imported. +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 + def generate_image_dataset( text_dataset: list[ProbeDataset], diff --git a/agentic_security/probe_data/msj_data.py b/agentic_security/probe_data/msj_data.py index 79f787a2..8a520cb1 100644 --- a/agentic_security/probe_data/msj_data.py +++ b/agentic_security/probe_data/msj_data.py @@ -19,8 +19,10 @@ def load_dataset_generic(name, getter=lambda x: x["train"]["prompt"]): def prepare_prompts( - dataset_names=[], budget=-1, tools_inbox=None + dataset_names=None, budget=-1, tools_inbox=None ) -> list[ProbeDataset]: + if dataset_names is None: + dataset_names = [] # fka/awesome-chatgpt-prompts # data-is-better-together/10k_prompts_ranked # alespalla/chatbot_instruction_prompts @@ -32,4 +34,6 @@ def prepare_prompts( "fka/awesome-chatgpt-prompts" ), } - return [dataset_map[name] for name in dataset_map] + if not dataset_names: + return list(dataset_map.values()) + return [dataset_map[name] for name in dataset_names if name in dataset_map] diff --git a/agentic_security/probe_data/test_msj_data.py b/agentic_security/probe_data/test_msj_data.py index e81812b5..7f29d457 100644 --- a/agentic_security/probe_data/test_msj_data.py +++ b/agentic_security/probe_data/test_msj_data.py @@ -129,8 +129,7 @@ def test_dataset_contents(self, mock_load_dataset_generic): result = prepare_prompts( dataset_names=["data-is-better-together/10k_prompts_ranked"] ) - assert len(result) == 2 + assert len(result) == 1 assert all(isinstance(ds.prompts, list) for ds in result) assert all(isinstance(ds.metadata, dict) for ds in result) assert result[0].prompts == ["test prompt"] - assert result[1].prompts == ["another prompt"] diff --git a/tests/unit/llm_providers/test_litellm_provider.py b/tests/unit/llm_providers/test_litellm_provider.py index 4e947f10..194f07b4 100644 --- a/tests/unit/llm_providers/test_litellm_provider.py +++ b/tests/unit/llm_providers/test_litellm_provider.py @@ -209,11 +209,18 @@ def provider(self): return LiteLLMProvider() def test_rate_limit_maps_to_llm_rate_limit_error(self, provider): - fake_exc = type("RateLimitError", (Exception,), {})() - fake_exc.__class__.__module__ = "litellm.exceptions" - fake_exc.__class__.__qualname__ = "RateLimitError" + import litellm.exceptions + + # Subclass the *real* litellm.exceptions.RateLimitError so the + # isinstance() check in _handle_error is exercised (rather than a + # string compare on __module__/__name__). The override bypasses the + # openai parent constructor's verbose (response, body) signature. + class _RealRateLimit(litellm.exceptions.RateLimitError): + def __init__(self, message="rate limited"): + Exception.__init__(self, message) + with pytest.raises(LLMRateLimitError): - provider._handle_error(fake_exc) + provider._handle_error(_RealRateLimit()) def test_generic_error_maps_to_llm_provider_error(self, provider): with pytest.raises(LLMProviderError): diff --git a/tests/unit/probe_actor/test_fuzzer.py b/tests/unit/probe_actor/test_fuzzer.py index f81035d9..495f1efd 100644 --- a/tests/unit/probe_actor/test_fuzzer.py +++ b/tests/unit/probe_actor/test_fuzzer.py @@ -113,6 +113,48 @@ async def test_perform_many_shot_scan_probe_injection( await assert_scan(async_gen, ["Loading", "Scan completed."]) +@pytest.mark.asyncio +@patch("agentic_security.probe_data.msj_data.prepare_prompts") +@patch("agentic_security.probe_data.data.prepare_prompts") +async def test_many_shot_passes_dataset_names_to_msj( + prepare_prompts_mock, msj_prepare_prompts_mock +): + # Regression: msj_data.prepare_prompts expects dataset_names: list[str], + # but probe_datasets is a list[dict]. perform_many_shot_scan must extract + # the "dataset_name" strings and filter out unselected entries before + # forwarding them. Previously it passed the raw dict list, which (after + # msj_data.prepare_prompts was fixed) silently returned an empty result. + prepare_prompts_mock.return_value = [] + msj_prepare_prompts_mock.return_value = [] + + request_factory = MagicMock() + request_factory.fn = AsyncMock( + return_value=AsyncMock(status_code=200, text="ok", json=lambda: {}) + ) + + async_gen = perform_many_shot_scan( + request_factory=request_factory, + max_budget=100, + datasets=[{"dataset_name": "main", "selected": True}], + probe_datasets=[ + {"dataset_name": "probe-a", "selected": True}, + { + "dataset_name": "probe-b", + "selected": False, + }, # unselected -> filtered out + ], + optimize=False, + ) + await assert_scan(async_gen, ["Loading", "Scan completed."]) + + msj_prepare_prompts_mock.assert_called_once() + args, kwargs = msj_prepare_prompts_mock.call_args + dataset_names = kwargs.get("dataset_names") or (args[0] if args else None) + assert dataset_names == [ + "probe-a" + ], f"Expected dataset_names=['probe-a'], got {dataset_names!r}" + + @pytest.mark.asyncio @patch("agentic_security.probe_data.data.prepare_prompts") async def test_scan_router_single_shot(prepare_prompts_mock):