|
| 1 | +import io |
| 2 | +import os |
| 3 | +import os.path as osp |
| 4 | +import re |
| 5 | +import string |
| 6 | +import tarfile |
| 7 | +from collections import defaultdict |
| 8 | +from functools import lru_cache |
| 9 | +from pathlib import Path |
| 10 | +from typing import Any, Dict, List |
| 11 | + |
| 12 | +import datasets |
| 13 | +import numpy as np |
| 14 | +from huggingface_hub import snapshot_download |
| 15 | +from loguru import logger as eval_logger |
| 16 | +from PIL import Image |
| 17 | + |
| 18 | +REPO_ID = "TACPS-liv/Spatial-DISE" |
| 19 | +MERGE_IMAGE_COLUMNS = [ |
| 20 | + ("image", "merged image"), |
| 21 | +] |
| 22 | +SEPARATE_IMAGE_COLUMNS = [ |
| 23 | + ("question_image_path", "separate question image"), |
| 24 | + ("question_image_1_path", "separate question image 1"), |
| 25 | + ("question_image_2_path", "separate question image 2"), |
| 26 | + ("option_a_image_path", "separate option A image"), |
| 27 | + ("option_b_image_path", "separate option B image"), |
| 28 | + ("option_c_image_path", "separate option C image"), |
| 29 | + ("option_d_image_path", "separate option D image"), |
| 30 | +] |
| 31 | + |
| 32 | + |
| 33 | +def spatial_dise_process_docs(dataset: datasets.Dataset) -> datasets.Dataset: |
| 34 | + return _process_docs(dataset, image_mode="merge") |
| 35 | + |
| 36 | + |
| 37 | +def spatial_dise_process_docs_separate(dataset: datasets.Dataset) -> datasets.Dataset: |
| 38 | + return _process_docs(dataset, image_mode="separate") |
| 39 | + |
| 40 | + |
| 41 | +def _process_docs(dataset: datasets.Dataset, image_mode: str) -> datasets.Dataset: |
| 42 | + dataset_root = _dataset_root() |
| 43 | + tar_index = _tar_index(dataset_root) |
| 44 | + |
| 45 | + def _process_doc(doc, idx): |
| 46 | + clean_doc = {str(key).strip(): _strip(value) for key, value in doc.items()} |
| 47 | + image_refs = _image_refs(clean_doc, tar_index, image_mode) |
| 48 | + if len(image_refs) == 0: |
| 49 | + raise FileNotFoundError(f"Spatial-DISE image {clean_doc['image']} not found in tar shards under {dataset_root}") |
| 50 | + option_letters = _option_letters(clean_doc.get("options", "")) |
| 51 | + |
| 52 | + return { |
| 53 | + "id": f"benchmark_{idx}", |
| 54 | + "question": clean_doc["question"], |
| 55 | + "answer": clean_doc["answer"].upper(), |
| 56 | + "option_letters": option_letters, |
| 57 | + "image_path": image_refs[0]["path"], |
| 58 | + "image_shard": image_refs[0]["shard"], |
| 59 | + "image_paths": [ref["path"] for ref in image_refs], |
| 60 | + "image_shards": [ref["shard"] for ref in image_refs], |
| 61 | + "image_roles": [ref["role"] for ref in image_refs], |
| 62 | + "image_mode": image_mode, |
| 63 | + "category": clean_doc.get("category", ""), |
| 64 | + "difficulty": clean_doc.get("difficulty", ""), |
| 65 | + "source": clean_doc.get("source", ""), |
| 66 | + "dise_category": clean_doc.get("dise_category", ""), |
| 67 | + } |
| 68 | + |
| 69 | + return dataset.map(_process_doc, with_indices=True) |
| 70 | + |
| 71 | + |
| 72 | +def spatial_dise_doc_to_visual(doc: Dict[str, Any]) -> List[Image.Image]: |
| 73 | + images = [] |
| 74 | + for image_path, image_shard in zip(doc["image_paths"], doc["image_shards"]): |
| 75 | + images.append(_open_tar_image(image_shard, image_path)) |
| 76 | + return images |
| 77 | + |
| 78 | + |
| 79 | +def spatial_dise_doc_to_text(doc: Dict[str, Any], lmms_eval_specific_kwargs=None) -> str: |
| 80 | + if lmms_eval_specific_kwargs is None: |
| 81 | + lmms_eval_specific_kwargs = {} |
| 82 | + |
| 83 | + pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") |
| 84 | + post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") |
| 85 | + option_text = ", ".join(doc.get("option_letters") or ["A", "B", "C", "D"]) |
| 86 | + if doc.get("image_mode") == "separate": |
| 87 | + image_context = ( |
| 88 | + "Images are provided as separate question/view/option images from the original sample. " |
| 89 | + f"Use all images together. The answer choices are labeled {option_text}.\n" |
| 90 | + ) |
| 91 | + else: |
| 92 | + image_context = f"The image contains answer choices labeled {option_text}.\n" |
| 93 | + return f"{pre_prompt}{image_context}{doc['question'].strip()}{post_prompt}".strip() |
| 94 | + |
| 95 | + |
| 96 | +def spatial_dise_process_results(doc: Dict[str, Any], results: List[str]) -> Dict[str, Dict[str, Any]]: |
| 97 | + response = results[0] |
| 98 | + target = doc["answer"].strip().upper() |
| 99 | + pred = _extract_answer(response, doc.get("option_letters")) |
| 100 | + is_correct = pred == target |
| 101 | + |
| 102 | + return { |
| 103 | + "spatial_dise_acc": { |
| 104 | + "id": doc["id"], |
| 105 | + "gt": target, |
| 106 | + "pred": response, |
| 107 | + "pred_parsed": pred, |
| 108 | + "category": doc["category"], |
| 109 | + "difficulty": doc["difficulty"], |
| 110 | + "dise_category": doc["dise_category"], |
| 111 | + "is_correct": is_correct, |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + |
| 116 | +def spatial_dise_aggregate_results(results: List[Dict[str, Any]]) -> float: |
| 117 | + if len(results) == 0: |
| 118 | + return 0.0 |
| 119 | + |
| 120 | + scores = [sample["is_correct"] for sample in results] |
| 121 | + _log_breakdown("category", results) |
| 122 | + _log_breakdown("difficulty", results) |
| 123 | + _log_breakdown("dise_category", results) |
| 124 | + return float(np.mean(scores)) |
| 125 | + |
| 126 | + |
| 127 | +def _extract_answer(response: str, choices=None) -> str: |
| 128 | + response = str(response).strip() |
| 129 | + choices = _normalize_choices(choices) |
| 130 | + letters = "".join(re.escape(choice) for choice in choices) |
| 131 | + try: |
| 132 | + from lmms_eval.tasks._task_utils.mcq_extract import extract_mcq_answer |
| 133 | + |
| 134 | + answer = extract_mcq_answer(response, choices=choices) |
| 135 | + if answer: |
| 136 | + return answer.strip().upper() |
| 137 | + except Exception: |
| 138 | + pass |
| 139 | + |
| 140 | + patterns = [ |
| 141 | + rf"(?:answer|final answer|correct answer)\s*[::]?\s*\(?([{letters}])\)?", |
| 142 | + rf"^\s*\(?([{letters}])\)?(?:[\.\):\s]|$)", |
| 143 | + rf"\b([{letters}])\b", |
| 144 | + ] |
| 145 | + for pattern in patterns: |
| 146 | + match = re.search(pattern, response, flags=re.IGNORECASE) |
| 147 | + if match: |
| 148 | + return match.group(1).upper() |
| 149 | + return "" |
| 150 | + |
| 151 | + |
| 152 | +def _option_letters(value) -> List[str]: |
| 153 | + if value is None: |
| 154 | + return list("ABCD") |
| 155 | + letters = [] |
| 156 | + for option in str(value).replace(",", ",").split(","): |
| 157 | + option = option.strip().upper() |
| 158 | + if option and option[0] in string.ascii_uppercase and option[0] not in letters: |
| 159 | + letters.append(option[0]) |
| 160 | + return letters or list("ABCD") |
| 161 | + |
| 162 | + |
| 163 | +def _normalize_choices(choices) -> List[str]: |
| 164 | + if not choices: |
| 165 | + return list("ABCD") |
| 166 | + normalized = [] |
| 167 | + for choice in choices: |
| 168 | + choice = str(choice).strip().upper() |
| 169 | + if choice and choice[0] in string.ascii_uppercase and choice[0] not in normalized: |
| 170 | + normalized.append(choice[0]) |
| 171 | + return normalized or list("ABCD") |
| 172 | + |
| 173 | + |
| 174 | +def _log_breakdown(key: str, results: List[Dict[str, Any]]) -> None: |
| 175 | + grouped = defaultdict(list) |
| 176 | + for sample in results: |
| 177 | + grouped[sample[key]].append(sample["is_correct"]) |
| 178 | + |
| 179 | + eval_logger.info(f"Spatial-DISE {key} breakdown:") |
| 180 | + for name in sorted(grouped): |
| 181 | + score = float(np.mean(grouped[name])) |
| 182 | + eval_logger.info(f" {name}: {score:.4f} ({sum(grouped[name])}/{len(grouped[name])})") |
| 183 | + |
| 184 | + |
| 185 | +def _dataset_root() -> str: |
| 186 | + local_root = os.environ.get("SPATIAL_DISE_ROOT") |
| 187 | + if local_root: |
| 188 | + local_root = osp.expanduser(osp.expandvars(local_root)) |
| 189 | + if osp.isdir(local_root): |
| 190 | + return local_root |
| 191 | + |
| 192 | + return snapshot_download( |
| 193 | + repo_id=REPO_ID, |
| 194 | + repo_type="dataset", |
| 195 | + revision="main", |
| 196 | + allow_patterns=["DISE-bench/DISE-benchmark.csv", "image/*.tar"], |
| 197 | + ) |
| 198 | + |
| 199 | + |
| 200 | +def _csv_path_to_tar_member(path: str) -> str: |
| 201 | + path = str(path).strip() |
| 202 | + if path.startswith("images/"): |
| 203 | + path = path[len("images/") :] |
| 204 | + return path.lstrip("/\\") |
| 205 | + |
| 206 | + |
| 207 | +def _image_refs(doc: Dict[str, Any], tar_index: Dict[str, str], image_mode: str) -> List[Dict[str, str]]: |
| 208 | + refs = [] |
| 209 | + seen = set() |
| 210 | + columns = SEPARATE_IMAGE_COLUMNS if image_mode == "separate" else MERGE_IMAGE_COLUMNS |
| 211 | + for column, role in columns: |
| 212 | + value = doc.get(column, "") |
| 213 | + if value is None: |
| 214 | + continue |
| 215 | + value = str(value).strip() |
| 216 | + if not value or value.lower() == "nan": |
| 217 | + continue |
| 218 | + member = _csv_path_to_tar_member(value) |
| 219 | + if member in seen: |
| 220 | + continue |
| 221 | + shard = tar_index.get(member) |
| 222 | + if shard is None: |
| 223 | + raise FileNotFoundError(f"Spatial-DISE image {column}={value} not found in tar shards") |
| 224 | + refs.append({"role": role, "path": member, "shard": shard}) |
| 225 | + seen.add(member) |
| 226 | + return refs |
| 227 | + |
| 228 | + |
| 229 | +def _open_tar_image(shard: str, member: str) -> Image.Image: |
| 230 | + with tarfile.open(shard) as tf: |
| 231 | + image_file = tf.extractfile(member) |
| 232 | + if image_file is None: |
| 233 | + raise FileNotFoundError(f"{member} not found in {shard}") |
| 234 | + image = Image.open(io.BytesIO(image_file.read())).convert("RGB") |
| 235 | + return image |
| 236 | + |
| 237 | + |
| 238 | +@lru_cache(maxsize=4) |
| 239 | +def _tar_index(dataset_root: str) -> Dict[str, str]: |
| 240 | + image_dir = osp.join(dataset_root, "image") |
| 241 | + tar_paths = sorted(Path(image_dir).glob("*.tar")) |
| 242 | + if not tar_paths: |
| 243 | + raise FileNotFoundError(f"No Spatial-DISE tar shards found under {image_dir}") |
| 244 | + |
| 245 | + tar_index = {} |
| 246 | + for tar_path in tar_paths: |
| 247 | + with tarfile.open(tar_path) as tf: |
| 248 | + for member in tf.getmembers(): |
| 249 | + if member.isfile(): |
| 250 | + tar_index[member.name] = str(tar_path) |
| 251 | + return tar_index |
| 252 | + |
| 253 | + |
| 254 | +def _strip(value): |
| 255 | + return value.strip() if isinstance(value, str) else value |
0 commit comments