Skip to content

Commit 3cdc2bb

Browse files
feat(text-metrics): split text_score pair into dedicated branch
Adds text_score and oneig_text_score metrics together with shared OCR text utilities and benchmark wiring for Long Text Bench and OneIG Text Rendering. Made-with: Cursor
1 parent 2627d78 commit 3cdc2bb

3 files changed

Lines changed: 654 additions & 1 deletion

File tree

src/pruna/evaluation/benchmarks.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def list(cls, task_type: str | None = None) -> list[str]:
256256
"Text-to-image benchmark for long, detailed prompts. Evaluates model ability to "
257257
"handle complex multi-clause descriptions and maintain coherence across long instructions."
258258
),
259-
metrics=[], # Paper uses text_score/TIT-Score; not in Pruna
259+
metrics=["text_score"],
260260
task_type="text_to_image",
261261
reference="https://arxiv.org/abs/2507.22058",
262262
),
@@ -299,6 +299,13 @@ def list(cls, task_type: str | None = None) -> list[str]:
299299
task_type="text_to_image",
300300
reference="https://arxiv.org/abs/2506.07977",
301301
),
302+
Benchmark(
303+
name="OneIG Text Rendering",
304+
description="OneIG subset: text and graphics painted into the image.",
305+
metrics=["oneig_text_score"],
306+
task_type="text_to_image",
307+
reference="https://arxiv.org/abs/2506.07977",
308+
),
302309
Benchmark(
303310
name="DPG",
304311
description=(
Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Text rendering via OCR: mean Levenshtein (``text_score`` / ``ocr_levenshtein``).
16+
17+
OneIG composite: ``oneig_text_score`` / ``ocr_text_score``.
18+
"""
19+
20+
from __future__ import annotations
21+
22+
from abc import abstractmethod
23+
from typing import Any, Literal
24+
25+
import numpy as np
26+
import torch
27+
28+
from pruna.engine.utils import set_to_best_available_device
29+
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
30+
from pruna.evaluation.metrics.metric_text_score_utils import (
31+
levenshtein,
32+
normalize_text_simple,
33+
oneig_mean_text_score,
34+
oneig_per_sample_contributions,
35+
)
36+
from pruna.evaluation.metrics.registry import MetricRegistry
37+
from pruna.evaluation.metrics.result import MetricResult
38+
from pruna.evaluation.metrics.utils import (
39+
SINGLE,
40+
get_call_type_for_single_metric,
41+
metric_data_processor,
42+
)
43+
from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm
44+
from pruna.evaluation.metrics.vlm_utils import TextOutput, _process_images, get_text_from_response
45+
46+
OCR_PROMPT = (
47+
"Extract all text visible in this image. Include logos, stylized fonts, handwritten text, "
48+
"and non-standard typography. Return only the extracted text, exactly as it appears—no preamble, "
49+
"explanation, or markdown. Preserve words, numbers, punctuation, and spacing. "
50+
"IMPORTANT: Do NOT correct spelling errors or typos. If a word is misspelled in the image "
51+
"(e.g. 'Teclhology' instead of 'Technology'), reproduce it exactly as it appears, including the misspelling. "
52+
"If no text is recognized, reply with exactly: No text recognized"
53+
)
54+
55+
56+
class _BaseVLMOCRTextMetric(StatefulMetric):
57+
"""
58+
Shared VLM OCR over rendered images with ground truth in ``text_content``.
59+
60+
Subclasses implement how OCR and GT strings are scored and aggregated.
61+
62+
Parameters
63+
----------
64+
*args : Any
65+
Additional positional arguments (unused; registry compatibility).
66+
vlm : BaseVLM | None, optional
67+
Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored.
68+
vlm_type : {'litellm', 'transformers'}, optional
69+
VLM backend. Default is ``'litellm'``.
70+
model_name : str | None, optional
71+
Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not
72+
provided (e.g. ``openai/gpt-4o``).
73+
vlm_kwargs : dict, optional
74+
Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models,
75+
set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options.
76+
structured_output : bool, optional
77+
Use structured generation (litellm pydantic; transformers outlines when applicable).
78+
Default is True.
79+
device : str | torch.device | None, optional
80+
Device for transformers VLM.
81+
api_key : str | None, optional
82+
API key for litellm.
83+
call_type : str, optional
84+
Call type for the metric.
85+
**kwargs : Any
86+
Additional arguments.
87+
88+
Examples
89+
--------
90+
OCR metrics call ``get_vlm`` directly (not ``StatefulVLMMeanScoresMetric``). Same
91+
``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`:
92+
93+
.. code-block:: python
94+
95+
import torch
96+
97+
from pruna.evaluation.metrics import TextScoreMetric
98+
99+
hosted = TextScoreMetric(vlm_type="litellm", model_name="openai/gpt-4o")
100+
local = TextScoreMetric(
101+
vlm_type="transformers",
102+
model_name="HuggingFaceTB/SmolVLM-256M-Instruct",
103+
device="cpu",
104+
vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}},
105+
)
106+
107+
Use ``OneIGTextScoreMetric`` the same way for ``oneig_text_score`` / ``ocr_text_score``.
108+
"""
109+
110+
default_call_type: str = "y_gt"
111+
112+
def __init__(
113+
self,
114+
*args: Any,
115+
vlm: BaseVLM | None = None,
116+
vlm_type: Literal["litellm", "transformers"] = "litellm",
117+
model_name: str | None = None,
118+
vlm_kwargs: dict | None = None,
119+
structured_output: bool = True,
120+
device: str | torch.device | None = None,
121+
api_key: str | None = None,
122+
call_type: str = SINGLE,
123+
**kwargs: Any,
124+
) -> None:
125+
super().__init__(device=device)
126+
self.device = set_to_best_available_device(device)
127+
128+
self.vlm = get_vlm(
129+
vlm=vlm,
130+
vlm_type=vlm_type,
131+
model_name=model_name,
132+
device=device,
133+
api_key=api_key,
134+
structured_output=structured_output,
135+
**(vlm_kwargs or {}),
136+
)
137+
self.response_format = TextOutput if structured_output else None
138+
139+
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
140+
self.higher_is_better = type(self).higher_is_better
141+
142+
@abstractmethod
143+
def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None:
144+
"""Update metric state from one ground-truth / OCR pair."""
145+
146+
@abstractmethod
147+
def _compute_result_value(self) -> float:
148+
"""Return the scalar reported as ``MetricResult.result``."""
149+
150+
def update(self, x: list[Any] | torch.Tensor, gt: list[str], outputs: torch.Tensor) -> None:
151+
"""
152+
Run OCR on outputs and score against ``text_content`` (or string list) auxiliaries.
153+
154+
Parameters
155+
----------
156+
x : List[Any] | torch.Tensor
157+
Batch prompts or metadata.
158+
gt : list of dict or list of str
159+
Auxiliaries with ``'text_content'`` as a string, a list of strings (joined with
160+
newlines), or plain strings per batch item.
161+
outputs : torch.Tensor
162+
Rendered images.
163+
"""
164+
inputs = metric_data_processor(x, gt, outputs, self.call_type)
165+
images = _process_images(inputs[0])
166+
auxiliaries = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], (list, tuple)) else [{}] * len(images)
167+
for i, image in enumerate(images):
168+
responses = self.vlm.generate([image], [OCR_PROMPT], response_format=self.response_format)
169+
raw = responses[0] if responses else ""
170+
ocr_text = get_text_from_response(raw)
171+
aux = auxiliaries[i] if i < len(auxiliaries) else {}
172+
text_gt = aux.get("text_content") if isinstance(aux, dict) else (aux if isinstance(aux, str) else None)
173+
if isinstance(text_gt, list):
174+
text_gt = "\n".join(str(x) for x in text_gt)
175+
if text_gt is None:
176+
raise ValueError(
177+
f"{self.metric_name} requires 'text_content' in auxiliaries. "
178+
"Use a benchmark that provides it (e.g. LongTextBench, OneIG)."
179+
)
180+
self._accumulate_sample(text_gt, ocr_text)
181+
182+
def compute(self) -> MetricResult:
183+
"""
184+
Aggregate batched contributions into a single metric value.
185+
186+
Returns
187+
-------
188+
MetricResult
189+
Named result with ``higher_is_better`` taken from the class.
190+
"""
191+
value = self._compute_result_value()
192+
return MetricResult(self.metric_name, self.__dict__, float(value))
193+
194+
195+
@MetricRegistry.register("ocr_levenshtein")
196+
@MetricRegistry.register("text_score")
197+
class TextScoreMetric(_BaseVLMOCRTextMetric):
198+
"""
199+
OCR then mean normalized character accuracy in [0, 1] (higher is better).
200+
201+
Registry: ``ocr_levenshtein`` (descriptive) and ``text_score`` (legacy).
202+
203+
Uses light normalization only (not the full OneIG preprocess). See
204+
:class:`OneIGTextScoreMetric` for the OneIG composite ``ocr_text_score``.
205+
206+
Parameters
207+
----------
208+
*args : Any
209+
Additional positional arguments (unused; registry compatibility).
210+
vlm : BaseVLM | None, optional
211+
Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored.
212+
vlm_type : {'litellm', 'transformers'}, optional
213+
VLM backend. Default is ``'litellm'``.
214+
model_name : str | None, optional
215+
Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not
216+
provided (e.g. ``openai/gpt-4o``).
217+
vlm_kwargs : dict, optional
218+
Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models,
219+
set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options.
220+
structured_output : bool, optional
221+
Use structured generation (litellm pydantic; transformers outlines when applicable).
222+
Default is True.
223+
device : str | torch.device | None, optional
224+
Device for transformers VLM.
225+
api_key : str | None, optional
226+
API key for litellm.
227+
call_type : str, optional
228+
Call type for the metric.
229+
**kwargs : Any
230+
Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`.
231+
"""
232+
233+
scores: list[float]
234+
higher_is_better: bool = True
235+
metric_name: str = "text_score"
236+
237+
def __init__(
238+
self,
239+
*args: Any,
240+
vlm: BaseVLM | None = None,
241+
vlm_type: Literal["litellm", "transformers"] = "litellm",
242+
model_name: str | None = None,
243+
vlm_kwargs: dict[str, Any] | None = None,
244+
structured_output: bool = True,
245+
device: str | torch.device | None = None,
246+
api_key: str | None = None,
247+
call_type: str = SINGLE,
248+
**kwargs: Any,
249+
) -> None:
250+
super().__init__(
251+
*args,
252+
vlm=vlm,
253+
vlm_type=vlm_type,
254+
model_name=model_name,
255+
vlm_kwargs=vlm_kwargs,
256+
structured_output=structured_output,
257+
device=device,
258+
api_key=api_key,
259+
call_type=call_type,
260+
**kwargs,
261+
)
262+
self.add_state("scores", [])
263+
264+
def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None:
265+
norm_gt = normalize_text_simple(text_gt)
266+
norm_ocr = normalize_text_simple(ocr_text)
267+
ed = levenshtein(norm_ocr, norm_gt)
268+
denom = max(float(len(norm_gt)), 1.0)
269+
self.scores.append(1.0 - min(1.0, ed / denom))
270+
271+
def _compute_result_value(self) -> float:
272+
if not self.scores:
273+
return 0.0
274+
return float(np.mean(self.scores))
275+
276+
277+
@MetricRegistry.register("ocr_text_score")
278+
@MetricRegistry.register("oneig_text_score")
279+
class OneIGTextScoreMetric(_BaseVLMOCRTextMetric):
280+
"""
281+
OCR then OneIG-style composite text score (higher is better).
282+
283+
Registry: ``ocr_text_score`` (descriptive) and ``oneig_text_score`` (protocol).
284+
285+
Aggregates edit distance, completion rate, and word/char accuracy like
286+
``OneIG-Benchmark/scripts/text/text_score.py``.
287+
288+
Parameters
289+
----------
290+
*args : Any
291+
Additional positional arguments (forwarded to :class:`_BaseVLMOCRTextMetric`).
292+
language_mode : {'EN', 'ZH'}, optional
293+
Selects ``MAX_EDIT_DISTANCE`` (100 vs 50) for the composite.
294+
vlm : BaseVLM | None, optional
295+
Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored.
296+
vlm_type : {'litellm', 'transformers'}, optional
297+
VLM backend. Default is ``'litellm'``.
298+
model_name : str | None, optional
299+
Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not
300+
provided (e.g. ``openai/gpt-4o``).
301+
vlm_kwargs : dict, optional
302+
Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models,
303+
set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options.
304+
structured_output : bool, optional
305+
Use structured generation (litellm pydantic; transformers outlines when applicable).
306+
Default is True.
307+
device : str | torch.device | None, optional
308+
Device for transformers VLM.
309+
api_key : str | None, optional
310+
API key for litellm.
311+
call_type : str, optional
312+
Call type for the metric.
313+
**kwargs : Any
314+
Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`.
315+
"""
316+
317+
edit_distances: list[float]
318+
completion_ratios: list[float]
319+
match_counts: list[int]
320+
gt_totals: list[int]
321+
322+
higher_is_better: bool = True
323+
metric_name: str = "oneig_text_score"
324+
325+
def __init__(
326+
self,
327+
*args: Any,
328+
language_mode: Literal["EN", "ZH"] = "EN",
329+
vlm: BaseVLM | None = None,
330+
vlm_type: Literal["litellm", "transformers"] = "litellm",
331+
model_name: str | None = None,
332+
vlm_kwargs: dict[str, Any] | None = None,
333+
structured_output: bool = True,
334+
device: str | torch.device | None = None,
335+
api_key: str | None = None,
336+
call_type: str = SINGLE,
337+
**kwargs: Any,
338+
) -> None:
339+
super().__init__(
340+
*args,
341+
vlm=vlm,
342+
vlm_type=vlm_type,
343+
model_name=model_name,
344+
vlm_kwargs=vlm_kwargs,
345+
structured_output=structured_output,
346+
device=device,
347+
api_key=api_key,
348+
call_type=call_type,
349+
**kwargs,
350+
)
351+
self.language_mode = language_mode
352+
self.add_state("edit_distances", [])
353+
self.add_state("completion_ratios", [])
354+
self.add_state("match_counts", [])
355+
self.add_state("gt_totals", [])
356+
357+
def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None:
358+
ed, cr, mcount, gtot = oneig_per_sample_contributions(text_gt, ocr_text)
359+
self.edit_distances.append(ed)
360+
self.completion_ratios.append(cr)
361+
self.match_counts.append(mcount)
362+
self.gt_totals.append(gtot)
363+
364+
def _compute_result_value(self) -> float:
365+
*_, text_score = oneig_mean_text_score(
366+
self.edit_distances,
367+
self.completion_ratios,
368+
self.match_counts,
369+
self.gt_totals,
370+
self.language_mode,
371+
)
372+
return text_score

0 commit comments

Comments
 (0)