Skip to content

Commit 540724a

Browse files
authored
Feat/ocrbench omnidocbench lang subset (#1250)
* feat(metrics): add Chinese/English subset scores for OCRBench_v2 and OmniDocBench * fix(ocrbench_v2): guard spotting_evaluation against missing hmean key
1 parent 4650095 commit 540724a

5 files changed

Lines changed: 185 additions & 25 deletions

File tree

lmms_eval/tasks/ocrbench_v2/ocrbench_v2.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,11 @@ metric_list:
2020
- metric: ocrbench_v2_accuracy
2121
aggregation: !function utils.ocrbench_v2_aggregate_accuracy
2222
higher_is_better: true
23+
- metric: ocrbench_v2_accuracy_en
24+
aggregation: !function utils.ocrbench_v2_aggregate_accuracy_en
25+
higher_is_better: true
26+
- metric: ocrbench_v2_accuracy_cn
27+
aggregation: !function utils.ocrbench_v2_aggregate_accuracy_cn
28+
higher_is_better: true
2329
metadata:
2430
- version: 0.0

lmms_eval/tasks/ocrbench_v2/spotting_metric.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,5 +173,9 @@ def spotting_evaluation(prediction_list, img_metas):
173173
command = {"g": gt_zip_path, "s": submit_zip_path, "o": temp_dir, "p": '{"IOU_CONSTRAINT":0.5}'}
174174

175175
result = rrc_evaluation_funcs.main_evaluation(command, default_evaluation_params, validate_data, evaluate_method)
176-
score = result["method"]["hmean"]
176+
method = result.get("method", {})
177+
if isinstance(method, dict):
178+
score = method.get("hmean", 0)
179+
else:
180+
score = 0
177181
return score

lmms_eval/tasks/ocrbench_v2/utils.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -368,16 +368,30 @@ def ocrbench_v2_process_results(doc, results):
368368
else:
369369
score = spotting_evaluation(predict_bbox, doc)
370370

371+
payload = {"question_type": data_type, "score": score, "prediction": pred, "ground_truth": gt_ans}
371372
return {
372-
"ocrbench_v2_accuracy": {"question_type": data_type, "score": score, "prediction": pred, "ground_truth": gt_ans},
373+
"ocrbench_v2_accuracy": payload,
374+
"ocrbench_v2_accuracy_en": payload,
375+
"ocrbench_v2_accuracy_cn": payload,
373376
}
374377

375378

376379
def calculate_average_score(categories, score_buckets):
377-
return sum(sum(score_buckets[cat]) / len(score_buckets[cat]) if len(score_buckets[cat]) > 0 else 0 for cat in categories) / len(categories)
380+
"""Weighted average across categories by sample count."""
381+
total_score = sum(sum(score_buckets[cat]) for cat in categories)
382+
total_count = sum(len(score_buckets[cat]) for cat in categories)
383+
if total_count == 0:
384+
return 0.0
385+
return total_score / total_count
378386

379387

380-
def ocrbench_v2_aggregate_accuracy(results, args):
388+
ENGLISH_TASKS = ["text_recognition_en", "text_detection_en", "text_spotting_en", "relationship_extraction_en", "element_parsing_en", "mathematical_calculation_en", "visual_text_understanding_en", "knowledge_reasoning_en"]
389+
390+
CHINESE_TASKS = ["text_recognition_cn", "relationship_extraction_cn", "element_parsing_cn", "visual_text_understanding_cn", "knowledge_reasoning_cn"]
391+
392+
393+
def _fill_score_buckets(results):
394+
"""Shared logic: populate score buckets and per-question-type scores from results."""
381395
question_type_scores = {}
382396
score_buckets = _make_score_buckets()
383397

@@ -435,14 +449,19 @@ def ocrbench_v2_aggregate_accuracy(results, args):
435449
print("No such task!")
436450
raise TypeError
437451

438-
english_tasks = ["text_recognition_en", "text_detection_en", "text_spotting_en", "relationship_extraction_en", "element_parsing_en", "mathematical_calculation_en", "visual_text_understanding_en", "knowledge_reasoning_en"]
452+
return question_type_scores, score_buckets
453+
439454

440-
chinese_tasks = ["text_recognition_cn", "relationship_extraction_cn", "element_parsing_cn", "visual_text_understanding_cn", "knowledge_reasoning_cn"]
455+
def ocrbench_v2_aggregate_accuracy(results, args):
456+
question_type_scores, score_buckets = _fill_score_buckets(results)
441457

442-
OCRBench_v2_English_subset_score = calculate_average_score(english_tasks, score_buckets)
443-
OCRBench_v2_Chinese_subset_score = calculate_average_score(chinese_tasks, score_buckets)
458+
OCRBench_v2_English_subset_score = calculate_average_score(ENGLISH_TASKS, score_buckets)
459+
OCRBench_v2_Chinese_subset_score = calculate_average_score(CHINESE_TASKS, score_buckets)
444460

445-
Final_score = (OCRBench_v2_English_subset_score + OCRBench_v2_Chinese_subset_score) / 2
461+
en_count = sum(len(score_buckets[t]) for t in ENGLISH_TASKS)
462+
cn_count = sum(len(score_buckets[t]) for t in CHINESE_TASKS)
463+
total = en_count + cn_count
464+
Final_score = (OCRBench_v2_English_subset_score * en_count + OCRBench_v2_Chinese_subset_score * cn_count) / total if total > 0 else 0.0
446465
file_name = generate_submission_file("ocrbench_v2_results.txt", args, subpath="results")
447466
with open(file_name, "w") as f:
448467
print("######################### OCRBench v2 ##########################", file=f)
@@ -451,13 +470,13 @@ def ocrbench_v2_aggregate_accuracy(results, args):
451470
avg_score = sum(scores) / len(scores) if len(scores) > 0 else 0
452471
print(f"{q_type} (sample number: {len(scores)}): {avg_score:.2f}", file=f)
453472
print("######################### English Subsets ######################", file=f)
454-
for task in english_tasks:
473+
for task in ENGLISH_TASKS:
455474
num_samples = len(score_buckets[task])
456475
avg_score = sum(score_buckets[task]) / num_samples if num_samples > 0 else 0
457476
print(f"{task.replace('_', ' ').title()} (Total {num_samples}): {avg_score:.2f}", file=f)
458477
print(f"Overall English Score: {OCRBench_v2_English_subset_score:.2f}", file=f)
459478
print("######################### Chinese Subsets ######################", file=f)
460-
for task in chinese_tasks:
479+
for task in CHINESE_TASKS:
461480
num_samples = len(score_buckets[task])
462481
avg_score = sum(score_buckets[task]) / num_samples if num_samples > 0 else 0
463482
print(f"{task.replace('_', ' ').title()} (Total {num_samples}): {avg_score:.2f}", file=f)
@@ -467,3 +486,13 @@ def ocrbench_v2_aggregate_accuracy(results, args):
467486
logger.info(f"OCRBench v2 results saved to {file_name}")
468487

469488
return Final_score # return the final score as accuracy
489+
490+
491+
def ocrbench_v2_aggregate_accuracy_en(results, args):
492+
_, score_buckets = _fill_score_buckets(results)
493+
return calculate_average_score(ENGLISH_TASKS, score_buckets)
494+
495+
496+
def ocrbench_v2_aggregate_accuracy_cn(results, args):
497+
_, score_buckets = _fill_score_buckets(results)
498+
return calculate_average_score(CHINESE_TASKS, score_buckets)

lmms_eval/tasks/omnidocbench/omnidocbench.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,24 @@ metric_list:
2323
- metric: omnidocbench_nld_score
2424
aggregation: mean
2525
higher_is_better: true
26+
- metric: omnidocbench_exact_match_en
27+
aggregation: !function utils.omnidocbench_aggregate_exact_match_en
28+
higher_is_better: true
29+
- metric: omnidocbench_exact_match_cn
30+
aggregation: !function utils.omnidocbench_aggregate_exact_match_cn
31+
higher_is_better: true
32+
- metric: omnidocbench_nld_score_en
33+
aggregation: !function utils.omnidocbench_aggregate_nld_score_en
34+
higher_is_better: true
35+
- metric: omnidocbench_nld_score_cn
36+
aggregation: !function utils.omnidocbench_aggregate_nld_score_cn
37+
higher_is_better: true
38+
- metric: omnidocbench_exact_match_mixed
39+
aggregation: !function utils.omnidocbench_aggregate_exact_match_mixed
40+
higher_is_better: true
41+
- metric: omnidocbench_nld_score_mixed
42+
aggregation: !function utils.omnidocbench_aggregate_nld_score_mixed
43+
higher_is_better: true
2644
lmms_eval_specific_kwargs:
2745
default:
2846
pre_prompt: ""

lmms_eval/tasks/omnidocbench/utils.py

Lines changed: 117 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import io
2+
import json
23
import re
4+
from collections import Counter
35
from typing import Any
46

57
import Levenshtein
@@ -57,6 +59,60 @@ def _to_rgb(image_obj: Any):
5759
return None
5860

5961

62+
def _detect_document_language(doc) -> str:
63+
"""Detect whether a document is 'en', 'cn', or 'mixed' from its answer JSON.
64+
65+
Inspects ``page_info.page_attribute.language`` first (most reliable).
66+
Falls back to majority vote over ``layout_dets[*].attribute.text_language``
67+
and ``layout_dets[*].attribute.language`` (for tables).
68+
"""
69+
answer_raw = doc.get("answer") or doc.get("answers") or doc.get("target")
70+
if not answer_raw:
71+
return "mixed"
72+
73+
if isinstance(answer_raw, list):
74+
answer_raw = answer_raw[0]
75+
if not isinstance(answer_raw, str):
76+
return "mixed"
77+
78+
try:
79+
answer = json.loads(answer_raw)
80+
except (json.JSONDecodeError, TypeError):
81+
return "mixed"
82+
83+
# Try page-level language first
84+
page_lang = None
85+
page_info = answer.get("page_info") or {}
86+
page_attr = page_info.get("page_attribute") or {}
87+
page_lang_raw = page_attr.get("language", "")
88+
if "chinese" in page_lang_raw:
89+
page_lang = "cn"
90+
elif "english" in page_lang_raw:
91+
page_lang = "en"
92+
93+
if page_lang:
94+
return page_lang
95+
96+
# Fall back to element-level majority vote
97+
lang_counts = Counter()
98+
for det in answer.get("layout_dets", []):
99+
attr = det.get("attribute") or {}
100+
# Text elements use "text_language", tables use "language"
101+
lang_val = attr.get("text_language", "") or attr.get("language", "")
102+
if "chinese" in lang_val:
103+
lang_counts["cn"] += 1
104+
elif "english" in lang_val or lang_val.endswith("_en"):
105+
lang_counts["en"] += 1
106+
elif "mixed" in lang_val:
107+
lang_counts["mixed"] += 1
108+
109+
if not lang_counts:
110+
return "mixed"
111+
112+
top = lang_counts.most_common(1)[0][0]
113+
return top
114+
115+
60116
def omnidocbench_doc_to_visual(doc):
61117
visuals = []
62118

@@ -114,22 +170,69 @@ def _normalized_levenshtein_score(pred: str, ref: str) -> float:
114170
def omnidocbench_process_results(doc, results):
115171
prediction = _normalize_text(results[0])
116172
answers = _extract_answers(doc)
173+
lang = _detect_document_language(doc)
174+
117175
if not answers:
118-
return {"omnidocbench_exact_match": 0.0, "omnidocbench_nld_score": 0.0}
176+
em_score = 0.0
177+
nld_score = 0.0
178+
else:
179+
# Exact match
180+
answer_set = {_normalize_text(answer) for answer in answers}
181+
em_score = float(prediction in answer_set)
119182

120-
# Exact match
121-
answer_set = {_normalize_text(answer) for answer in answers}
122-
em_score = float(prediction in answer_set)
183+
options = _extract_options(doc)
184+
if options:
185+
pred_letter = _extract_option_letter(str(results[0]))
186+
if pred_letter:
187+
for answer in answers:
188+
if pred_letter == answer.strip().upper()[:1]:
189+
em_score = max(em_score, 1.0)
123190

124-
options = _extract_options(doc)
125-
if options:
126-
pred_letter = _extract_option_letter(str(results[0]))
127-
if pred_letter:
128-
for answer in answers:
129-
if pred_letter == answer.strip().upper()[:1]:
130-
em_score = max(em_score, 1.0)
191+
# Normalized Levenshtein score: (1 - NLD) * 100, take best across answers
192+
nld_score = max(_normalized_levenshtein_score(prediction, _normalize_text(answer)) for answer in answers)
193+
194+
lang_payload_em = {"score": em_score, "lang": lang}
195+
lang_payload_nld = {"score": nld_score, "lang": lang}
196+
197+
return {
198+
"omnidocbench_exact_match": em_score,
199+
"omnidocbench_nld_score": nld_score,
200+
"omnidocbench_exact_match_en": lang_payload_em,
201+
"omnidocbench_exact_match_cn": lang_payload_em,
202+
"omnidocbench_exact_match_mixed": lang_payload_em,
203+
"omnidocbench_nld_score_en": lang_payload_nld,
204+
"omnidocbench_nld_score_cn": lang_payload_nld,
205+
"omnidocbench_nld_score_mixed": lang_payload_nld,
206+
}
207+
208+
209+
def _aggregate_by_lang(results, target_langs):
210+
"""Filter results by language and compute mean score."""
211+
filtered = [r["score"] for r in results if r["lang"] in target_langs]
212+
if not filtered:
213+
return 0.0
214+
return sum(filtered) / len(filtered)
215+
216+
217+
def omnidocbench_aggregate_exact_match_en(results, args):
218+
return _aggregate_by_lang(results, {"en"})
219+
220+
221+
def omnidocbench_aggregate_exact_match_cn(results, args):
222+
return _aggregate_by_lang(results, {"cn"})
223+
224+
225+
def omnidocbench_aggregate_exact_match_mixed(results, args):
226+
return _aggregate_by_lang(results, {"mixed"})
227+
228+
229+
def omnidocbench_aggregate_nld_score_en(results, args):
230+
return _aggregate_by_lang(results, {"en"})
231+
232+
233+
def omnidocbench_aggregate_nld_score_cn(results, args):
234+
return _aggregate_by_lang(results, {"cn"})
131235

132-
# Normalized Levenshtein score: (1 - NLD) * 100, take best across answers
133-
nld_score = max(_normalized_levenshtein_score(prediction, _normalize_text(answer)) for answer in answers)
134236

135-
return {"omnidocbench_exact_match": em_score, "omnidocbench_nld_score": nld_score}
237+
def omnidocbench_aggregate_nld_score_mixed(results, args):
238+
return _aggregate_by_lang(results, {"mixed"})

0 commit comments

Comments
 (0)