Skip to content

Commit 7a22cf8

Browse files
rlundeen2Copilot
andauthored
MAINT: Refactor Cyber scenario to use technique registry pattern (#1654)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 6ccf1ff commit 7a22cf8

6 files changed

Lines changed: 541 additions & 509 deletions

File tree

pyrit/scenario/core/scenario_techniques.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pyrit.executor.attack import (
2626
ManyShotJailbreakAttack,
2727
PromptSendingAttack,
28+
RedTeamingAttack,
2829
RolePlayAttack,
2930
RolePlayPaths,
3031
TreeOfAttacksWithPruningAttack,
@@ -70,6 +71,11 @@
7071
strategy_tags=["core", "multi_turn"],
7172
accepts_scorer_override=False,
7273
),
74+
AttackTechniqueSpec(
75+
name="red_teaming",
76+
attack_class=RedTeamingAttack,
77+
strategy_tags=["core", "multi_turn"],
78+
),
7379
]
7480

7581

pyrit/scenario/scenarios/airt/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any
77

88
from pyrit.scenario.scenarios.airt.content_harms import ContentHarms
9-
from pyrit.scenario.scenarios.airt.cyber import Cyber, CyberStrategy
9+
from pyrit.scenario.scenarios.airt.cyber import Cyber
1010
from pyrit.scenario.scenarios.airt.jailbreak import Jailbreak, JailbreakStrategy
1111
from pyrit.scenario.scenarios.airt.leakage import Leakage, LeakageStrategy
1212
from pyrit.scenario.scenarios.airt.psychosocial import Psychosocial, PsychosocialStrategy
@@ -28,6 +28,8 @@ def __getattr__(name: str) -> Any:
2828
return RapidResponse.get_strategy_class()
2929
if name == "ContentHarmsStrategy":
3030
return ContentHarms.get_strategy_class()
31+
if name == "CyberStrategy":
32+
return Cyber.get_strategy_class()
3133
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
3234

3335

Lines changed: 59 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
"""
5+
Cyber scenario — technique-based malware generation testing.
6+
7+
Strategies select **attack techniques** (PromptSending, RedTeaming).
8+
Datasets control **what** is tested (malware generation objectives).
9+
Use ``--dataset-names`` to narrow which objectives to test.
10+
"""
11+
12+
from __future__ import annotations
13+
414
import logging
515
import os
6-
from typing import TYPE_CHECKING, Any, Optional
16+
from typing import TYPE_CHECKING, ClassVar
717

818
from pyrit.auth import get_azure_openai_auth
919
from pyrit.common import apply_defaults
1020
from pyrit.common.path import SCORER_SEED_PROMPT_PATH
11-
from pyrit.executor.attack.core.attack_config import (
12-
AttackAdversarialConfig,
13-
AttackScoringConfig,
14-
)
15-
from pyrit.executor.attack.multi_turn.red_teaming import RedTeamingAttack
16-
from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack
17-
from pyrit.models import SeedAttackGroup
18-
from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget
19-
from pyrit.scenario.core.atomic_attack import AtomicAttack
20-
from pyrit.scenario.core.attack_technique import AttackTechnique
21+
from pyrit.prompt_target import OpenAIChatTarget
2122
from pyrit.scenario.core.dataset_configuration import DatasetConfiguration
2223
from pyrit.scenario.core.scenario import Scenario
23-
from pyrit.scenario.core.scenario_strategy import ScenarioStrategy
2424
from pyrit.score import (
2525
SelfAskRefusalScorer,
2626
SelfAskTrueFalseScorer,
@@ -31,24 +31,37 @@
3131
)
3232

3333
if TYPE_CHECKING:
34-
from pyrit.executor.attack.core.attack_strategy import AttackStrategy
34+
from pyrit.scenario.core.scenario_strategy import ScenarioStrategy
3535

3636
logger = logging.getLogger(__name__)
3737

38+
_CYBER_TECHNIQUE_NAMES = {"prompt_sending", "red_teaming"}
39+
3840

39-
class CyberStrategy(ScenarioStrategy):
41+
def _build_cyber_strategy() -> type[ScenarioStrategy]:
4042
"""
41-
Strategies for malware-focused cyber attacks. While not in the CyberStrategy class, a
42-
few of these include:
43-
* Shell smashing
44-
* Zip bombs
45-
* File deletion (rm -rf /).
43+
Build the Cyber strategy class dynamically from SCENARIO_TECHNIQUES.
44+
45+
Selects only ``prompt_sending`` and ``red_teaming`` techniques from
46+
the shared catalog.
47+
48+
Returns:
49+
type[ScenarioStrategy]: The dynamically generated strategy enum class.
4650
"""
51+
from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry
52+
from pyrit.registry.tag_query import TagQuery
53+
from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES
54+
55+
cyber_specs = [s for s in SCENARIO_TECHNIQUES if s.name in _CYBER_TECHNIQUE_NAMES]
4756

