Skip to content

Commit f6ceed3

Browse files
committed
GraniteGuardian-based metric for attaq
Signed-off-by: Jonathan Bnayahu <bnayahu@il.ibm.com>
1 parent b2751ba commit f6ceed3

4 files changed

Lines changed: 105 additions & 8 deletions

File tree

prepare/benchmarks/safety.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,27 @@
44

55
benchmark = Benchmark(
66
subsets={
7-
"attaq": DatasetRecipe(card="cards.attaq"),
8-
"provoq": DatasetRecipe(card="cards.safety.provoq"),
9-
"airbench": DatasetRecipe(card="cards.safety.airbench2024"),
10-
"ailuminate": DatasetRecipe(card="cards.safety.mlcommons_ailuminate"),
7+
"attaq": DatasetRecipe(
8+
card="cards.safety.attaq_gg",
9+
template_card_index="default",
10+
max_test_instances=500,
11+
),
12+
"provoq": DatasetRecipe(
13+
card="cards.safety.provoq",
14+
template_card_index="default",
15+
group_by=["group"],
16+
max_test_instances=500,
17+
),
18+
"airbench": DatasetRecipe(
19+
card="cards.safety.airbench2024",
20+
template_card_index="default",
21+
max_test_instances=500,
22+
),
23+
"ailuminate": DatasetRecipe(
24+
card="cards.safety.mlcommons_ailuminate",
25+
template_card_index="default",
26+
max_test_instances=500,
27+
),
1128
}
1229
)
1330

prepare/cards/safety/attaq_gg.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from unitxt import add_to_catalog
2+
from unitxt.blocks import InputOutputTemplate, LoadHF, Task, TaskCard, TemplatesDict
3+
from unitxt.operators import Shuffle
4+
from unitxt.splitters import RenameSplits
5+
from unitxt.test_utils.card import test_card
6+
7+
card = TaskCard(
8+
loader=LoadHF(path="ibm/AttaQ"),
9+
preprocess_steps=[
10+
RenameSplits(mapper={"train": "test"}),
11+
Shuffle(page_size=2800),
12+
],
13+
task=Task(
14+
input_fields={"input": str},
15+
reference_fields={"label": str},
16+
prediction_type=str,
17+
metrics=[
18+
"metrics.granite_guardian.assistant_risk.harm[prediction_type=str,user_message_field=input,assistant_message_field=prediction]",
19+
],
20+
),
21+
templates=TemplatesDict(
22+
{"default": InputOutputTemplate(input_format="{input}", output_format="")}
23+
),
24+
)
25+
26+
test_card(card, strict=False, demos_taken_from="test", num_demos=0)
27+
add_to_catalog(card, "cards.safety.attaq_gg", overwrite=True)

src/unitxt/catalog/benchmarks/safety.json

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,30 @@
33
"subsets": {
44
"attaq": {
55
"__type__": "dataset_recipe",
6-
"card": "cards.attaq"
6+
"card": "cards.safety.attaq_gg",
7+
"template_card_index": "default",
8+
"max_test_instances": 500
79
},
810
"provoq": {
911
"__type__": "dataset_recipe",
10-
"card": "cards.safety.provoq"
12+
"card": "cards.safety.provoq",
13+
"template_card_index": "default",
14+
"group_by": [
15+
"group"
16+
],
17+
"max_test_instances": 500
1118
},
1219
"airbench": {
1320
"__type__": "dataset_recipe",
14-
"card": "cards.safety.airbench2024"
21+
"card": "cards.safety.airbench2024",
22+
"template_card_index": "default",
23+
"max_test_instances": 500
1524
},
1625
"ailuminate": {
1726
"__type__": "dataset_recipe",
18-
"card": "cards.safety.mlcommons_ailuminate"
27+
"card": "cards.safety.mlcommons_ailuminate",
28+
"template_card_index": "default",
29+
"max_test_instances": 500
1930
}
2031
}
2132
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{
2+
"__type__": "task_card",
3+
"loader": {
4+
"__type__": "load_hf",
5+
"path": "ibm/AttaQ"
6+
},
7+
"preprocess_steps": [
8+
{
9+
"__type__": "rename_splits",
10+
"mapper": {
11+
"train": "test"
12+
}
13+
},
14+
{
15+
"__type__": "shuffle",
16+
"page_size": 2800
17+
}
18+
],
19+
"task": {
20+
"__type__": "task",
21+
"input_fields": {
22+
"input": "str"
23+
},
24+
"reference_fields": {
25+
"label": "str"
26+
},
27+
"prediction_type": "str",
28+
"metrics": [
29+
"metrics.granite_guardian.assistant_risk.harm[prediction_type=str,user_message_field=input,assistant_message_field=prediction]"
30+
]
31+
},
32+
"templates": {
33+
"__type__": "templates_dict",
34+
"items": {
35+
"default": {
36+
"__type__": "input_output_template",
37+
"input_format": "{input}",
38+
"output_format": ""
39+
}
40+
}
41+
}
42+
}

0 commit comments

Comments
 (0)