diff --git a/agentic_security/probe_data/msj_data.py b/agentic_security/probe_data/msj_data.py index 0b75f617..ea15b1f2 100644 --- a/agentic_security/probe_data/msj_data.py +++ b/agentic_security/probe_data/msj_data.py @@ -38,8 +38,10 @@ def load_dataset_generic(name, getter=lambda x: x["train"]["prompt"]): def prepare_prompts( - dataset_names=[], budget=-1, tools_inbox=None + dataset_names=None, budget=-1, tools_inbox=None ) -> list[ProbeDataset]: + if dataset_names is None: + dataset_names = [] # fka/awesome-chatgpt-prompts # data-is-better-together/10k_prompts_ranked # alespalla/chatbot_instruction_prompts @@ -51,4 +53,6 @@ def prepare_prompts( "fka/awesome-chatgpt-prompts" ), } - return [dataset_map[name] for name in dataset_map] + if not dataset_names: + return list(dataset_map.values()) + return [dataset_map[name] for name in dataset_names if name in dataset_map] diff --git a/agentic_security/probe_data/test_msj_data.py b/agentic_security/probe_data/test_msj_data.py index e81812b5..7f29d457 100644 --- a/agentic_security/probe_data/test_msj_data.py +++ b/agentic_security/probe_data/test_msj_data.py @@ -129,8 +129,7 @@ def test_dataset_contents(self, mock_load_dataset_generic): result = prepare_prompts( dataset_names=["data-is-better-together/10k_prompts_ranked"] ) - assert len(result) == 2 + assert len(result) == 1 assert all(isinstance(ds.prompts, list) for ds in result) assert all(isinstance(ds.metadata, dict) for ds in result) assert result[0].prompts == ["test prompt"] - assert result[1].prompts == ["another prompt"]