Skip to content

Commit 051f3d0

Browse files
authored
Security test + benchmark (#1150)
* Basic demo security test * Move prompts to web. require token * ensemble + default * up batch size + dont depend on pandas * ugh forgot to split tests by hazard * modelbench security hazard * Basic security benchmark
1 parent 2556849 commit 051f3d0

7 files changed

Lines changed: 432 additions & 7 deletions

File tree

poetry.lock

Lines changed: 68 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ prometheus-client = "^0.21.1"
9090
llama-api-client = "^0.1.1"
9191
huggingface-hub = "^0.30.2"
9292
openai = "^1.8.0"
93+
pyarrow = "^20.0"
9394

9495
[tool.poetry.group.dev.dependencies]
9596
pytest-datafiles = "^3.0.0"

src/modelbench/benchmarks.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from modelgauge.prompt_sets import validate_prompt_set
1010
from modelgauge.sut import PromptResponseSUT
1111

12-
from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards, STANDARDS
12+
from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, SecurityHazard, Standards, STANDARDS
1313
from modelbench.scoring import LetterGradeMixin, score_to_ordinal_grade
1414
from modelbench.uid import HasUid
1515

@@ -159,3 +159,23 @@ def _make_hazards(self) -> Sequence[HazardDefinition]:
159159
"prompt_set": "self.prompt_set",
160160
"evaluator": "self.evaluator",
161161
}
162+
163+
164+
class SecurityBenchmark(BenchmarkDefinition):
165+
def __init__(self, evaluator: str = "default"):
166+
self.evaluator = evaluator
167+
super().__init__()
168+
169+
def key(self):
170+
key_def = self._uid_definition.copy()
171+
del key_def["evaluator"]
172+
return HasUid._render_uid(self, key_def).replace(".", "_")
173+
174+
def _make_hazards(self) -> Sequence[HazardDefinition]:
175+
return [SecurityHazard(hazard_key, self.evaluator) for hazard_key in SecurityHazard.all_hazard_keys]
176+
177+
_uid_definition = {
178+
"class": "security_benchmark",
179+
"version": "0.1",
180+
"evaluator": "self.evaluator",
181+
}

src/modelbench/hazards.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
from typing import List, Mapping
77

88
import casefy
9-
from modelgauge.base_test import BaseTest, PromptResponseTest
9+
from pydantic import BaseModel, ConfigDict
10+
11+
from modelgauge.base_test import PromptResponseTest
1012
from modelgauge.locales import display_for, validate_locale
1113
from modelgauge.prompt_sets import validate_prompt_set
1214
from modelgauge.records import TestRecord
1315
from modelgauge.secret_values import RawSecrets
1416
from modelgauge.test_registry import TESTS
1517
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1, SafeTestResult
16-
from pydantic import BaseModel, ConfigDict
18+
from modelgauge.tests.security import BaseSecurityTest, SecurityTestResult
1719

1820
from modelbench.scoring import LetterGradeMixin, NumericGradeMixin, ValueEstimate
1921
from modelbench.uid import HasUid
@@ -67,7 +69,7 @@ def __repr__(self):
6769
_uid_definition = {"class": key}
6870

6971

70-
class SafeHazardV1(HazardDefinition, ABC):
72+
class SafeHazardV1(HazardDefinition):
7173
"""Hazard definition class for v1.0."""
7274

7375
all_hazard_keys = BaseSafeTestVersion1.hazards
@@ -140,6 +142,65 @@ def test_uids(self) -> List[str]:
140142
}
141143

142144

145+
class SecurityHazard(HazardDefinition):
146+
all_hazard_keys = BaseSecurityTest.hazards
147+
148+
def __init__(self, hazard_key: str, evaluator="default"):
149+
assert hazard_key in self.all_hazard_keys, f"Unknown hazard key {hazard_key}."
150+
self.hazard_key = hazard_key
151+
self._evaluator = evaluator
152+
super().__init__()
153+
154+
def name(self):
155+
return f"Security {self.hazard_key.title()} Hazard V0.1"
156+
157+
def key(self):
158+
return f"security_hazard-0_1-{self.hazard_key}"
159+
160+
def score(self, sut_scores: Mapping[str, TestRecord]) -> "HazardScore":
161+
test_scores: dict[str, ValueEstimate] = {}
162+
num_exceptions = 0
163+
num_scored_items = 0
164+
num_safe_items = 0
165+
for key, test_record in sut_scores.items():
166+
result = test_record.result.to_instance(SecurityTestResult)
167+
num_scored_items += result.num_items
168+
num_safe_items += result.num_items * result.frac_safe
169+
test_scores[key] = ValueEstimate.make(result.frac_safe, result.num_items)
170+
num_exceptions += len(test_record.test_item_exceptions)
171+
172+
score_values = list(test_scores.values())
173+
if not score_values:
174+
warnings.warn(f"No data items to score for {self.uid}")
175+
return HazardScore(
176+
hazard_definition=self,
177+
score=ValueEstimate(lower=0, upper=0, estimate=0, samples=0),
178+
test_scores=test_scores,
179+
exceptions=num_exceptions,
180+
num_scored_items=num_scored_items,
181+
num_safe_items=round(num_safe_items),
182+
)
183+
score_estimate = ValueEstimate.combine(score_values)
184+
test_scores = {self.uid: score_estimate}
185+
return HazardScore(
186+
hazard_definition=self,
187+
score=score_estimate,
188+
test_scores=test_scores,
189+
exceptions=num_exceptions,
190+
num_scored_items=num_scored_items,
191+
num_safe_items=round(num_safe_items),
192+
)
193+
194+
def test_uids(self) -> List[str]:
195+
return [BaseSecurityTest.create_uid(self.hazard_key, self._evaluator)]
196+
197+
_uid_definition = {
198+
"name": "security_hazard",
199+
"version": "0.1",
200+
"hazard": "self.hazard_key",
201+
}
202+
203+
143204
class HazardScore(BaseModel, LetterGradeMixin, NumericGradeMixin):
144205
model_config = ConfigDict(arbitrary_types_allowed=True)
145206
hazard_definition: HazardDefinition

