Skip to content

Commit bd71e82

Browse files
authored
feat: Add Spatial-DISE benchmark task (#1327)
* Add Spatial-DISE benchmark task * Use merged and separate Spatial-DISE images * Split Spatial-DISE image input modes * Update Spatial-DISE paper metadata * Support dynamic Spatial-DISE answer options
1 parent 9c78edc commit bd71e82

3 files changed

Lines changed: 322 additions & 0 deletions

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
dataset_path: csv
2+
dataset_kwargs:
3+
data_files:
4+
test: https://huggingface.co/datasets/TACPS-liv/Spatial-DISE/resolve/main/DISE-bench/DISE-benchmark.csv
5+
skipinitialspace: true
6+
task: spatial_dise
7+
test_split: test
8+
output_type: generate_until
9+
process_docs: !function utils.spatial_dise_process_docs
10+
doc_to_visual: !function utils.spatial_dise_doc_to_visual
11+
doc_to_text: !function utils.spatial_dise_doc_to_text
12+
doc_to_target: "answer"
13+
process_results: !function utils.spatial_dise_process_results
14+
15+
metric_list:
16+
- metric: spatial_dise_acc
17+
aggregation: !function utils.spatial_dise_aggregate_results
18+
higher_is_better: true
19+
20+
generation_kwargs:
21+
max_new_tokens: 16
22+
temperature: 0
23+
do_sample: false
24+
25+
lmms_eval_specific_kwargs:
26+
default:
27+
pre_prompt: ""
28+
post_prompt: "\nPlease select the correct answer and respond with only one option letter."
29+
30+
metadata:
31+
- version: 0.0
32+
- dataset: TACPS-liv/Spatial-DISE
33+
- paper: https://openreview.net/pdf?id=bMINsPQpME
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
dataset_path: csv
2+
dataset_kwargs:
3+
data_files:
4+
test: https://huggingface.co/datasets/TACPS-liv/Spatial-DISE/resolve/main/DISE-bench/DISE-benchmark.csv
5+
skipinitialspace: true
6+
task: spatial_dise_separate
7+
test_split: test
8+
output_type: generate_until
9+
process_docs: !function utils.spatial_dise_process_docs_separate
10+
doc_to_visual: !function utils.spatial_dise_doc_to_visual
11+
doc_to_text: !function utils.spatial_dise_doc_to_text
12+
doc_to_target: "answer"
13+
process_results: !function utils.spatial_dise_process_results
14+
15+
metric_list:
16+
- metric: spatial_dise_acc
17+
aggregation: !function utils.spatial_dise_aggregate_results
18+
higher_is_better: true
19+
20+
generation_kwargs:
21+
max_new_tokens: 16
22+
temperature: 0
23+
do_sample: false
24+
25+
lmms_eval_specific_kwargs:
26+
default:
27+
pre_prompt: ""
28+
post_prompt: "\nPlease select the correct answer and respond with only one option letter."
29+
30+
metadata:
31+
- version: 0.0
32+
- dataset: TACPS-liv/Spatial-DISE
33+
- paper: https://openreview.net/pdf?id=bMINsPQpME
34+
- image_mode: separate
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import io
2+
import os
3+
import os.path as osp
4+
import re
5+
import string
6+
import tarfile
7+
from collections import defaultdict
8+
from functools import lru_cache
9+
from pathlib import Path
10+
from typing import Any, Dict, List
11+
12+
import datasets
13+
import numpy as np
14+
from huggingface_hub import snapshot_download
15+
from loguru import logger as eval_logger
16+
from PIL import Image
17+
18+
REPO_ID = "TACPS-liv/Spatial-DISE"
19+
MERGE_IMAGE_COLUMNS = [
20+
("image", "merged image"),
21+
]
22+
SEPARATE_IMAGE_COLUMNS = [
23+
("question_image_path", "separate question image"),
24+
("question_image_1_path", "separate question image 1"),
25+
("question_image_2_path", "separate question image 2"),
26+
("option_a_image_path", "separate option A image"),
27+
("option_b_image_path", "separate option B image"),
28+
("option_c_image_path", "separate option C image"),
29+
("option_d_image_path", "separate option D image"),
30+
]
31+
32+
33+
def spatial_dise_process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
34+
return _process_docs(dataset, image_mode="merge")
35+
36+
37+
def spatial_dise_process_docs_separate(dataset: datasets.Dataset) -> datasets.Dataset:
38+
return _process_docs(dataset, image_mode="separate")
39+
40+
41+
def _process_docs(dataset: datasets.Dataset, image_mode: str) -> datasets.Dataset:
42+
dataset_root = _dataset_root()
43+
tar_index = _tar_index(dataset_root)
44+
45+
def _process_doc(doc, idx):
46+
clean_doc = {str(key).strip(): _strip(value) for key, value in doc.items()}
47+
image_refs = _image_refs(clean_doc, tar_index, image_mode)
48+
if len(image_refs) == 0:
49+
raise FileNotFoundError(f"Spatial-DISE image {clean_doc['image']} not found in tar shards under {dataset_root}")
50+
option_letters = _option_letters(clean_doc.get("options", ""))
51+
52+
return {
53+
"id": f"benchmark_{idx}",
54+
"question": clean_doc["question"],
55+
"answer": clean_doc["answer"].upper(),
56+
"option_letters": option_letters,
57+
"image_path": image_refs[0]["path"],
58+
"image_shard": image_refs[0]["shard"],
59+
"image_paths": [ref["path"] for ref in image_refs],
60+
"image_shards": [ref["shard"] for ref in image_refs],
61+
"image_roles": [ref["role"] for ref in image_refs],
62+
"image_mode": image_mode,
63+
"category": clean_doc.get("category", ""),
64+
"difficulty": clean_doc.get("difficulty", ""),
65+
"source": clean_doc.get("source", ""),
66+
"dise_category": clean_doc.get("dise_category", ""),
67+
}
68+
69+
return dataset.map(_process_doc, with_indices=True)
70+
71+
72+
def spatial_dise_doc_to_visual(doc: Dict[str, Any]) -> List[Image.Image]:
73+
images = []
74+
for image_path, image_shard in zip(doc["image_paths"], doc["image_shards"]):
75+
images.append(_open_tar_image(image_shard, image_path))
76+
return images
77+
78+
79+
def spatial_dise_doc_to_text(doc: Dict[str, Any], lmms_eval_specific_kwargs=None) -> str:
80+
if lmms_eval_specific_kwargs is None:
81+
lmms_eval_specific_kwargs = {}
82+
83+
pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "")
84+
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "")
85+
option_text = ", ".join(doc.get("option_letters") or ["A", "B", "C", "D"])
86+
if doc.get("image_mode") == "separate":
87+
image_context = (
88+
"Images are provided as separate question/view/option images from the original sample. "
89+
f"Use all images together. The answer choices are labeled {option_text}.\n"
90+
)
91+
else:
92+
image_context = f"The image contains answer choices labeled {option_text}.\n"
93+
return f"{pre_prompt}{image_context}{doc['question'].strip()}{post_prompt}".strip()
94+
95+
96+
def spatial_dise_process_results(doc: Dict[str, Any], results: List[str]) -> Dict[str, Dict[str, Any]]:
97+
response = results[0]
98+
target = doc["answer"].strip().upper()
99+
pred = _extract_answer(response, doc.get("option_letters"))
100+
is_correct = pred == target
101+
102+
return {
103+
"spatial_dise_acc": {
104+
"id": doc["id"],
105+
"gt": target,
106+
"pred": response,
107+
"pred_parsed": pred,
108+
"category": doc["category"],
109+
"difficulty": doc["difficulty"],
110+
"dise_category": doc["dise_category"],
111+
"is_correct": is_correct,
112+
}
113+
}
114+
115+
116+
def spatial_dise_aggregate_results(results: List[Dict[str, Any]]) -> float:
117+
if len(results) == 0:
118+
return 0.0
119+
120+
scores = [sample["is_correct"] for sample in results]
121+
_log_breakdown("category", results)
122+
_log_breakdown("difficulty", results)
123+
_log_breakdown("dise_category", results)
124+
return float(np.mean(scores))
125+
126+
127+
def _extract_answer(response: str, choices=None) -> str:
128+
response = str(response).strip()
129+
choices = _normalize_choices(choices)
130+
letters = "".join(re.escape(choice) for choice in choices)
131+
try:
132+
from lmms_eval.tasks._task_utils.mcq_extract import extract_mcq_answer
133+
134+
answer = extract_mcq_answer(response, choices=choices)
135+
if answer:
136+
return answer.strip().upper()
137+
except Exception:
138+
pass
139+
140+
patterns = [
141+
rf"(?:answer|final answer|correct answer)\s*[::]?\s*\(?([{letters}])\)?",
142+
rf"^\s*\(?([{letters}])\)?(?:[\.\):\s]|$)",
143+
rf"\b([{letters}])\b",
144+
]
145+
for pattern in patterns:
146+
match = re.search(pattern, response, flags=re.IGNORECASE)
147+
if match:
148+
return match.group(1).upper()
149+
return ""
150+
151+
152+
def _option_letters(value) -> List[str]:
153+
if value is None:
154+
return list("ABCD")
155+
letters = []
156+
for option in str(value).replace(",", ",").split(","):
157+
option = option.strip().upper()
158+
if option and option[0] in string.ascii_uppercase and option[0] not in letters:
159+
letters.append(option[0])
160+
return letters or list("ABCD")
161+
162+
163+
def _normalize_choices(choices) -> List[str]:
164+
if not choices:
165+
return list("ABCD")
166+
normalized = []
167+
for choice in choices:
168+
choice = str(choice).strip().upper()
169+
if choice and choice[0] in string.ascii_uppercase and choice[0] not in normalized:
170+
normalized.append(choice[0])
171+
return normalized or list("ABCD")
172+
173+
174+
def _log_breakdown(key: str, results: List[Dict[str, Any]]) -> None:
175+
grouped = defaultdict(list)
176+
for sample in results:
177+
grouped[sample[key]].append(sample["is_correct"])
178+
179+
eval_logger.info(f"Spatial-DISE {key} breakdown:")
180+
for name in sorted(grouped):
181+
score = float(np.mean(grouped[name]))
182+
eval_logger.info(f" {name}: {score:.4f} ({sum(grouped[name])}/{len(grouped[name])})")
183+
184+
185+
def _dataset_root() -> str:
186+
local_root = os.environ.get("SPATIAL_DISE_ROOT")
187+
if local_root:
188+
local_root = osp.expanduser(osp.expandvars(local_root))
189+
if osp.isdir(local_root):
190+
return local_root
191+
192+
return snapshot_download(
193+
repo_id=REPO_ID,
194+
repo_type="dataset",
195+
revision="main",
196+
allow_patterns=["DISE-bench/DISE-benchmark.csv", "image/*.tar"],
197+
)
198+
199+
200+
def _csv_path_to_tar_member(path: str) -> str:
201+
path = str(path).strip()
202+
if path.startswith("images/"):
203+
path = path[len("images/") :]
204+
return path.lstrip("/\\")
205+
206+
207+
def _image_refs(doc: Dict[str, Any], tar_index: Dict[str, str], image_mode: str) -> List[Dict[str, str]]:
208+
refs = []
209+
seen = set()
210+
columns = SEPARATE_IMAGE_COLUMNS if image_mode == "separate" else MERGE_IMAGE_COLUMNS
211+
for column, role in columns:
212+
value = doc.get(column, "")
213+
if value is None:
214+
continue
215+
value = str(value).strip()
216+
if not value or value.lower() == "nan":
217+
continue
218+
member = _csv_path_to_tar_member(value)
219+
if member in seen:
220+
continue
221+
shard = tar_index.get(member)
222+
if shard is None:
223+
raise FileNotFoundError(f"Spatial-DISE image {column}={value} not found in tar shards")
224+
refs.append({"role": role, "path": member, "shard": shard})
225+
seen.add(member)
226+
return refs
227+
228+
229+
def _open_tar_image(shard: str, member: str) -> Image.Image:
230+
with tarfile.open(shard) as tf:
231+
image_file = tf.extractfile(member)
232+
if image_file is None:
233+
raise FileNotFoundError(f"{member} not found in {shard}")
234+
image = Image.open(io.BytesIO(image_file.read())).convert("RGB")
235+
return image
236+
237+
238+
@lru_cache(maxsize=4)
239+
def _tar_index(dataset_root: str) -> Dict[str, str]:
240+
image_dir = osp.join(dataset_root, "image")
241+
tar_paths = sorted(Path(image_dir).glob("*.tar"))
242+
if not tar_paths:
243+
raise FileNotFoundError(f"No Spatial-DISE tar shards found under {image_dir}")
244+
245+
tar_index = {}
246+
for tar_path in tar_paths:
247+
with tarfile.open(tar_path) as tf:
248+
for member in tf.getmembers():
249+
if member.isfile():
250+
tar_index[member.name] = str(tar_path)
251+
return tar_index
252+
253+
254+
def _strip(value):
255+
return value.strip() if isinstance(value, str) else value

0 commit comments

Comments
 (0)