Skip to content

Commit 9656947

Browse files
committed
[Benchmark] Add support for MaRVL, xGQA and ALM-Bench.
Register the three multilingual benchmarks and resolve generic config entries to their dataset-specific implementations so they can be evaluated through the standard VLMEvalKit flow. Made-with: Cursor
1 parent a3656d4 commit 9656947

5 files changed

Lines changed: 509 additions & 0 deletions

File tree

run.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,18 @@ def build_dataset_from_config(cfg, dataset_name):
9494
cls = getattr(vlmeval.dataset, cls_name)
9595
sig = inspect.signature(cls.__init__)
9696
valid_params = {k: v for k, v in config.items() if k in sig.parameters}
97+
dataset_id = valid_params.get('dataset')
98+
generic_dataset_classes = {
99+
'ImageMCQDataset',
100+
'ImageVQADataset',
101+
'ImageYORNDataset',
102+
'OCRBench',
103+
}
104+
if dataset_id is not None and cls_name in generic_dataset_classes:
105+
dataset_kwargs = {k: v for k, v in valid_params.items() if k != 'dataset'}
106+
resolved = build_dataset(dataset_id, **dataset_kwargs)
107+
if resolved is not None:
108+
return resolved
97109
if cls.MODALITY == 'VIDEO':
98110
if valid_params.get('fps', 0) > 0 and valid_params.get('nframe', 0) > 0:
99111
raise ValueError('fps and nframe should not be set at the same time')

vlmeval/dataset/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77

88
from vlmeval.smp import LMUDataRoot, dump, get_intermediate_file_path, load, localize_df, toliststr
9+
from .almbench import ALMBenchDataset
910
from .asclepius import Asclepius
1011
from .av_speakerbench import AVSpeakerBench
1112
from .CGAVCounting.cg_av_counting import CGAVCounting
@@ -67,6 +68,7 @@
6768
from .m3oralbench import M3oralBenchDataset
6869
from .m4bench import M4Bench
6970
from .macbench import MaCBench
71+
from .marvl import MaRVL, MaRVL_id, MaRVL_sw, MaRVL_ta, MaRVL_tr, MaRVL_zh
7072
from .matbench import MATBench
7173
from .medqbench_caption import MedqbenchCaptionDataset
7274
from .medqbench_mcq import MedqbenchMCQDataset
@@ -152,6 +154,7 @@
152154
from .wildvision import WildVision
153155
from .worldsense import WorldSense
154156
from .worldvqa import WorldVQA
157+
from .xgqa import xGQA, xGQA_bn, xGQA_de, xGQA_en, xGQA_id, xGQA_ko, xGQA_pt, xGQA_ru, xGQA_zh
155158
from .xstest import XSTestDataset
156159

157160
from .video_dataset_config import supported_video_datasets # isort: skip
@@ -282,6 +285,9 @@ def evaluate(self, eval_file, **judge_kwargs):
282285
ZEROBench, SCAM, Omni3DBench, TallyQA, _3DSRBench, BMMR, AffordanceDataset,
283286
MMEReasoning, GOBenchDataset, SFE, ChartMimic, MMVMBench, XLRSBench,
284287
OmniEarthMCQBench, VisFactor, OSTDataset, OCRBench_v2, TreeBench, CVQA, M4Bench,
288+
MaRVL, MaRVL_id, MaRVL_sw, MaRVL_ta, MaRVL_tr, MaRVL_zh,
289+
xGQA, xGQA_bn, xGQA_de, xGQA_en, xGQA_id, xGQA_ko, xGQA_pt, xGQA_ru, xGQA_zh,
290+
ALMBenchDataset,
285291
AyaVisionBench, TopViewRS, VLMBias, MMHELIX, MedqbenchMCQDataset, MathCanvas, MMReason,
286292
MedqbenchPairedDescriptionDataset, MedqbenchCaptionDataset, ChartMuseum, ChartQAPro, ReasonMap_Plus, # noqa: E501
287293
olmOCRBench, OceanOCRBench, MATBench, VLRMBench, RefCOCODataset, RefSpatialDataset,

