Skip to content

Commit a599190

Browse files
authored
feat: add compliance framework tagging for AI red teaming (#318)
Add OWASP, ATLAS, SAIF, and NIST compliance tags to attacks and transforms. Attacks tagged with core jailbreak technique (LLM01), transforms tagged with specific vulnerability categories. Includes comprehensive test coverage.
1 parent 8aaa668 commit a599190

27 files changed

+1433
-115
lines changed

dreadnode/airt/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dreadnode.airt import attack, search
1+
from dreadnode.airt import attack, compliance, search
22
from dreadnode.airt.attack import (
33
Attack,
44
goat_attack,
@@ -9,20 +9,37 @@
99
tap_attack,
1010
zoo_attack,
1111
)
12+
from dreadnode.airt.compliance import (
13+
ATTACK_MAPPINGS,
14+
ATLASTechnique,
15+
NISTAIRMFFunction,
16+
OWASPCategory,
17+
SAIFCategory,
18+
tag_attack,
19+
tag_transform,
20+
)
1221
from dreadnode.airt.target import CustomTarget, LLMTarget, Target
1322

1423
__all__ = [
24+
"ATTACK_MAPPINGS",
25+
"ATLASTechnique",
1526
"Attack",
1627
"CustomTarget",
1728
"LLMTarget",
29+
"NISTAIRMFFunction",
30+
"OWASPCategory",
31+
"SAIFCategory",
1832
"Target",
1933
"attack",
34+
"compliance",
2035
"goat_attack",
2136
"hop_skip_jump_attack",
2237
"nes_attack",
2338
"prompt_attack",
2439
"search",
2540
"simba_attack",
41+
"tag_attack",
42+
"tag_transform",
2643
"tap_attack",
2744
"target",
2845
"zoo_attack",

dreadnode/airt/attack/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class Attack(Study[In, Out]):
2626
tags: list[str] = Config(default_factory=lambda: ["attack"])
2727
"""A list of tags associated with the attack for logging."""
2828

29+
compliance_tags: dict[str, t.Any] = Config(default_factory=dict)
30+
"""Compliance framework tags (OWASP, ATLAS, SAIF, NIST) for this attack."""
31+
2932
hooks: list[EvalHook] = Field(default_factory=list, exclude=True, repr=False)
3033
"""Hooks to run at various points in the attack lifecycle."""
3134

dreadnode/airt/attack/crescendo.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
import yaml
44

55
from dreadnode.airt.attack import Attack
6+
from dreadnode.airt.compliance import (
7+
ATLASTechnique,
8+
NISTAIRMFFunction,
9+
OWASPCategory,
10+
SAIFCategory,
11+
tag_attack,
12+
)
613
from dreadnode.airt.target.base import Target
714
from dreadnode.constants import CRESCENDO_VARIANT_1
815
from dreadnode.data_types.message import Message as DnMessage
@@ -14,6 +21,20 @@
1421
from dreadnode.transforms.base import Transform
1522
from dreadnode.transforms.refine import adapt_prompt_trials, llm_refine
1623

24+
# Compliance framework tags for Crescendo attack
25+
# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.)
26+
# are added when transforms targeting those categories are used
27+
COMPLIANCE_TAGS = tag_attack(
28+
atlas=[
29+
ATLASTechnique.PROMPT_INJECTION_DIRECT,
30+
ATLASTechnique.LLM_JAILBREAK,
31+
],
32+
owasp=OWASPCategory.LLM01_PROMPT_INJECTION,
33+
saif=SAIFCategory.INPUT_MANIPULATION,
34+
nist_function=NISTAIRMFFunction.MEASURE,
35+
nist_subcategory="MS-2.7",
36+
)
37+
1738

1839
def crescendo_attack(
1940
goal: str,
@@ -179,6 +200,7 @@ async def crescendo_refiner(trials: list[Trial[DnMessage]]) -> DnMessage:
179200
"objective": objective_judge,
180201
},
181202
hooks=hooks or [],
203+
compliance_tags=COMPLIANCE_TAGS,
182204
)
183205

184206
# Add stop condition based on early_stopping_score

dreadnode/airt/attack/goat.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import typing as t
22

33
from dreadnode.airt.attack import Attack
4+
from dreadnode.airt.compliance import (
5+
ATLASTechnique,
6+
NISTAIRMFFunction,
7+
OWASPCategory,
8+
SAIFCategory,
9+
tag_attack,
10+
)
411
from dreadnode.data_types.message import Message as DnMessage
512
from dreadnode.meta.context import TrialCandidate
613
from dreadnode.optimization.search.graph import graph_neighborhood_search
@@ -18,6 +25,21 @@
1825
from dreadnode.optimization.trial import Trial
1926

2027

28+
# Compliance framework tags for GOAT attack
29+
# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.)
30+
# are added when transforms targeting those categories are used
31+
COMPLIANCE_TAGS = tag_attack(
32+
atlas=[
33+
ATLASTechnique.PROMPT_INJECTION_DIRECT,
34+
ATLASTechnique.LLM_JAILBREAK,
35+
],
36+
owasp=OWASPCategory.LLM01_PROMPT_INJECTION,
37+
saif=SAIFCategory.INPUT_MANIPULATION,
38+
nist_function=NISTAIRMFFunction.MEASURE,
39+
nist_subcategory="MS-2.7",
40+
)
41+
42+
2143
def goat_attack(
2244
goal: str,
2345
target: "Target[DnMessage, DnMessage]",
@@ -121,6 +143,7 @@ async def message_refiner(trials: list["Trial[DnMessage]"]) -> DnMessage:
121143
},
122144
constraints=[topic_constraint],
123145
hooks=hooks or [],
146+
compliance_tags=COMPLIANCE_TAGS,
124147
)
125148