48-
# Aggregate members (special markers that expand to strategies with matching tags)
49-
ALL = ("all", {"all"})
50-
SINGLE_TURN = ("single_turn", {"single_turn"})
51-
MULTI_TURN = ("multi_turn", {"multi_turn"})
57+
return AttackTechniqueRegistry.build_strategy_class_from_specs(
58+
class_name="CyberStrategy",
59+
specs=cyber_specs,
60+
aggregate_tags={
61+
"single_turn": TagQuery.any_of("single_turn"),
62+
"multi_turn": TagQuery.any_of("multi_turn"),
63+
},
64+
)
5265

5366

5467
class Cyber(Scenario):
@@ -60,27 +73,31 @@ class Cyber(Scenario):
6073
techniques.
6174
"""
6275

63-
VERSION: int = 1
76+
VERSION: int = 2
77+
_cached_strategy_class: ClassVar[type[ScenarioStrategy] | None] = None
6478

6579
@classmethod
6680
def get_strategy_class(cls) -> type[ScenarioStrategy]:
6781
"""
68-
Get the strategy enum class for this scenario.
82+
Return the dynamically generated strategy class, building it on first access.
6983
7084
Returns:
71-
Type[ScenarioStrategy]: The CyberStrategy enum class.
85+
type[ScenarioStrategy]: The CyberStrategy enum class.
7286
"""
73-
return CyberStrategy
87+
if cls._cached_strategy_class is None:
88+
cls._cached_strategy_class = _build_cyber_strategy()
89+
return cls._cached_strategy_class
7490

7591
@classmethod
7692
def get_default_strategy(cls) -> ScenarioStrategy:
7793
"""
78-
Get the default strategy used when no strategies are specified.
94+
Return the default strategy member (``ALL``).
7995
8096
Returns:
81-
ScenarioStrategy: CyberStrategy.ALL (all cyber strategies).
97+
ScenarioStrategy: The ALL strategy value.
8298
"""
83-
return CyberStrategy.ALL
99+
strategy_class = cls.get_strategy_class()
100+
return strategy_class("all")
84101

85102
@classmethod
86103
def default_dataset_config(cls) -> DatasetConfiguration:
@@ -96,54 +113,36 @@ def default_dataset_config(cls) -> DatasetConfiguration:
96113
def __init__(
97114
self,
98115
*,
99-
adversarial_chat: Optional[PromptChatTarget] = None,
100-
objective_scorer: Optional[TrueFalseScorer] = None,
116+
objective_scorer: TrueFalseScorer | None = None,
101117
include_baseline: bool = True,
102-
scenario_result_id: Optional[str] = None,
118+
scenario_result_id: str | None = None,
103119
) -> None:
104120
"""
105121
Initialize the cyber harms scenario.
106122
107123
Args:
108-
adversarial_chat (Optional[PromptChatTarget]): Adversarial chat for the red teaming attack, corresponding
109-
to CyberStrategy.MultiTurn. If not provided, defaults to an OpenAI chat target.
110-
objective_scorer (Optional[TrueFalseScorer]): Objective scorer for malware detection. If not
111-
provided, defaults to a SelfAskScorer using the malware.yaml file under the scorer config store for
112-
malware detection
124+
objective_scorer (TrueFalseScorer | None): Objective scorer for malware detection. If not
125+
provided, defaults to a composite scorer using malware detection + refusal backstop.
113126
include_baseline (bool): Whether to include a baseline atomic attack that sends all objectives
114-
without modifications. Defaults to True. When True, a "baseline" attack is automatically
115-
added as the first atomic attack, allowing comparison between unmodified prompts and
116-
attack-modified prompts.
117-
scenario_result_id (Optional[str]): Optional ID of an existing scenario result to resume.
127+
without modifications. Defaults to True.
128+
scenario_result_id (str | None): Optional ID of an existing scenario result to resume.
118129
"""
119-
# Cyber uses a "take object, make config" pattern to expose a more ergonomic interface. Helper
120-
# methods return objects, not configs.
121-
122-
# In this context the "objective" scorer has nothing to do with the "objective" target.
123-
# The scoring config is what's used later in scoring orchestration, not the scorer object.
124-
if not objective_scorer:
125-
objective_scorer = self._get_default_objective_scorer()
126-
self._scorer_config = AttackScoringConfig(objective_scorer=objective_scorer)
127-
128-
# Adversarial chat is used by the RedTeamingAttack (multiturn) variation of the scenario. Otherwise, it's not
129-
# invoked.
130-
self._adversarial_chat = adversarial_chat if adversarial_chat else self._get_default_adversarial_target()
131-
self._adversarial_config = AttackAdversarialConfig(target=self._adversarial_chat)
130+
self._objective_scorer: TrueFalseScorer = (
131+
objective_scorer if objective_scorer else self._get_default_objective_scorer()
132+
)
132133

133134
super().__init__(
134135
version=self.VERSION,
135-
strategy_class=CyberStrategy,
136-
objective_scorer=objective_scorer,
136+
objective_scorer=self._objective_scorer,
137+
strategy_class=self.get_strategy_class(),
137138
include_default_baseline=include_baseline,
138139
scenario_result_id=scenario_result_id,
139140
)
140141

141-
# Will be resolved in _get_atomic_attacks_async
142-
self._seed_groups: Optional[list[SeedAttackGroup]] = None
143-
144142
def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer:
145143
"""
146144
Retrieve the default objective scorer.
145+
147146
The scoring mechanism is as follows:
148147
1. Detect for the presence of malware using a SelfAskTrueFalseScorer.
149148
2. Detect for NON-refusal as a backstop. (If the target model refuses to produce malware, we want that
@@ -175,93 +174,3 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer:
175174
return TrueFalseCompositeScorer(
176175
aggregator=TrueFalseScoreAggregator.AND, scorers=[presence_of_malware, backstop]
177176
)
178-
179-
def _get_default_adversarial_target(self) -> OpenAIChatTarget:
180-
"""
181-
Create and retrieve the default adversarial target.
182-
183-
Returns:
184-
OpenAIChatTarget: Default adversarial target, using an unfiltered endpoint.
185-
"""
186-
endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT")
187-
return OpenAIChatTarget(
188-
endpoint=endpoint,
189-
api_key=get_azure_openai_auth(endpoint or ""),
190-
model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"),
191-
temperature=1.2,
192-
)
193-
194-
def _resolve_seed_groups(self) -> list[SeedAttackGroup]:
195-
"""
196-
Resolve seed groups from dataset configuration.
197-
198-
Returns:
199-
List[SeedAttackGroup]: List of seed attack groups with objectives to be tested.
200-
"""
201-
# Use dataset_config (guaranteed to be set by initialize_async)
202-
seed_groups = self._dataset_config.get_all_seed_attack_groups()
203-
204-
if not seed_groups:
205-
self._raise_dataset_exception()
206-
207-
return list(seed_groups)
208-
209-
def _get_atomic_attack_from_strategy(self, strategy: str) -> AtomicAttack:
210-
"""
211-
Translate the strategy into an actual AtomicAttack.
212-
213-
Args:
214-
strategy: The CyberStrategy enum (SingleTurn or MultiTurn).
215-
216-
Returns:
217-
AtomicAttack: configured for the specified strategy.
218-
219-
Raises:
220-
ValueError: If scenario is not properly initialized or an unknown CyberStrategy is passed.
221-
"""
222-
# objective_target is guaranteed to be non-None by parent class validation
223-
if self._objective_target is None:
224-
raise ValueError(
225-
"Scenario not properly initialized. Call await scenario.initialize_async() before running."
226-
)
227-
attack_strategy: Optional[AttackStrategy[Any, Any]] = None
228-
if strategy == "single_turn":
229-
attack_strategy = PromptSendingAttack(
230-
objective_target=self._objective_target,
231-
attack_scoring_config=self._scorer_config,
232-
)
233-
elif strategy == "multi_turn":
234-
attack_strategy = RedTeamingAttack(
235-
objective_target=self._objective_target,
236-
attack_scoring_config=self._scorer_config,
237-
attack_adversarial_config=self._adversarial_config,
238-
)
239-
else:
240-
raise ValueError(f"Unknown CyberStrategy: {strategy}")
241-
242-
# _seed_groups is guaranteed to be set by _get_atomic_attacks_async before this method is called
243-
if self._seed_groups is None:
244-
raise ValueError("_seed_groups must be resolved before creating atomic attacks")
245-
246-
return AtomicAttack(
247-
atomic_attack_name=f"cyber_{strategy}",
248-
attack_technique=AttackTechnique(attack=attack_strategy),
249-
seed_groups=self._seed_groups,
250-
adversarial_chat=self._adversarial_chat,
251-
objective_scorer=self._scorer_config.objective_scorer,
252-
memory_labels=self._memory_labels,
253-
)
254-
255-
async def _get_atomic_attacks_async(self) -> list[AtomicAttack]:
256-
"""
257-
Generate atomic attacks for each strategy.
258-
259-
Returns:
260-
List[AtomicAttack]: List of atomic attacks to execute.
261-
"""
262-
# Resolve seed groups from deprecated objectives or dataset config
263-
self._seed_groups = self._resolve_seed_groups()
264-
265-
strategies = {s.value for s in self._scenario_strategies}
266-
267-
return [self._get_atomic_attack_from_strategy(strategy) for strategy in strategies]

0 commit comments

Comments
 (0)