|
| 1 | +from opencompass.datasets import HFDataset |
| 2 | +from opencompass.openicl.icl_evaluator import AccEvaluator |
| 3 | +from opencompass.openicl.icl_inferencer import GenInferencer |
| 4 | +from opencompass.openicl.icl_raw_prompt_template import RawPromptTemplate |
| 5 | +from opencompass.openicl.icl_retriever import ZeroRetriever |
| 6 | + |
| 7 | +_DOMAINS = [ |
| 8 | + ('cs', 'cs'), |
| 9 | + ('q_fin', 'q-fin'), |
| 10 | + ('math', 'math'), |
| 11 | + ('physics', 'physics'), |
| 12 | + ('stat', 'stat'), |
| 13 | + ('q_bio', 'q-bio'), |
| 14 | + ('econ', 'econ'), |
| 15 | + ('eess', 'eess'), |
| 16 | +] |
| 17 | +_RELEASES = ['2024b', '2025a', '2026a'] |
| 18 | +_TASK_TYPES = ['s', 'c', 'p'] |
| 19 | + |
| 20 | +_S_PROMPT = """## Instruction: |
| 21 | + Given a **shuffled text** composed of sentences A, B, and C, your task is to select the correct order from four available selections. Avoid providing any additional information (such as explanations of your choice) or restating the sentences in your answer. Simply provide your selection: Selection 1, Selection 2, Selection 3, or Selection 4. |
| 22 | +## Shuffled text: |
| 23 | +{shuffled_text} |
| 24 | +## Choice: |
| 25 | +**Selection 1** {A} |
| 26 | +**Selection 2** {B} |
| 27 | +**Selection 3** {C} |
| 28 | +**Selection 4** {D} |
| 29 | +Answer:""" |
| 30 | + |
| 31 | +_C_PROMPT = """## Instruction: |
| 32 | + Given a **masked paragraph** with three masked sentences marked as '<|MaskedSentence|>' and candidate sentences labeled A, B, and C, your task is to fill in the correct sentences to the masked positions by selecting the appropriate answers from four provided selections. Avoid providing any additional information (such as explanations of your choice) or restating the sentences in your answer. Simply provide your selection: Selection 1, Selection 2, Selection 3, or Selection 4. |
| 33 | +## Masked paragraph: |
| 34 | +{text_with_holes} |
| 35 | +## {text_candidates} |
| 36 | + ## Choice: |
| 37 | +**Selection 1** {A} |
| 38 | +**Selection 2** {B} |
| 39 | +**Selection 3** {C} |
| 40 | +**Selection 4** {D} |
| 41 | +Answer:""" |
| 42 | + |
| 43 | +_P_PROMPT = """## Instruction: |
| 44 | + Given a context, and four choices marked as A, B, C, and D, your task is to select the correct text which is the next sequence of the provided context. Avoid providing any additional information (such as explanations of your choice) or restating the choice in your answer. Simply provide one of the four letters: A, B, C, or D. |
| 45 | +## Context: |
| 46 | +{context} |
| 47 | +## Choice: |
| 48 | +**A** {A} |
| 49 | +**B** {B} |
| 50 | +**C** {C} |
| 51 | +**D** {D} |
| 52 | +Answer:""" |
| 53 | + |
| 54 | + |
| 55 | +def _dataset_path(release, hf_domain, task_type, compact=True): |
| 56 | + suffix = '-50' if compact else '' |
| 57 | + if release == '2024b': |
| 58 | + return f'liangzid/robench2024b_all_set{hf_domain}SCP-{task_type}{suffix}' |
| 59 | + return (f'liangzid/robench{release}_test_all_category_set' |
| 60 | + f'{hf_domain}SCP-{task_type}{suffix}') |
| 61 | + |
| 62 | + |
| 63 | +def _reader_cfg(task_type): |
| 64 | + if task_type == 's': |
| 65 | + return dict( |
| 66 | + input_columns=['shuffled_text', 'A', 'B', 'C', 'D'], |
| 67 | + output_column='label', |
| 68 | + train_split='train', |
| 69 | + test_split='train', |
| 70 | + ) |
| 71 | + if task_type == 'c': |
| 72 | + return dict( |
| 73 | + input_columns=['text_with_holes', 'text_candidates', 'A', 'B', |
| 74 | + 'C', 'D'], |
| 75 | + output_column='label', |
| 76 | + train_split='train', |
| 77 | + test_split='train', |
| 78 | + ) |
| 79 | + return dict( |
| 80 | + input_columns=['context', 'A', 'B', 'C', 'D'], |
| 81 | + output_column='label', |
| 82 | + train_split='train', |
| 83 | + test_split='train', |
| 84 | + ) |
| 85 | + |
| 86 | + |
| 87 | +def _infer_cfg(task_type): |
| 88 | + prompt = dict(s=_S_PROMPT, c=_C_PROMPT, p=_P_PROMPT)[task_type] |
| 89 | + return dict( |
| 90 | + prompt_template=dict( |
| 91 | + type=RawPromptTemplate, |
| 92 | + messages=[ |
| 93 | + dict(role='user', content=prompt), |
| 94 | + ], |
| 95 | + ), |
| 96 | + retriever=dict(type=ZeroRetriever), |
| 97 | + inferencer=dict(type=GenInferencer), |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +def _eval_cfg(task_type): |
| 102 | + postprocessor = ('arxivrollbench_selection_postprocess' if task_type |
| 103 | + in ['s', 'c'] else 'arxivrollbench_choice_postprocess') |
| 104 | + return dict( |
| 105 | + evaluator=dict(type=AccEvaluator), |
| 106 | + pred_postprocessor=dict(type=postprocessor), |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +def _build_datasets(compact=True): |
| 111 | + datasets = [] |
| 112 | + for release in _RELEASES: |
| 113 | + for domain, hf_domain in _DOMAINS: |
| 114 | + for task_type in _TASK_TYPES: |
| 115 | + suffix = '' if compact else '-full' |
| 116 | + datasets.append( |
| 117 | + dict( |
| 118 | + abbr=(f'arxivrollbench-{release}-{domain}-' |
| 119 | + f'{task_type}{suffix}'), |
| 120 | + type=HFDataset, |
| 121 | + path=_dataset_path(release, hf_domain, task_type, |
| 122 | + compact), |
| 123 | + reader_cfg=_reader_cfg(task_type), |
| 124 | + infer_cfg=_infer_cfg(task_type), |
| 125 | + eval_cfg=_eval_cfg(task_type), |
| 126 | + )) |
| 127 | + return datasets |
| 128 | + |
| 129 | + |
| 130 | +# Default ArxivRollBench configuration: compact 50-sample splits. |
| 131 | +arxivrollbench_datasets = _build_datasets(compact=True) |
| 132 | + |
| 133 | +# Full public splits are also provided for complete benchmark runs. |
| 134 | +arxivrollbench_full_datasets = _build_datasets(compact=False) |
0 commit comments