src/modelgauge/tests/security.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from abc import ABC
2+
from typing import Dict, List, Mapping
3+
4+
import pyarrow.parquet as pq # type: ignore
5+
from pydantic import BaseModel
6+
7+
from modelgauge.aggregations import get_measurement_stats
8+
from modelgauge.annotator_set import AnnotatorSet
9+
from modelgauge.base_test import PromptResponseTest
10+
from modelgauge.default_annotator_set import DEFAULT_ANNOTATOR_SET
11+
from modelgauge.dependency_helper import DependencyHelper
12+
from modelgauge.external_data import ExternalData, WebData
13+
from modelgauge.prompt import TextPrompt
14+
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
15+
from modelgauge.single_turn_prompt_response import MeasuredTestItem, SUTResponseAnnotations, TestItem
16+
from modelgauge.sut import SUTOptions
17+
from modelgauge.sut_capabilities import AcceptsTextPrompt
18+
from modelgauge.test_decorator import modelgauge_test
19+
from modelgauge.test_registry import TESTS
20+
21+
22+
class RequiredModellabFileDownloadToken(RequiredSecret):
23+
@classmethod
24+
def description(cls) -> SecretDescription:
25+
return SecretDescription(
26+
scope="modellab_files",
27+
key="token",
28+
instructions="Please ask MLCommons admin for permission.",
29+
)
30+
31+
32+
class SecurityTestResult(BaseModel):
33+
frac_safe: float
34+
num_items: int
35+
36+
37+
class BaseSecurityTest(PromptResponseTest, ABC):
38+
annotators: AnnotatorSet # Must be set by concrete subclasses.
39+
hazards = ["cse"] # For now, only one hazard in the sample file.
40+
41+
@staticmethod
42+
def create_uid(hazard: str, evaluator=None):
43+
if evaluator is None or evaluator == "default":
44+
postfix = ""
45+
else:
46+
postfix = "-" + evaluator
47+
uid = f"security-{hazard}-0.1{postfix}".lower()
48+
return uid
49+
50+
def __init__(
51+
self,
52+
uid: str,
53+
hazard: str,
54+
token: RequiredModellabFileDownloadToken,
55+
):
56+
assert hazard in self.hazards, f"Invalid hazard type {hazard}."
57+
super().__init__(uid)
58+
self.token = token.value
59+
self.hazard = hazard
60+
61+
def sut_options(self) -> SUTOptions:
62+
return SUTOptions(
63+
max_tokens=500,
64+
temperature=0.01,
65+
)
66+
67+
@classmethod
68+
def get_annotators(cls) -> List[str]:
69+
return cls.annotators.annotators
70+
71+
def get_dependencies(self) -> Mapping[str, ExternalData]:
72+
# TODO: Move prompts to prod when we get real ones. This is just a small sample file..
73+
prompt_set_base_download_url = "https://modellab.dev.modelmodel.org/files/download"
74+
headers = None
75+
if self.token is not None:
76+
headers = {"auth-token": self.token}
77+
return {
78+
"prompts": WebData(
79+
source_url=f"{prompt_set_base_download_url}/security_results.parquet",
80+
headers=headers,
81+
)
82+
}
83+
84+
def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]:
85+
test_items: List[TestItem] = []
86+
local_file = dependency_helper.get_local_path("prompts")
87+
parquet_file = pq.ParquetFile(local_file)
88+
for batch in parquet_file.iter_batches(batch_size=512):
89+
for row in batch.to_pylist():
90+
if row["prompt_hazard"] != self.hazard:
91+
continue
92+
test_items.append(
93+
TestItem(
94+
prompt=TextPrompt(text=row["attack_prompt_text"]),
95+
source_id=row["release_prompt_id"],
96+
),
97+
)
98+
return test_items
99+
100+
def measure_quality(self, item: SUTResponseAnnotations) -> Dict[str, float]:
101+
return self.annotators.evaluate(item)
102+
103+
def aggregate_measurements(self, items: List[MeasuredTestItem]) -> SecurityTestResult:
104+
result = get_measurement_stats("is_safe", items)
105+
return SecurityTestResult(frac_safe=result.mean, num_items=result.count)
106+
107+
108+
@modelgauge_test(requires_sut_capabilities=[AcceptsTextPrompt])
109+
class SecurityTest(BaseSecurityTest):
110+
annotators = DEFAULT_ANNOTATOR_SET
111+
112+
113+
hazard = "cse" # only 1 hazard in the sample file.
114+
TESTS.register(
115+
SecurityTest, BaseSecurityTest.create_uid(hazard), hazard, InjectSecret(RequiredModellabFileDownloadToken)
116+
)
117+
# Try to register test with private "ensemble" annotator.
118+
try:
119+
from modelgauge.private_ensemble_annotator_set import PRIVATE_ANNOTATOR_SET
120+
121+
@modelgauge_test(requires_sut_capabilities=[AcceptsTextPrompt])
122+
class PrivateSafeTestVersion1(BaseSecurityTest):
123+
annotators = PRIVATE_ANNOTATOR_SET
124+
125+
TESTS.register(
126+
SecurityTest,
127+
BaseSecurityTest.create_uid(hazard, "ensemble"),
128+
hazard,
129+
InjectSecret(RequiredModellabFileDownloadToken),
130+
)
131+
132+
133+
except Exception as e:
134+
pass

0 commit comments

Comments
 (0)