Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions agentic_security/llm_providers/litellm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def _parse_response(self, response: Any) -> LLMResponse:
)

def _handle_error(self, e: Exception) -> None:
qualname = f"{type(e).__module__}.{type(e).__name__}"
if qualname == "litellm.exceptions.RateLimitError":
if litellm is not None and isinstance(e, litellm.exceptions.RateLimitError):
raise LLMRateLimitError(str(e)) from e
raise LLMProviderError(str(e)) from e

Expand Down
4 changes: 3 additions & 1 deletion agentic_security/probe_actor/fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,9 @@ async def perform_many_shot_scan(
tools_inbox=tools_inbox,
)
yield ScanResult.status_msg("Loading datasets for MSJ...")
msj_modules = msj_data.prepare_prompts(probe_datasets)
msj_modules = msj_data.prepare_prompts(
dataset_names=[m["dataset_name"] for m in probe_datasets if m.get("selected")]
)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")

fuzzer_state = FuzzerState()
Expand Down
7 changes: 4 additions & 3 deletions agentic_security/probe_data/image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import httpx
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
from cache_to_disk import cache_to_disk
from tqdm import tqdm

from agentic_security.probe_data.models import ImageProbeDataset, ProbeDataset

# matplotlib backend must be set before pyplot is imported.
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402


def generate_image_dataset(
text_dataset: list[ProbeDataset],
Expand Down
8 changes: 6 additions & 2 deletions agentic_security/probe_data/msj_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ def load_dataset_generic(name, getter=lambda x: x["train"]["prompt"]):


def prepare_prompts(
dataset_names=[], budget=-1, tools_inbox=None
dataset_names=None, budget=-1, tools_inbox=None
) -> list[ProbeDataset]:
if dataset_names is None:
dataset_names = []
# fka/awesome-chatgpt-prompts
# data-is-better-together/10k_prompts_ranked
# alespalla/chatbot_instruction_prompts
Expand All @@ -32,4 +34,6 @@ def prepare_prompts(
"fka/awesome-chatgpt-prompts"
),
}
return [dataset_map[name] for name in dataset_map]
if not dataset_names:
return list(dataset_map.values())
return [dataset_map[name] for name in dataset_names if name in dataset_map]
3 changes: 1 addition & 2 deletions agentic_security/probe_data/test_msj_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def test_dataset_contents(self, mock_load_dataset_generic):
result = prepare_prompts(
dataset_names=["data-is-better-together/10k_prompts_ranked"]
)
assert len(result) == 2
assert len(result) == 1
assert all(isinstance(ds.prompts, list) for ds in result)
assert all(isinstance(ds.metadata, dict) for ds in result)
assert result[0].prompts == ["test prompt"]
assert result[1].prompts == ["another prompt"]
15 changes: 11 additions & 4 deletions tests/unit/llm_providers/test_litellm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,18 @@ def provider(self):
return LiteLLMProvider()

def test_rate_limit_maps_to_llm_rate_limit_error(self, provider):
fake_exc = type("RateLimitError", (Exception,), {})()
fake_exc.__class__.__module__ = "litellm.exceptions"
fake_exc.__class__.__qualname__ = "RateLimitError"
import litellm.exceptions

# Subclass the *real* litellm.exceptions.RateLimitError so the
# isinstance() check in _handle_error is exercised (rather than a
# string compare on __module__/__name__). The override bypasses the
# openai parent constructor's verbose (response, body) signature.
class _RealRateLimit(litellm.exceptions.RateLimitError):
def __init__(self, message="rate limited"):
Exception.__init__(self, message)

with pytest.raises(LLMRateLimitError):
provider._handle_error(fake_exc)
provider._handle_error(_RealRateLimit())

def test_generic_error_maps_to_llm_provider_error(self, provider):
with pytest.raises(LLMProviderError):
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/probe_actor/test_fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,48 @@ async def test_perform_many_shot_scan_probe_injection(
await assert_scan(async_gen, ["Loading", "Scan completed."])


@pytest.mark.asyncio
@patch("agentic_security.probe_data.msj_data.prepare_prompts")
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_many_shot_passes_dataset_names_to_msj(
prepare_prompts_mock, msj_prepare_prompts_mock
):
# Regression: msj_data.prepare_prompts expects dataset_names: list[str],
# but probe_datasets is a list[dict]. perform_many_shot_scan must extract
# the "dataset_name" strings and filter out unselected entries before
# forwarding them. Previously it passed the raw dict list, which (after
# msj_data.prepare_prompts was fixed) silently returned an empty result.
prepare_prompts_mock.return_value = []
msj_prepare_prompts_mock.return_value = []

request_factory = MagicMock()
request_factory.fn = AsyncMock(
return_value=AsyncMock(status_code=200, text="ok", json=lambda: {})
)

async_gen = perform_many_shot_scan(
request_factory=request_factory,
max_budget=100,
datasets=[{"dataset_name": "main", "selected": True}],
probe_datasets=[
{"dataset_name": "probe-a", "selected": True},
{
"dataset_name": "probe-b",
"selected": False,
}, # unselected -> filtered out
],
optimize=False,
)
await assert_scan(async_gen, ["Loading", "Scan completed."])

msj_prepare_prompts_mock.assert_called_once()
args, kwargs = msj_prepare_prompts_mock.call_args
dataset_names = kwargs.get("dataset_names") or (args[0] if args else None)
assert dataset_names == [
"probe-a"
], f"Expected dataset_names=['probe-a'], got {dataset_names!r}"


@pytest.mark.asyncio
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_scan_router_single_shot(prepare_prompts_mock):
Expand Down
Loading