Skip to content

Commit 2a8740e

Browse files
liangzidssiq
andauthored
[Dataset] Add ArxivRollBench (#2458)
* Add ArxivRollBench dataset configs * Fix ArxivRollBench config prompts --------- Co-authored-by: ssiq <liweiwuhome@hotmail.com>
1 parent 5589a37 commit 2a8740e

4 files changed

Lines changed: 171 additions & 0 deletions

File tree

dataset-index.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,12 @@
359359
paper: https://arcprize.org/guide#private
360360
configpath: opencompass/configs/datasets/ARC_Prize_Public_Evaluation/arc_prize_public_evaluation_gen.py
361361
configpath_llmjudge: ''
362+
- arxivrollbench:
363+
name: ArxivRollBench
364+
category: Reasoning / Robustness
365+
paper: https://ojs.aaai.org/index.php/AAAI/article/view/41098
366+
configpath: opencompass/configs/datasets/arxivrollbench/arxivrollbench_gen.py
367+
configpath_llmjudge: ''
362368
- ax:
363369
name: SuperGLUE / AX
364370
category: Reasoning
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from opencompass.datasets import HFDataset
2+
from opencompass.openicl.icl_evaluator import AccEvaluator
3+
from opencompass.openicl.icl_inferencer import GenInferencer
4+
from opencompass.openicl.icl_raw_prompt_template import RawPromptTemplate
5+
from opencompass.openicl.icl_retriever import ZeroRetriever
6+
7+
_DOMAINS = [
8+
('cs', 'cs'),
9+
('q_fin', 'q-fin'),
10+
('math', 'math'),
11+
('physics', 'physics'),
12+
('stat', 'stat'),
13+
('q_bio', 'q-bio'),
14+
('econ', 'econ'),
15+
('eess', 'eess'),
16+
]
17+
_RELEASES = ['2024b', '2025a', '2026a']
18+
_TASK_TYPES = ['s', 'c', 'p']
19+
20+
_S_PROMPT = """## Instruction:
21+
Given a **shuffled text** composed of sentences A, B, and C, your task is to select the correct order from four available selections. Avoid providing any additional information (such as explanations of your choice) or restating the sentences in your answer. Simply provide your selection: Selection 1, Selection 2, Selection 3, or Selection 4.
22+
## Shuffled text:
23+
{shuffled_text}
24+
## Choice:
25+
**Selection 1** {A}
26+
**Selection 2** {B}
27+
**Selection 3** {C}
28+
**Selection 4** {D}
29+
Answer:"""
30+
31+
_C_PROMPT = """## Instruction:
32+
Given a **masked paragraph** with three masked sentences marked as '<|MaskedSentence|>' and candidate sentences labeled A, B, and C, your task is to fill in the correct sentences to the masked positions by selecting the appropriate answers from four provided selections. Avoid providing any additional information (such as explanations of your choice) or restating the sentences in your answer. Simply provide your selection: Selection 1, Selection 2, Selection 3, or Selection 4.
33+
## Masked paragraph:
34+
{text_with_holes}
35+
## {text_candidates}
36+
## Choice:
37+
**Selection 1** {A}
38+
**Selection 2** {B}
39+
**Selection 3** {C}
40+
**Selection 4** {D}
41+
Answer:"""
42+
43+
_P_PROMPT = """## Instruction:
44+
Given a context, and four choices marked as A, B, C, and D, your task is to select the correct text which is the next sequence of the provided context. Avoid providing any additional information (such as explanations of your choice) or restating the choice in your answer. Simply provide one of the four letters: A, B, C, or D.
45+
## Context:
46+
{context}
47+
## Choice:
48+
**A** {A}
49+
**B** {B}
50+
**C** {C}
51+
**D** {D}
52+
Answer:"""
53+
54+
55+
def _dataset_path(release, hf_domain, task_type, compact=True):
56+
suffix = '-50' if compact else ''
57+
if release == '2024b':
58+
return f'liangzid/robench2024b_all_set{hf_domain}SCP-{task_type}{suffix}'
59+
return (f'liangzid/robench{release}_test_all_category_set'
60+
f'{hf_domain}SCP-{task_type}{suffix}')
61+
62+
63+
def _reader_cfg(task_type):
64+
if task_type == 's':
65+
return dict(
66+
input_columns=['shuffled_text', 'A', 'B', 'C', 'D'],
67+
output_column='label',
68+
train_split='train',
69+
test_split='train',
70+
)
71+
if task_type == 'c':
72+
return dict(
73+
input_columns=['text_with_holes', 'text_candidates', 'A', 'B',
74+
'C', 'D'],
75+
output_column='label',
76+
train_split='train',
77+
test_split='train',
78+
)
79+
return dict(
80+
input_columns=['context', 'A', 'B', 'C', 'D'],
81+
output_column='label',
82+
train_split='train',
83+
test_split='train',
84+
)
85+
86+
87+
def _infer_cfg(task_type):
88+
prompt = dict(s=_S_PROMPT, c=_C_PROMPT, p=_P_PROMPT)[task_type]
89+
return dict(
90+
prompt_template=dict(
91+
type=RawPromptTemplate,
92+
messages=[
93+
dict(role='user', content=prompt),
94+
],
95+
),
96+
retriever=dict(type=ZeroRetriever),
97+
inferencer=dict(type=GenInferencer),
98+
)
99+
100+
101+
def _eval_cfg(task_type):
102+
postprocessor = ('arxivrollbench_selection_postprocess' if task_type
103+
in ['s', 'c'] else 'arxivrollbench_choice_postprocess')
104+
return dict(
105+
evaluator=dict(type=AccEvaluator),
106+
pred_postprocessor=dict(type=postprocessor),
107+
)
108+
109+
110+
def _build_datasets(compact=True):
111+
datasets = []
112+
for release in _RELEASES:
113+
for domain, hf_domain in _DOMAINS:
114+
for task_type in _TASK_TYPES:
115+
suffix = '' if compact else '-full'
116+
datasets.append(
117+
dict(
118+
abbr=(f'arxivrollbench-{release}-{domain}-'
119+
f'{task_type}{suffix}'),
120+
type=HFDataset,
121+
path=_dataset_path(release, hf_domain, task_type,
122+
compact),
123+
reader_cfg=_reader_cfg(task_type),
124+
infer_cfg=_infer_cfg(task_type),
125+
eval_cfg=_eval_cfg(task_type),
126+
))
127+
return datasets
128+
129+
130+
# Default ArxivRollBench configuration: compact 50-sample splits.
131+
arxivrollbench_datasets = _build_datasets(compact=True)
132+
133+
# Full public splits are also provided for complete benchmark runs.
134+
arxivrollbench_full_datasets = _build_datasets(compact=False)

opencompass/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .apps import * # noqa: F401, F403
1010
from .arc import * # noqa: F401, F403
1111
from .arc_prize_public_evaluation import * # noqa: F401, F403
12+
from .arxivrollbench import * # noqa: F401, F403
1213
from .ax import * # noqa: F401, F403
1314
from .babilong import * # noqa: F401, F403
1415
from .bbeh import * # noqa: F401, F403
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import re
2+
3+
from opencompass.registry import TEXT_POSTPROCESSORS
4+
5+
6+
@TEXT_POSTPROCESSORS.register_module()
7+
def arxivrollbench_selection_postprocess(text: str) -> str:
8+
"""Extract an ArxivRollBench S/C answer in ``Selection N`` format."""
9+
if text is None:
10+
return ''
11+
text = str(text).strip()
12+
match = re.search(r'\bselection\s*([1-4])\b', text, re.IGNORECASE)
13+
if match:
14+
return f'Selection {match.group(1)}'
15+
match = re.search(r'\b([1-4])\b', text)
16+
if match:
17+
return f'Selection {match.group(1)}'
18+
return text
19+
20+
21+
@TEXT_POSTPROCESSORS.register_module()
22+
def arxivrollbench_choice_postprocess(text: str) -> str:
23+
"""Extract an ArxivRollBench P-task answer in A/B/C/D format."""
24+
if text is None:
25+
return ''
26+
text = str(text).strip()
27+
match = re.search(r'\b([ABCD])\b', text, re.IGNORECASE)
28+
if match:
29+
return match.group(1).upper()
30+
return text

0 commit comments

Comments
 (0)