Skip to content

Commit 1516dba

Browse files
authored
feat: add SiteBench image/video result merge script (#1076)
Add a script to combine site_bench_image and site_bench_video results into unified metrics. Supports auto-detection from logs directory and explicit JSONL paths. Can be invoked as: python -m lmms_eval.tasks.sitebench.merge_results --logs-dir logs/MODEL/
1 parent 5cdb841 commit 1516dba

1 file changed

Lines changed: 391 additions & 0 deletions

File tree

Lines changed: 391 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Merge SiteBench Image and Video results from lmms-eval output.
4+
5+
This script combines site_bench_image and site_bench_video results to compute
6+
overall metrics matching VLMEvalKit's methodology.
7+
8+
Usage:
9+
python -m lmms_eval.tasks.sitebench.merge_results --logs-dir logs/MODEL_NAME/
10+
python -m lmms_eval.tasks.sitebench.merge_results --image-jsonl path/to/image.jsonl --video-jsonl path/to/video.jsonl
11+
"""
12+
13+
import argparse
14+
import glob
15+
import json
16+
import os
17+
import re
18+
from collections import defaultdict
19+
20+
import pandas as pd
21+
22+
23+
def _empty_stats():
24+
return {
25+
"caa_num": 0.0,
26+
"caa_den": 0.0,
27+
"acc_num": 0.0,
28+
"acc_den": 0.0,
29+
}
30+
31+
32+
def _count_options_from_input(text: str) -> int | None:
33+
"""Count number of options from the input text."""
34+
if not text:
35+
return None
36+
lines = [line.strip() for line in text.splitlines()]
37+
try:
38+
start_idx = next(i for i, line in enumerate(lines) if line.lower().startswith("options"))
39+
except StopIteration:
40+
start_idx = None
41+
if start_idx is None:
42+
return None
43+
count = 0
44+
for line in lines[start_idx + 1 :]:
45+
if not line:
46+
break
47+
lower = line.lower()
48+
if lower.startswith("give me") or "best answer" in lower:
49+
break
50+
if re.match(r"^[A-Z]:", line):
51+
count += 1
52+
else:
53+
if count > 0:
54+
break
55+
return count if count > 0 else None
56+
57+
58+
def _count_options_from_doc(doc: dict) -> int | None:
59+
"""Count number of options from the doc dict."""
60+
if not isinstance(doc, dict):
61+
return None
62+
for key in ("choices", "options", "answer_choices"):
63+
value = doc.get(key)
64+
if isinstance(value, list) and len(value) > 0:
65+
return len(value)
66+
return None
67+
68+
69+
def compute_random_expected_acc(jsonl_path: str) -> tuple[float, int, int]:
70+
"""
71+
Compute the random expected accuracy (1/num_options average).
72+
73+
Returns:
74+
tuple of (avg_random_acc, total_counted, missing_count)
75+
"""
76+
total = 0
77+
sum_expect = 0.0
78+
missing = 0
79+
80+
with open(jsonl_path, "r", encoding="utf-8") as f:
81+
for line in f:
82+
if not line.strip():
83+
continue
84+
item = json.loads(line)
85+
n_opt = _count_options_from_input(item.get("input"))
86+
if n_opt is None:
87+
n_opt = _count_options_from_doc(item.get("doc"))
88+
if n_opt is None or n_opt <= 0:
89+
missing += 1
90+
continue
91+
sum_expect += 1.0 / n_opt
92+
total += 1
93+
94+
avg = sum_expect / total if total > 0 else 0.0
95+
return avg, total, missing
96+
97+
98+
def compute_stats_from_jsonl(jsonl_path: str) -> dict:
99+
"""
100+
Compute aggregated statistics from a samples JSONL file.
101+
102+
Returns:
103+
dict with keys: metric_stats, category_stats, overall
104+
"""
105+
metric_stats = defaultdict(_empty_stats)
106+
category_stats = defaultdict(_empty_stats)
107+
108+
with open(jsonl_path, "r", encoding="utf-8") as f:
109+
for line in f:
110+
if not line.strip():
111+
continue
112+
item = json.loads(line)
113+
114+
# Get accuracy and chance_adjusted_acc dicts
115+
acc = item.get("accuracy", {})
116+
caa = item.get("chance_adjusted_acc", {})
117+
118+
acc_total = acc.get("total", 0.0)
119+
caa_total = caa.get("total", 0.0)
120+
121+
# Update metric stats (by category/dataset keys)
122+
for key, value in acc.items():
123+
if key == "total":
124+
continue
125+
metric_stats[key]["acc_num"] += value
126+
metric_stats[key]["acc_den"] += acc_total
127+
128+
for key, value in caa.items():
129+
if key == "total":
130+
continue
131+
metric_stats[key]["caa_num"] += value
132+
metric_stats[key]["caa_den"] += caa_total
133+
134+
# Extract category from doc if available
135+
doc = item.get("doc")
136+
if isinstance(doc, dict):
137+
category = doc.get("category")
138+
if category:
139+
category_stats[category]["acc_num"] += acc.get("overall", 0.0)
140+
category_stats[category]["acc_den"] += acc_total
141+
category_stats[category]["caa_num"] += caa.get("overall", 0.0)
142+
category_stats[category]["caa_den"] += caa_total
143+
144+
# Compute overall from "overall" key in metric_stats
145+
overall = None
146+
if "overall" in metric_stats:
147+
overall = metric_stats["overall"]
148+
149+
return {
150+
"metric_stats": dict(metric_stats),
151+
"category_stats": dict(category_stats),
152+
"overall": overall,
153+
}
154+
155+
156+
def stats_to_df(stats: dict, label_col: str) -> pd.DataFrame:
157+
"""Convert stats dict to a pandas DataFrame."""
158+
rows = []
159+
for key, val in stats.items():
160+
caa = val["caa_num"] / val["caa_den"] if val["caa_den"] > 0 else 0.0
161+
acc = val["acc_num"] / val["acc_den"] if val["acc_den"] > 0 else 0.0
162+
count = val["acc_den"] if val["acc_den"] > 0 else val["caa_den"]
163+
rows.append((key, caa * 100, acc * 100, int(count)))
164+
165+
df = pd.DataFrame(rows, columns=[label_col, "CAA (%)", "Accuracy (%)", "Count"])
166+
df = df.sort_values(by="CAA (%)", ascending=False, ignore_index=True)
167+
return df
168+
169+
170+
def merge_stats(stats1: dict, stats2: dict) -> dict:
171+
"""Merge two stats dictionaries."""
172+
merged = defaultdict(_empty_stats)
173+
174+
for key, val in stats1.items():
175+
merged[key]["acc_num"] += val["acc_num"]
176+
merged[key]["acc_den"] += val["acc_den"]
177+
merged[key]["caa_num"] += val["caa_num"]
178+
merged[key]["caa_den"] += val["caa_den"]
179+
180+
for key, val in stats2.items():
181+
merged[key]["acc_num"] += val["acc_num"]
182+
merged[key]["acc_den"] += val["acc_den"]
183+
merged[key]["caa_num"] += val["caa_num"]
184+
merged[key]["caa_den"] += val["caa_den"]
185+
186+
return dict(merged)
187+
188+
189+
def find_latest_sitebench_files(logs_dir: str) -> tuple[str | None, str | None]:
190+
"""
191+
Find the latest site_bench_image and site_bench_video JSONL files.
192+
193+
Returns:
194+
tuple of (image_jsonl_path, video_jsonl_path)
195+
"""
196+
# Find all site_bench_image JSONL files
197+
image_files = glob.glob(os.path.join(logs_dir, "*samples_site_bench_image.jsonl"))
198+
# Find all site_bench_video JSONL files (including 32frame_multiimage variants)
199+
video_files = glob.glob(os.path.join(logs_dir, "*samples_site_bench_video*.jsonl"))
200+
201+
# Sort by filename (timestamp) descending to get latest
202+
image_files.sort(reverse=True)
203+
video_files.sort(reverse=True)
204+
205+
image_path = image_files[0] if image_files else None
206+
video_path = video_files[0] if video_files else None
207+
208+
return image_path, video_path
209+
210+
211+
def print_results(name: str, stats: dict, category_stats: dict = None, random_acc: float = None):
212+
"""Print formatted results."""
213+
print(f"\n{'='*60}")
214+
print(f"{name}")
215+
print("=" * 60)
216+
217+
if stats.get("overall"):
218+
overall = stats["overall"]
219+
acc = overall["acc_num"] / overall["acc_den"] if overall["acc_den"] > 0 else 0.0
220+
caa = overall["caa_num"] / overall["caa_den"] if overall["caa_den"] > 0 else 0.0
221+
count = int(overall["acc_den"])
222+
print(f"Overall: Accuracy={acc*100:.2f}%, CAA={caa*100:.2f}%, Count={count}")
223+
if random_acc is not None:
224+
print(f"Random Expected Accuracy: {random_acc*100:.2f}%")
225+
226+
if category_stats:
227+
cat_df = stats_to_df(category_stats, "Category")
228+
print("\nCategory Breakdown:")
229+
print(cat_df.to_string(index=False))
230+
231+
232+
def main():
233+
parser = argparse.ArgumentParser(description="Merge SiteBench Image and Video results from lmms-eval output.")
234+
parser.add_argument(
235+
"--logs-dir",
236+
type=str,
237+
help="Path to the model's logs directory (e.g., logs/MODEL_NAME/). " "Will auto-detect the latest site_bench_image and site_bench_video files.",
238+
)
239+
parser.add_argument(
240+
"--image-jsonl",
241+
type=str,
242+
help="Path to site_bench_image samples JSONL file.",
243+
)
244+
parser.add_argument(
245+
"--video-jsonl",
246+
type=str,
247+
help="Path to site_bench_video samples JSONL file.",
248+
)
249+
parser.add_argument(
250+
"--output",
251+
type=str,
252+
help="Optional output JSON file to save combined results.",
253+
)
254+
255+
args = parser.parse_args()
256+
257+
image_stats = None
258+
video_stats = None
259+
image_path = None
260+
video_path = None
261+
image_random_acc = None
262+
video_random_acc = None
263+
264+
# Auto-detect files from logs directory
265+
if args.logs_dir:
266+
image_path, video_path = find_latest_sitebench_files(args.logs_dir)
267+
268+
if image_path:
269+
print(f"Found image JSONL: {image_path}")
270+
image_stats = compute_stats_from_jsonl(image_path)
271+
image_random_acc, _, _ = compute_random_expected_acc(image_path)
272+
else:
273+
print("Warning: No site_bench_image JSONL found")
274+
275+
if video_path:
276+
print(f"Found video JSONL: {video_path}")
277+
video_stats = compute_stats_from_jsonl(video_path)
278+
video_random_acc, _, _ = compute_random_expected_acc(video_path)
279+
else:
280+
print("Warning: No site_bench_video JSONL found")
281+
282+
# Use explicit file paths if provided (override auto-detected)
283+
if args.image_jsonl:
284+
image_path = args.image_jsonl
285+
print(f"Using image JSONL: {image_path}")
286+
image_stats = compute_stats_from_jsonl(image_path)
287+
image_random_acc, _, _ = compute_random_expected_acc(image_path)
288+
289+
if args.video_jsonl:
290+
video_path = args.video_jsonl
291+
print(f"Using video JSONL: {video_path}")
292+
video_stats = compute_stats_from_jsonl(video_path)
293+
video_random_acc, _, _ = compute_random_expected_acc(video_path)
294+
295+
# Print individual results
296+
if image_stats:
297+
print_results(
298+
"SiteBench Image",
299+
image_stats,
300+
image_stats.get("category_stats"),
301+
image_random_acc,
302+
)
303+
304+
if video_stats:
305+
print_results(
306+
"SiteBench Video",
307+
video_stats,
308+
video_stats.get("category_stats"),
309+
video_random_acc,
310+
)
311+
312+
# Compute and print combined results
313+
if image_stats and video_stats:
314+
combined_metric = merge_stats(
315+
image_stats.get("metric_stats", {}),
316+
video_stats.get("metric_stats", {}),
317+
)
318+
combined_category = merge_stats(
319+
image_stats.get("category_stats", {}),
320+
video_stats.get("category_stats", {}),
321+
)
322+
323+
# Compute combined overall
324+
img_overall = image_stats.get("overall", _empty_stats())
325+
vid_overall = video_stats.get("overall", _empty_stats())
326+
combined_overall = {
327+
"acc_num": img_overall["acc_num"] + vid_overall["acc_num"],
328+
"acc_den": img_overall["acc_den"] + vid_overall["acc_den"],
329+
"caa_num": img_overall["caa_num"] + vid_overall["caa_num"],
330+
"caa_den": img_overall["caa_den"] + vid_overall["caa_den"],
331+
}
332+
333+
combined_stats = {
334+
"metric_stats": combined_metric,
335+
"category_stats": combined_category,
336+
"overall": combined_overall,
337+
}
338+
339+
# Compute combined random expected accuracy (weighted average)
340+
combined_random_acc = None
341+
if image_random_acc is not None and video_random_acc is not None:
342+
img_count = image_stats["overall"]["acc_den"]
343+
vid_count = video_stats["overall"]["acc_den"]
344+
total_count = img_count + vid_count
345+
if total_count > 0:
346+
combined_random_acc = (image_random_acc * img_count + video_random_acc * vid_count) / total_count
347+
348+
print_results(
349+
"SiteBench Combined (Image + Video)",
350+
combined_stats,
351+
combined_category,
352+
combined_random_acc,
353+
)
354+
355+
# Save to output file if requested
356+
if args.output:
357+
output_data = {
358+
"image": {
359+
"file": image_path,
360+
"accuracy": (image_stats["overall"]["acc_num"] / image_stats["overall"]["acc_den"] * 100 if image_stats["overall"]["acc_den"] > 0 else 0),
361+
"caa": (image_stats["overall"]["caa_num"] / image_stats["overall"]["caa_den"] * 100 if image_stats["overall"]["caa_den"] > 0 else 0),
362+
"count": int(image_stats["overall"]["acc_den"]),
363+
},
364+
"video": {
365+
"file": video_path,
366+
"accuracy": (video_stats["overall"]["acc_num"] / video_stats["overall"]["acc_den"] * 100 if video_stats["overall"]["acc_den"] > 0 else 0),
367+
"caa": (video_stats["overall"]["caa_num"] / video_stats["overall"]["caa_den"] * 100 if video_stats["overall"]["caa_den"] > 0 else 0),
368+
"count": int(video_stats["overall"]["acc_den"]),
369+
},
370+
"combined": {
371+
"accuracy": (combined_overall["acc_num"] / combined_overall["acc_den"] * 100 if combined_overall["acc_den"] > 0 else 0),
372+
"caa": (combined_overall["caa_num"] / combined_overall["caa_den"] * 100 if combined_overall["caa_den"] > 0 else 0),
373+
"count": int(combined_overall["acc_den"]),
374+
},
375+
}
376+
with open(args.output, "w") as f:
377+
json.dump(output_data, f, indent=2)
378+
print(f"\nResults saved to: {args.output}")
379+
380+
elif not image_stats and not video_stats:
381+
print("\nError: No SiteBench results found!")
382+
print("Please provide:")
383+
print(" --logs-dir path/to/model/logs/")
384+
print(" OR --image-jsonl and --video-jsonl paths")
385+
return 1
386+
387+
return 0
388+
389+
390+
if __name__ == "__main__":
391+
exit(main())

0 commit comments

Comments
 (0)