|
| 1 | +# Copyright 2025 - Pruna AI GmbH. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Text rendering via OCR: mean Levenshtein (``text_score`` / ``ocr_levenshtein``). |
| 16 | +
|
| 17 | +OneIG composite: ``oneig_text_score`` / ``ocr_text_score``. |
| 18 | +""" |
| 19 | + |
| 20 | +from __future__ import annotations |
| 21 | + |
| 22 | +from abc import abstractmethod |
| 23 | +from typing import Any, Literal |
| 24 | + |
| 25 | +import numpy as np |
| 26 | +import torch |
| 27 | + |
| 28 | +from pruna.engine.utils import set_to_best_available_device |
| 29 | +from pruna.evaluation.metrics.metric_stateful import StatefulMetric |
| 30 | +from pruna.evaluation.metrics.metric_text_score_utils import ( |
| 31 | + levenshtein, |
| 32 | + normalize_text_simple, |
| 33 | + oneig_mean_text_score, |
| 34 | + oneig_per_sample_contributions, |
| 35 | +) |
| 36 | +from pruna.evaluation.metrics.registry import MetricRegistry |
| 37 | +from pruna.evaluation.metrics.result import MetricResult |
| 38 | +from pruna.evaluation.metrics.utils import ( |
| 39 | + SINGLE, |
| 40 | + get_call_type_for_single_metric, |
| 41 | + metric_data_processor, |
| 42 | +) |
| 43 | +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm |
| 44 | +from pruna.evaluation.metrics.vlm_utils import TextOutput, _process_images, get_text_from_response |
| 45 | + |
| 46 | +OCR_PROMPT = ( |
| 47 | + "Extract all text visible in this image. Include logos, stylized fonts, handwritten text, " |
| 48 | + "and non-standard typography. Return only the extracted text, exactly as it appears—no preamble, " |
| 49 | + "explanation, or markdown. Preserve words, numbers, punctuation, and spacing. " |
| 50 | + "IMPORTANT: Do NOT correct spelling errors or typos. If a word is misspelled in the image " |
| 51 | + "(e.g. 'Teclhology' instead of 'Technology'), reproduce it exactly as it appears, including the misspelling. " |
| 52 | + "If no text is recognized, reply with exactly: No text recognized" |
| 53 | +) |
| 54 | + |
| 55 | + |
| 56 | +class _BaseVLMOCRTextMetric(StatefulMetric): |
| 57 | + """ |
| 58 | + Shared VLM OCR over rendered images with ground truth in ``text_content``. |
| 59 | +
|
| 60 | + Subclasses implement how OCR and GT strings are scored and aggregated. |
| 61 | +
|
| 62 | + Parameters |
| 63 | + ---------- |
| 64 | + *args : Any |
| 65 | + Additional positional arguments (unused; registry compatibility). |
| 66 | + vlm : BaseVLM | None, optional |
| 67 | + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. |
| 68 | + vlm_type : {'litellm', 'transformers'}, optional |
| 69 | + VLM backend. Default is ``'litellm'``. |
| 70 | + model_name : str | None, optional |
| 71 | + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not |
| 72 | + provided (e.g. ``openai/gpt-4o``). |
| 73 | + vlm_kwargs : dict, optional |
| 74 | + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, |
| 75 | + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. |
| 76 | + structured_output : bool, optional |
| 77 | + Use structured generation (litellm pydantic; transformers outlines when applicable). |
| 78 | + Default is True. |
| 79 | + device : str | torch.device | None, optional |
| 80 | + Device for transformers VLM. |
| 81 | + api_key : str | None, optional |
| 82 | + API key for litellm. |
| 83 | + call_type : str, optional |
| 84 | + Call type for the metric. |
| 85 | + **kwargs : Any |
| 86 | + Additional arguments. |
| 87 | +
|
| 88 | + Examples |
| 89 | + -------- |
| 90 | + OCR metrics call ``get_vlm`` directly (not ``StatefulVLMMeanScoresMetric``). Same |
| 91 | + ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: |
| 92 | +
|
| 93 | + .. code-block:: python |
| 94 | +
|
| 95 | + import torch |
| 96 | +
|
| 97 | + from pruna.evaluation.metrics import TextScoreMetric |
| 98 | +
|
| 99 | + hosted = TextScoreMetric(vlm_type="litellm", model_name="openai/gpt-4o") |
| 100 | + local = TextScoreMetric( |
| 101 | + vlm_type="transformers", |
| 102 | + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", |
| 103 | + device="cpu", |
| 104 | + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, |
| 105 | + ) |
| 106 | +
|
| 107 | + Use ``OneIGTextScoreMetric`` the same way for ``oneig_text_score`` / ``ocr_text_score``. |
| 108 | + """ |
| 109 | + |
| 110 | + default_call_type: str = "y_gt" |
| 111 | + |
| 112 | + def __init__( |
| 113 | + self, |
| 114 | + *args: Any, |
| 115 | + vlm: BaseVLM | None = None, |
| 116 | + vlm_type: Literal["litellm", "transformers"] = "litellm", |
| 117 | + model_name: str | None = None, |
| 118 | + vlm_kwargs: dict | None = None, |
| 119 | + structured_output: bool = True, |
| 120 | + device: str | torch.device | None = None, |
| 121 | + api_key: str | None = None, |
| 122 | + call_type: str = SINGLE, |
| 123 | + **kwargs: Any, |
| 124 | + ) -> None: |
| 125 | + super().__init__(device=device) |
| 126 | + self.device = set_to_best_available_device(device) |
| 127 | + |
| 128 | + self.vlm = get_vlm( |
| 129 | + vlm=vlm, |
| 130 | + vlm_type=vlm_type, |
| 131 | + model_name=model_name, |
| 132 | + device=device, |
| 133 | + api_key=api_key, |
| 134 | + structured_output=structured_output, |
| 135 | + **(vlm_kwargs or {}), |
| 136 | + ) |
| 137 | + self.response_format = TextOutput if structured_output else None |
| 138 | + |
| 139 | + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) |
| 140 | + self.higher_is_better = type(self).higher_is_better |
| 141 | + |
| 142 | + @abstractmethod |
| 143 | + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: |
| 144 | + """Update metric state from one ground-truth / OCR pair.""" |
| 145 | + |
| 146 | + @abstractmethod |
| 147 | + def _compute_result_value(self) -> float: |
| 148 | + """Return the scalar reported as ``MetricResult.result``.""" |
| 149 | + |
| 150 | + def update(self, x: list[Any] | torch.Tensor, gt: list[str], outputs: torch.Tensor) -> None: |
| 151 | + """ |
| 152 | + Run OCR on outputs and score against ``text_content`` (or string list) auxiliaries. |
| 153 | +
|
| 154 | + Parameters |
| 155 | + ---------- |
| 156 | + x : List[Any] | torch.Tensor |
| 157 | + Batch prompts or metadata. |
| 158 | + gt : list of dict or list of str |
| 159 | + Auxiliaries with ``'text_content'`` as a string, a list of strings (joined with |
| 160 | + newlines), or plain strings per batch item. |
| 161 | + outputs : torch.Tensor |
| 162 | + Rendered images. |
| 163 | + """ |
| 164 | + inputs = metric_data_processor(x, gt, outputs, self.call_type) |
| 165 | + images = _process_images(inputs[0]) |
| 166 | + auxiliaries = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], (list, tuple)) else [{}] * len(images) |
| 167 | + for i, image in enumerate(images): |
| 168 | + responses = self.vlm.generate([image], [OCR_PROMPT], response_format=self.response_format) |
| 169 | + raw = responses[0] if responses else "" |
| 170 | + ocr_text = get_text_from_response(raw) |
| 171 | + aux = auxiliaries[i] if i < len(auxiliaries) else {} |
| 172 | + text_gt = aux.get("text_content") if isinstance(aux, dict) else (aux if isinstance(aux, str) else None) |
| 173 | + if isinstance(text_gt, list): |
| 174 | + text_gt = "\n".join(str(x) for x in text_gt) |
| 175 | + if text_gt is None: |
| 176 | + raise ValueError( |
| 177 | + f"{self.metric_name} requires 'text_content' in auxiliaries. " |
| 178 | + "Use a benchmark that provides it (e.g. LongTextBench, OneIG)." |
| 179 | + ) |
| 180 | + self._accumulate_sample(text_gt, ocr_text) |
| 181 | + |
| 182 | + def compute(self) -> MetricResult: |
| 183 | + """ |
| 184 | + Aggregate batched contributions into a single metric value. |
| 185 | +
|
| 186 | + Returns |
| 187 | + ------- |
| 188 | + MetricResult |
| 189 | + Named result with ``higher_is_better`` taken from the class. |
| 190 | + """ |
| 191 | + value = self._compute_result_value() |
| 192 | + return MetricResult(self.metric_name, self.__dict__, float(value)) |
| 193 | + |
| 194 | + |
| 195 | +@MetricRegistry.register("ocr_levenshtein") |
| 196 | +@MetricRegistry.register("text_score") |
| 197 | +class TextScoreMetric(_BaseVLMOCRTextMetric): |
| 198 | + """ |
| 199 | + OCR then mean normalized character accuracy in [0, 1] (higher is better). |
| 200 | +
|
| 201 | + Registry: ``ocr_levenshtein`` (descriptive) and ``text_score`` (legacy). |
| 202 | +
|
| 203 | + Uses light normalization only (not the full OneIG preprocess). See |
| 204 | + :class:`OneIGTextScoreMetric` for the OneIG composite ``ocr_text_score``. |
| 205 | +
|
| 206 | + Parameters |
| 207 | + ---------- |
| 208 | + *args : Any |
| 209 | + Additional positional arguments (unused; registry compatibility). |
| 210 | + vlm : BaseVLM | None, optional |
| 211 | + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. |
| 212 | + vlm_type : {'litellm', 'transformers'}, optional |
| 213 | + VLM backend. Default is ``'litellm'``. |
| 214 | + model_name : str | None, optional |
| 215 | + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not |
| 216 | + provided (e.g. ``openai/gpt-4o``). |
| 217 | + vlm_kwargs : dict, optional |
| 218 | + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, |
| 219 | + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. |
| 220 | + structured_output : bool, optional |
| 221 | + Use structured generation (litellm pydantic; transformers outlines when applicable). |
| 222 | + Default is True. |
| 223 | + device : str | torch.device | None, optional |
| 224 | + Device for transformers VLM. |
| 225 | + api_key : str | None, optional |
| 226 | + API key for litellm. |
| 227 | + call_type : str, optional |
| 228 | + Call type for the metric. |
| 229 | + **kwargs : Any |
| 230 | + Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. |
| 231 | + """ |
| 232 | + |
| 233 | + scores: list[float] |
| 234 | + higher_is_better: bool = True |
| 235 | + metric_name: str = "text_score" |
| 236 | + |
| 237 | + def __init__( |
| 238 | + self, |
| 239 | + *args: Any, |
| 240 | + vlm: BaseVLM | None = None, |
| 241 | + vlm_type: Literal["litellm", "transformers"] = "litellm", |
| 242 | + model_name: str | None = None, |
| 243 | + vlm_kwargs: dict[str, Any] | None = None, |
| 244 | + structured_output: bool = True, |
| 245 | + device: str | torch.device | None = None, |
| 246 | + api_key: str | None = None, |
| 247 | + call_type: str = SINGLE, |
| 248 | + **kwargs: Any, |
| 249 | + ) -> None: |
| 250 | + super().__init__( |
| 251 | + *args, |
| 252 | + vlm=vlm, |
| 253 | + vlm_type=vlm_type, |
| 254 | + model_name=model_name, |
| 255 | + vlm_kwargs=vlm_kwargs, |
| 256 | + structured_output=structured_output, |
| 257 | + device=device, |
| 258 | + api_key=api_key, |
| 259 | + call_type=call_type, |
| 260 | + **kwargs, |
| 261 | + ) |
| 262 | + self.add_state("scores", []) |
| 263 | + |
| 264 | + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: |
| 265 | + norm_gt = normalize_text_simple(text_gt) |
| 266 | + norm_ocr = normalize_text_simple(ocr_text) |
| 267 | + ed = levenshtein(norm_ocr, norm_gt) |
| 268 | + denom = max(float(len(norm_gt)), 1.0) |
| 269 | + self.scores.append(1.0 - min(1.0, ed / denom)) |
| 270 | + |
| 271 | + def _compute_result_value(self) -> float: |
| 272 | + if not self.scores: |
| 273 | + return 0.0 |
| 274 | + return float(np.mean(self.scores)) |
| 275 | + |
| 276 | + |
| 277 | +@MetricRegistry.register("ocr_text_score") |
| 278 | +@MetricRegistry.register("oneig_text_score") |
| 279 | +class OneIGTextScoreMetric(_BaseVLMOCRTextMetric): |
| 280 | + """ |
| 281 | + OCR then OneIG-style composite text score (higher is better). |
| 282 | +
|
| 283 | + Registry: ``ocr_text_score`` (descriptive) and ``oneig_text_score`` (protocol). |
| 284 | +
|
| 285 | + Aggregates edit distance, completion rate, and word/char accuracy like |
| 286 | + ``OneIG-Benchmark/scripts/text/text_score.py``. |
| 287 | +
|
| 288 | + Parameters |
| 289 | + ---------- |
| 290 | + *args : Any |
| 291 | + Additional positional arguments (forwarded to :class:`_BaseVLMOCRTextMetric`). |
| 292 | + language_mode : {'EN', 'ZH'}, optional |
| 293 | + Selects ``MAX_EDIT_DISTANCE`` (100 vs 50) for the composite. |
| 294 | + vlm : BaseVLM | None, optional |
| 295 | + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. |
| 296 | + vlm_type : {'litellm', 'transformers'}, optional |
| 297 | + VLM backend. Default is ``'litellm'``. |
| 298 | + model_name : str | None, optional |
| 299 | + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not |
| 300 | + provided (e.g. ``openai/gpt-4o``). |
| 301 | + vlm_kwargs : dict, optional |
| 302 | + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, |
| 303 | + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. |
| 304 | + structured_output : bool, optional |
| 305 | + Use structured generation (litellm pydantic; transformers outlines when applicable). |
| 306 | + Default is True. |
| 307 | + device : str | torch.device | None, optional |
| 308 | + Device for transformers VLM. |
| 309 | + api_key : str | None, optional |
| 310 | + API key for litellm. |
| 311 | + call_type : str, optional |
| 312 | + Call type for the metric. |
| 313 | + **kwargs : Any |
| 314 | + Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. |
| 315 | + """ |
| 316 | + |
| 317 | + edit_distances: list[float] |
| 318 | + completion_ratios: list[float] |
| 319 | + match_counts: list[int] |
| 320 | + gt_totals: list[int] |
| 321 | + |
| 322 | + higher_is_better: bool = True |
| 323 | + metric_name: str = "oneig_text_score" |
| 324 | + |
| 325 | + def __init__( |
| 326 | + self, |
| 327 | + *args: Any, |
| 328 | + language_mode: Literal["EN", "ZH"] = "EN", |
| 329 | + vlm: BaseVLM | None = None, |
| 330 | + vlm_type: Literal["litellm", "transformers"] = "litellm", |
| 331 | + model_name: str | None = None, |
| 332 | + vlm_kwargs: dict[str, Any] | None = None, |
| 333 | + structured_output: bool = True, |
| 334 | + device: str | torch.device | None = None, |
| 335 | + api_key: str | None = None, |
| 336 | + call_type: str = SINGLE, |
| 337 | + **kwargs: Any, |
| 338 | + ) -> None: |
| 339 | + super().__init__( |
| 340 | + *args, |
| 341 | + vlm=vlm, |
| 342 | + vlm_type=vlm_type, |
| 343 | + model_name=model_name, |
| 344 | + vlm_kwargs=vlm_kwargs, |
| 345 | + structured_output=structured_output, |
| 346 | + device=device, |
| 347 | + api_key=api_key, |
| 348 | + call_type=call_type, |
| 349 | + **kwargs, |
| 350 | + ) |
| 351 | + self.language_mode = language_mode |
| 352 | + self.add_state("edit_distances", []) |
| 353 | + self.add_state("completion_ratios", []) |
| 354 | + self.add_state("match_counts", []) |
| 355 | + self.add_state("gt_totals", []) |
| 356 | + |
| 357 | + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: |
| 358 | + ed, cr, mcount, gtot = oneig_per_sample_contributions(text_gt, ocr_text) |
| 359 | + self.edit_distances.append(ed) |
| 360 | + self.completion_ratios.append(cr) |
| 361 | + self.match_counts.append(mcount) |
| 362 | + self.gt_totals.append(gtot) |
| 363 | + |
| 364 | + def _compute_result_value(self) -> float: |
| 365 | + *_, text_score = oneig_mean_text_score( |
| 366 | + self.edit_distances, |
| 367 | + self.completion_ratios, |
| 368 | + self.match_counts, |
| 369 | + self.gt_totals, |
| 370 | + self.language_mode, |
| 371 | + ) |
| 372 | + return text_score |
0 commit comments