vlmeval/dataset/almbench.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
VLMEvalKit dataset class for ALM-Bench.
3+
"""
4+
5+
import re
6+
import string
7+
8+
import pandas as pd
9+
10+
from ..smp import load
11+
from .image_base import ImageBaseDataset
12+
13+
LANGUAGES = [
14+
'Afrikaans', 'Albanian', 'Amharic', 'Armenian', 'Assamese', 'Azerbaijani',
15+
'Basque', 'Belarusian', 'Bengali', 'Bhojpuri', 'Bosnian', 'Bulgarian',
16+
'Catalan', 'Cebuano', 'Chinese_Simplified', 'Chinese_Traditional', 'Croatian',
17+
'Czech', 'Danish', 'Dutch', 'Egyptian_Arabic', 'Emirati_Arabic', 'English',
18+
'Estonian', 'Filipino', 'Finnish', 'French', 'Galician', 'Georgian',
19+
'German', 'Greek', 'Gujarati', 'Hausa', 'Hawaiian', 'Hebrew', 'Hindi',
20+
'Hungarian', 'Icelandic', 'Igbo', 'Indonesian', 'Irish', 'Italian',
21+
'Japanese', 'Javanese', 'Kannada', 'Kazakh', 'Kinyarwanda', 'Korean',
22+
'Kurdish', 'Kyrgyz', 'Lao', 'Latin', 'Latvian', 'Lithuanian',
23+
'Luxembourgish', 'Macedonian', 'Malagasy', 'Malay', 'Malayalam', 'Maltese',
24+
'Marathi', 'Mongolian', 'Myanmar_Burmese', 'Nepali', 'Norwegian',
25+
'Odia_Oriya', 'Pashto', 'Persian', 'Polish', 'Portuguese', 'Punjabi',
26+
'Romanian', 'Russian', 'Sanskrit', 'Saudi_Arabic', 'Scots_Gaelic',
27+
'Serbian', 'Shona', 'Sindhi', 'Sinhala', 'Slovak', 'Slovenian', 'Somali',
28+
'Spanish', 'Sundanese', 'Swahili', 'Swedish', 'Tajik', 'Tamil', 'Telugu',
29+
'Thai', 'Turkish', 'Ukrainian', 'Urdu', 'Uyghur', 'Uzbek', 'Vietnamese',
30+
'Welsh', 'Yiddish', 'Yoruba',
31+
]
32+
33+
34+
def _make_url_dicts():
35+
names = ['ALMBench'] + [f'ALMBench_{lang}' for lang in LANGUAGES]
36+
return {name: '' for name in names}, {name: None for name in names}
37+
38+
39+
DATASET_URL, DATASET_MD5 = _make_url_dicts()
40+
41+
42+
def _normalise(text: str) -> str:
43+
"""Lowercase, strip punctuation and extra whitespace."""
44+
text = str(text).lower().strip()
45+
text = text.translate(str.maketrans('', '', string.punctuation))
46+
text = re.sub(r'\s+', ' ', text).strip()
47+
return text
48+
49+
50+
def _question_family(question_type: str) -> str:
51+
qtype = _normalise(question_type)
52+
if qtype in ('t/f', 'true/false', 'tf', 'true false question'):
53+
return 'tf'
54+
if qtype in ('mcqs', 'mcq', 'multiple choice', 'multiple choice questions'):
55+
return 'mcq'
56+
if qtype in ('svqas', 'svqa', 'short questions', 'short'):
57+
return 'short'
58+
if qtype in ('lvqas', 'lvqa', 'long question', 'long questions', 'long'):
59+
return 'long'
60+
return 'open'
61+
62+
63+
def _extract_tf(text: str):
64+
"""Extract True/False from a model prediction."""
65+
norm = _normalise(text)
66+
if re.search(r'\btrue\b|\byes\b|\bcorrect\b', norm):
67+
return 'true'
68+
if re.search(r'\bfalse\b|\bno\b|\bincorrect\b', norm):
69+
return 'false'
70+
return None
71+
72+
73+
def _extract_mcq_answer(answer: str) -> str:
74+
text = str(answer).strip()
75+
for delimiter in (' (', '\n('):
76+
if delimiter in text:
77+
return text.split(delimiter, 1)[0].strip()
78+
return text
79+
80+
81+
def _soft_exact_match(prediction: str, answer: str) -> bool:
82+
return _normalise(prediction) == _normalise(answer)
83+
84+
85+
def _tf_match(prediction: str, answer: str, english_answer: str = '') -> bool:
86+
pred_label = _extract_tf(prediction)
87+
ans_label = _extract_tf(english_answer) if str(english_answer).strip() else None
88+
if ans_label is None:
89+
ans_label = _extract_tf(answer)
90+
if pred_label is None or ans_label is None:
91+
if english_answer and _soft_exact_match(prediction, english_answer):
92+
return True
93+
return _soft_exact_match(prediction, answer)
94+
return pred_label == ans_label
95+
96+
97+
def _accuracy(df: pd.DataFrame) -> float:
98+
if len(df) == 0:
99+
return 0.0
100+
return round(df['correct'].sum() / len(df) * 100, 2)
101+
102+
103+
def _evaluate_row(row) -> bool:
104+
qtype = _question_family(str(row.get('question_type', '')))
105+
prediction = str(row['prediction'])
106+
answer = str(row['answer'])
107+
english_answer = str(row.get('english_answer', ''))
108+
109+
if qtype == 'tf':
110+
return _tf_match(prediction, answer, english_answer)
111+
if qtype == 'mcq':
112+
return _soft_exact_match(prediction, _extract_mcq_answer(answer))
113+
return _soft_exact_match(prediction, answer)
114+
115+
116+
class ALMBenchDataset(ImageBaseDataset):
117+
TYPE = 'VQA'
118+
MODALITY = 'IMAGE'
119+
DATASET_URL = DATASET_URL
120+
DATASET_MD5 = DATASET_MD5
121+
122+
def build_prompt(self, line):
123+
if isinstance(line, int):
124+
line = self.data.iloc[line]
125+
126+
img_paths = self.dump_image(line)
127+
if not isinstance(img_paths, list):
128+
img_paths = [img_paths]
129+
130+
question = str(line['question'])
131+
family = _question_family(str(line.get('question_type', '')).strip().lower())
132+
133+
if family == 'tf':
134+
instruction = 'Answer with True or False only.'
135+
elif family == 'mcq':
136+
instruction = 'Answer using only the text of the correct option.'
137+
elif family == 'short':
138+
instruction = 'Answer the question using a single word or short phrase.'
139+
else:
140+
instruction = 'Answer the question as accurately as possible.'
141+
142+
prompt = f'{question}\n{instruction}'
143+
msgs = [dict(type='image', value=p) for p in img_paths]
144+
msgs.append(dict(type='text', value=prompt))
145+
return msgs
146+
147+
def evaluate(self, eval_file, **judge_kwargs):
148+
data = load(eval_file)
149+
data['correct'] = data.apply(_evaluate_row, axis=1)
150+
151+
rows = []
152+
153+
def add_rows(col_name, split_label):
154+
if col_name not in data.columns:
155+
return
156+
for value in sorted(data[col_name].dropna().unique()):
157+
sub = data[data[col_name] == value]
158+
rows.append({
159+
'dataset': self.dataset_name,
160+
'split_by': split_label,
161+
'value': value,
162+
'total': len(sub),
163+
'correct': int(sub['correct'].sum()),
164+
'accuracy (%)': _accuracy(sub),
165+
})
166+
167+
add_rows('language', 'language')
168+
add_rows('category', 'category')
169+
add_rows('question_type', 'question_type')
170+
rows.append({
171+
'dataset': self.dataset_name,
172+
'split_by': 'overall',
173+
'value': 'all',
174+
'total': len(data),
175+
'correct': int(data['correct'].sum()),
176+
'accuracy (%)': _accuracy(data),
177+
})
178+
179+
result_df = pd.DataFrame(rows)
180+
result_path = eval_file.replace('.xlsx', '_ALMBench_results.csv')
181+
if result_path == eval_file:
182+
result_path = eval_file + '_ALMBench_results.csv'
183+
result_df.to_csv(result_path, index=False)
184+
print(f'\nALM-Bench results -> {result_path}')
185+
print(result_df.to_string(index=False))
186+
return result_df
187+
188+
189+
def _make_lang_class(lang: str):
190+
name = f'ALMBench_{lang}'
191+
return type(
192+
name,
193+
(ALMBenchDataset,),
194+
{
195+
'__doc__': f'ALM-Bench - language: {lang}',
196+
'DATASET_URL': {name: DATASET_URL.get(name, '')},
197+
'DATASET_MD5': {name: DATASET_MD5.get(name)},
198+
},
199+
)
200+
201+
202+
for _lang in LANGUAGES:
203+
globals()[f'ALMBench_{_lang}'] = _make_lang_class(_lang)
204+
205+
206+
class ALMBench(ALMBenchDataset):
207+
DATASET_URL = {'ALMBench': DATASET_URL.get('ALMBench', '')}
208+
DATASET_MD5 = {'ALMBench': DATASET_MD5.get('ALMBench')}
209+
210+
211+
ALM_LANGUAGES = list(LANGUAGES)
212+
ALM_DATASETS = ['ALMBench'] + [f'ALMBench_{lang}' for lang in LANGUAGES]

0 commit comments

Comments
 (0)