|
1 | 1 | import io |
| 2 | +import json |
2 | 3 | import re |
| 4 | +from collections import Counter |
3 | 5 | from typing import Any |
4 | 6 |
|
5 | 7 | import Levenshtein |
@@ -57,6 +59,60 @@ def _to_rgb(image_obj: Any): |
57 | 59 | return None |
58 | 60 |
|
59 | 61 |
|
| 62 | +def _detect_document_language(doc) -> str: |
| 63 | + """Detect whether a document is 'en', 'cn', or 'mixed' from its answer JSON. |
| 64 | +
|
| 65 | + Inspects ``page_info.page_attribute.language`` first (most reliable). |
| 66 | + Falls back to majority vote over ``layout_dets[*].attribute.text_language`` |
| 67 | + and ``layout_dets[*].attribute.language`` (for tables). |
| 68 | + """ |
| 69 | + answer_raw = doc.get("answer") or doc.get("answers") or doc.get("target") |
| 70 | + if not answer_raw: |
| 71 | + return "mixed" |
| 72 | + |
| 73 | + if isinstance(answer_raw, list): |
| 74 | + answer_raw = answer_raw[0] |
| 75 | + if not isinstance(answer_raw, str): |
| 76 | + return "mixed" |
| 77 | + |
| 78 | + try: |
| 79 | + answer = json.loads(answer_raw) |
| 80 | + except (json.JSONDecodeError, TypeError): |
| 81 | + return "mixed" |
| 82 | + |
| 83 | + # Try page-level language first |
| 84 | + page_lang = None |
| 85 | + page_info = answer.get("page_info") or {} |
| 86 | + page_attr = page_info.get("page_attribute") or {} |
| 87 | + page_lang_raw = page_attr.get("language", "") |
| 88 | + if "chinese" in page_lang_raw: |
| 89 | + page_lang = "cn" |
| 90 | + elif "english" in page_lang_raw: |
| 91 | + page_lang = "en" |
| 92 | + |
| 93 | + if page_lang: |
| 94 | + return page_lang |
| 95 | + |
| 96 | + # Fall back to element-level majority vote |
| 97 | + lang_counts = Counter() |
| 98 | + for det in answer.get("layout_dets", []): |
| 99 | + attr = det.get("attribute") or {} |
| 100 | + # Text elements use "text_language", tables use "language" |
| 101 | + lang_val = attr.get("text_language", "") or attr.get("language", "") |
| 102 | + if "chinese" in lang_val: |
| 103 | + lang_counts["cn"] += 1 |
| 104 | + elif "english" in lang_val or lang_val.endswith("_en"): |
| 105 | + lang_counts["en"] += 1 |
| 106 | + elif "mixed" in lang_val: |
| 107 | + lang_counts["mixed"] += 1 |
| 108 | + |
| 109 | + if not lang_counts: |
| 110 | + return "mixed" |
| 111 | + |
| 112 | + top = lang_counts.most_common(1)[0][0] |
| 113 | + return top |
| 114 | + |
| 115 | + |
60 | 116 | def omnidocbench_doc_to_visual(doc): |
61 | 117 | visuals = [] |
62 | 118 |
|
@@ -114,22 +170,69 @@ def _normalized_levenshtein_score(pred: str, ref: str) -> float: |
114 | 170 | def omnidocbench_process_results(doc, results): |
115 | 171 | prediction = _normalize_text(results[0]) |
116 | 172 | answers = _extract_answers(doc) |
| 173 | + lang = _detect_document_language(doc) |
| 174 | + |
117 | 175 | if not answers: |
118 | | - return {"omnidocbench_exact_match": 0.0, "omnidocbench_nld_score": 0.0} |
| 176 | + em_score = 0.0 |
| 177 | + nld_score = 0.0 |
| 178 | + else: |
| 179 | + # Exact match |
| 180 | + answer_set = {_normalize_text(answer) for answer in answers} |
| 181 | + em_score = float(prediction in answer_set) |
119 | 182 |
|
120 | | - # Exact match |
121 | | - answer_set = {_normalize_text(answer) for answer in answers} |
122 | | - em_score = float(prediction in answer_set) |
| 183 | + options = _extract_options(doc) |
| 184 | + if options: |
| 185 | + pred_letter = _extract_option_letter(str(results[0])) |
| 186 | + if pred_letter: |
| 187 | + for answer in answers: |
| 188 | + if pred_letter == answer.strip().upper()[:1]: |
| 189 | + em_score = max(em_score, 1.0) |
123 | 190 |
|
124 | | - options = _extract_options(doc) |
125 | | - if options: |
126 | | - pred_letter = _extract_option_letter(str(results[0])) |
127 | | - if pred_letter: |
128 | | - for answer in answers: |
129 | | - if pred_letter == answer.strip().upper()[:1]: |
130 | | - em_score = max(em_score, 1.0) |
| 191 | + # Normalized Levenshtein score: (1 - NLD) * 100, take best across answers |
| 192 | + nld_score = max(_normalized_levenshtein_score(prediction, _normalize_text(answer)) for answer in answers) |
| 193 | + |
| 194 | + lang_payload_em = {"score": em_score, "lang": lang} |
| 195 | + lang_payload_nld = {"score": nld_score, "lang": lang} |
| 196 | + |
| 197 | + return { |
| 198 | + "omnidocbench_exact_match": em_score, |
| 199 | + "omnidocbench_nld_score": nld_score, |
| 200 | + "omnidocbench_exact_match_en": lang_payload_em, |
| 201 | + "omnidocbench_exact_match_cn": lang_payload_em, |
| 202 | + "omnidocbench_exact_match_mixed": lang_payload_em, |
| 203 | + "omnidocbench_nld_score_en": lang_payload_nld, |
| 204 | + "omnidocbench_nld_score_cn": lang_payload_nld, |
| 205 | + "omnidocbench_nld_score_mixed": lang_payload_nld, |
| 206 | + } |
| 207 | + |
| 208 | + |
| 209 | +def _aggregate_by_lang(results, target_langs): |
| 210 | + """Filter results by language and compute mean score.""" |
| 211 | + filtered = [r["score"] for r in results if r["lang"] in target_langs] |
| 212 | + if not filtered: |
| 213 | + return 0.0 |
| 214 | + return sum(filtered) / len(filtered) |
| 215 | + |
| 216 | + |
| 217 | +def omnidocbench_aggregate_exact_match_en(results, args): |
| 218 | + return _aggregate_by_lang(results, {"en"}) |
| 219 | + |
| 220 | + |
| 221 | +def omnidocbench_aggregate_exact_match_cn(results, args): |
| 222 | + return _aggregate_by_lang(results, {"cn"}) |
| 223 | + |
| 224 | + |
| 225 | +def omnidocbench_aggregate_exact_match_mixed(results, args): |
| 226 | + return _aggregate_by_lang(results, {"mixed"}) |
| 227 | + |
| 228 | + |
| 229 | +def omnidocbench_aggregate_nld_score_en(results, args): |
| 230 | + return _aggregate_by_lang(results, {"en"}) |
| 231 | + |
| 232 | + |
| 233 | +def omnidocbench_aggregate_nld_score_cn(results, args): |
| 234 | + return _aggregate_by_lang(results, {"cn"}) |
131 | 235 |
|
132 | | - # Normalized Levenshtein score: (1 - NLD) * 100, take best across answers |
133 | | - nld_score = max(_normalized_levenshtein_score(prediction, _normalize_text(answer)) for answer in answers) |
134 | 236 |
|
135 | | - return {"omnidocbench_exact_match": em_score, "omnidocbench_nld_score": nld_score} |
| 237 | +def omnidocbench_aggregate_nld_score_mixed(results, args): |
| 238 | + return _aggregate_by_lang(results, {"mixed"}) |
0 commit comments