126149
if early_stopping_score is not None:

dreadnode/airt/attack/prompt.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import rigging as rg
44

55
from dreadnode.airt.attack.base import Attack
6+
from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_attack
67
from dreadnode.data_types.message import Message as DnMessage
78
from dreadnode.meta import TrialCandidate
89
from dreadnode.optimization.search.graph import beam_search
@@ -18,6 +19,19 @@
1819
from dreadnode.optimization.trial import Trial
1920

2021

22+
# Compliance framework tags for prompt attack
23+
# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.)
24+
# are added when transforms targeting those categories are used
25+
COMPLIANCE_TAGS = tag_attack(
26+
atlas=[
27+
ATLASTechnique.PROMPT_INJECTION_DIRECT,
28+
ATLASTechnique.LLM_JAILBREAK,
29+
],
30+
owasp=OWASPCategory.LLM01_PROMPT_INJECTION,
31+
saif=SAIFCategory.INPUT_MANIPULATION,
32+
)
33+
34+
2135
def prompt_attack(
2236
goal: str,
2337
target: "Target[DnMessage, DnMessage]",
@@ -117,6 +131,7 @@ async def message_refiner(trials: list["Trial[DnMessage]"]) -> DnMessage:
117131
"prompt_judge": prompt_judge,
118132
},
119133
hooks=hooks or [],
134+
compliance_tags=COMPLIANCE_TAGS,
120135
)
121136

122137
if early_stopping_score is not None:

dreadnode/airt/attack/tap.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
from dreadnode.airt.attack import Attack
44
from dreadnode.airt.attack.prompt import prompt_attack
5+
from dreadnode.airt.compliance import (
6+
ATLASTechnique,
7+
NISTAIRMFFunction,
8+
OWASPCategory,
9+
SAIFCategory,
10+
tag_attack,
11+
)
512
from dreadnode.data_types.message import Message as DnMessage
613
from dreadnode.scorers.judge import llm_judge
714

@@ -10,6 +17,21 @@
1017
from dreadnode.eval.hooks.base import EvalHook
1118

1219

20+
# Compliance framework tags for TAP attack
21+
# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.)
22+
# are added when transforms targeting those categories are used
23+
COMPLIANCE_TAGS = tag_attack(
24+
atlas=[
25+
ATLASTechnique.PROMPT_INJECTION_DIRECT,
26+
ATLASTechnique.LLM_JAILBREAK,
27+
],
28+
owasp=OWASPCategory.LLM01_PROMPT_INJECTION,
29+
saif=SAIFCategory.INPUT_MANIPULATION,
30+
nist_function=NISTAIRMFFunction.MEASURE,
31+
nist_subcategory="MS-2.7",
32+
)
33+
34+
1335
def tap_attack(
1436
goal: str,
1537
target: "Target[DnMessage, DnMessage]",
@@ -45,7 +67,7 @@ def tap_attack(
4567

4668
topic_constraint = llm_judge(evaluator_model, ON_TOPIC_RUBRIC.format(goal=goal))
4769

48-
return prompt_attack(
70+
base_attack = prompt_attack(
4971
goal,
5072
target,
5173
attacker_model,
@@ -58,7 +80,13 @@ def tap_attack(
5880
branching_factor=branching_factor,
5981
context_depth=context_depth,
6082
hooks=hooks or [],
61-
).with_(constraints={"on_topic": topic_constraint})
83+
)
84+
85+
# Set compliance tags before cloning
86+
base_attack.compliance_tags = COMPLIANCE_TAGS
87+
88+
# Add constraint and return
89+
return base_attack.with_(constraints={"on_topic": topic_constraint})
6290

6391

6492
REFINE_GUIDANCE = """\

0 commit comments

Comments
 (0)