Skip to content

Commit e586366

Browse files
feat(text-metrics): split oneig_reasoning into dedicated branch
Adds oneig_reasoning metric, lazy registry loading, and focused benchmark/test wiring as the final text-metric stack branch. Made-with: Cursor
1 parent 3cdc2bb commit e586366

4 files changed

Lines changed: 470 additions & 2 deletions

File tree

src/pruna/evaluation/benchmarks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,13 @@ def list(cls, task_type: str | None = None) -> list[str]:
285285
task_type="text_to_image",
286286
reference="https://arxiv.org/abs/2506.07977",
287287
),
288+
Benchmark(
289+
name="OneIG Knowledge Reasoning",
290+
description="OneIG subset: knowledge- and reasoning-heavy prompts.",
291+
metrics=["oneig_reasoning"],
292+
task_type="text_to_image",
293+
reference="https://arxiv.org/abs/2506.07977",
294+
),
288295
Benchmark(
289296
name="OneIG Multilingualism",
290297
description="OneIG subset: multilingual prompts (incl. Chinese splits).",
Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""OneIG reasoning score via LLM2CLIP text-image similarity.
16+
17+
Llama-derived checkpoints may require ``HF_TOKEN`` and ``huggingface-cli login``.
18+
19+
Hugging Face download tuning (optional):
20+
21+
- ``PRUNA_ONEIG_HF_VERBOSE=1`` or ``HF_DEBUG=1`` — hub **debug** logging and tqdm
22+
progress bars (helps when stderr is piped; pair with ``python -u`` or
23+
``PYTHONUNBUFFERED=1`` for line-buffered output).
24+
- ``PRUNA_ONEIG_HF_FAST_DOWNLOAD=1`` — enable **hf_transfer** multi-part downloads
25+
(requires ``pruna[evaluation]``, which lists ``hf_transfer``). Alternatively, set
26+
``HF_HUB_ENABLE_HF_TRANSFER=1`` **before** starting Python so the hub picks it up at
27+
import time.
28+
29+
``transformers`` is pinned to ``<5`` in ``pyproject.toml``. The LLM2CLIP loading path
30+
(``CLIPImageProcessor``, ``AutoModel``, ``LlamaEncoderModel``) is exercised on **4.x**
31+
releases in CI and manual smoke runs. ``transformers`` 5.x has had reports of
32+
``from_pretrained`` not fully initializing some non-persistent buffers (for example
33+
``position_ids``) for certain architectures; the pin avoids that class of failures
34+
until those issues are clearly resolved upstream.
35+
"""
36+
37+
from __future__ import annotations
38+
39+
import os
40+
from typing import Any
41+
42+
import torch
43+
44+
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
45+
from pruna.evaluation.metrics.registry import MetricRegistry
46+
from pruna.evaluation.metrics.result import MetricResult
47+
from pruna.evaluation.metrics.utils import (
48+
SINGLE,
49+
get_call_type_for_single_metric,
50+
metric_data_processor,
51+
)
52+
from pruna.evaluation.metrics.vlm_utils import _process_images
53+
from pruna.logging.logger import pruna_logger
54+
55+
56+
def _env_truthy(raw: str | None) -> bool:
57+
if raw is None:
58+
return False
59+
return raw.strip().upper() in {"1", "ON", "YES", "TRUE"}
60+
61+
62+
def _prepare_huggingface_hub_for_oneig_downloads() -> None:
63+
"""
64+
Apply Hugging Face Hub verbosity and optional fast downloads before checkpoints load.
65+
66+
``HF_HUB_ENABLE_HF_TRANSFER`` is read when ``huggingface_hub`` loads; if it was
67+
false, we flip the in-module flag after importing ``hf_transfer`` when
68+
``PRUNA_ONEIG_HF_FAST_DOWNLOAD=1``.
69+
"""
70+
if _env_truthy(os.environ.get("PRUNA_ONEIG_HF_VERBOSE")) or _env_truthy(os.environ.get("HF_DEBUG")):
71+
from huggingface_hub.utils import enable_progress_bars
72+
from huggingface_hub.utils.logging import set_verbosity_debug
73+
74+
set_verbosity_debug()
75+
enable_progress_bars()
76+
77+
if not _env_truthy(os.environ.get("PRUNA_ONEIG_HF_FAST_DOWNLOAD")):
78+
return
79+
80+
import hf_transfer # noqa: F401 # type: ignore[import-not-found]
81+
import huggingface_hub.constants as hf_constants
82+
83+
hf_constants.HF_HUB_ENABLE_HF_TRANSFER = True
84+
pruna_logger.info("oneig_reasoning: enabled hf_transfer downloads (PRUNA_ONEIG_HF_FAST_DOWNLOAD=1).")
85+
86+
87+
def _to_pil_list(images: list) -> list:
88+
"""Convert images to list of PIL.Image (RGB)."""
89+
import numpy as np
90+
from PIL import Image
91+
92+
out: list = []
93+
for img in images:
94+
if isinstance(img, Image.Image):
95+
out.append(img.convert("RGB"))
96+
elif isinstance(img, torch.Tensor):
97+
if img.ndim == 4:
98+
img = img[0]
99+
if img.max() > 1:
100+
img = img / 255.0
101+
np_img = (img.cpu().numpy() * 255).astype("uint8")
102+
if np_img.shape[0] == 3:
103+
np_img = np_img.transpose(1, 2, 0)
104+
out.append(Image.fromarray(np_img))
105+
elif hasattr(img, "__array__"):
106+
out.append(Image.fromarray(np.asarray(img)).convert("RGB"))
107+
else:
108+
out.append(img)
109+
return out
110+
111+
112+
class _LLM2CLIPScorer:
113+
"""
114+
Thin wrapper around LLM2CLIP text-image similarity.
115+
116+
Accepts PIL images and a single answer string; returns per-image scores.
117+
Best-effort alignment with OneIG-Benchmark scripts (CUDA + bfloat16).
118+
"""
119+
120+
def __init__(
121+
self,
122+
processor_model: str = "openai/clip-vit-large-patch14-336",
123+
model_name: str = "microsoft/LLM2CLIP-Openai-L-14-336",
124+
llm_model_name: str = "microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned",
125+
device: str = "cuda",
126+
) -> None:
127+
self.processor_model = processor_model
128+
self.model_name = model_name
129+
self.llm_model_name = llm_model_name
130+
self.device = device
131+
self._processor = None
132+
self._clip_model = None
133+
self._l2v = None
134+
135+
def _load_models(self) -> None:
136+
if self._clip_model is not None:
137+
return
138+
_prepare_huggingface_hub_for_oneig_downloads()
139+
from transformers import AutoConfig, AutoModel, AutoTokenizer, CLIPImageProcessor
140+
141+
from pruna.evaluation.metrics.vendor.oneig_llm2vec import LLM2Vec
142+
from pruna.evaluation.metrics.vendor.oneig_llm2vec.modeling_llama_encoder import LlamaEncoderModel
143+
144+
pruna_logger.info(
145+
"oneig_reasoning: downloading or loading LLM2CLIP checkpoints "
146+
"(%s, %s). First run can take many minutes and several gigabytes; "
147+
"Hugging Face download progress may look idle when logs are piped.",
148+
self.model_name,
149+
self.llm_model_name,
150+
)
151+
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
152+
self._processor = CLIPImageProcessor.from_pretrained(self.processor_model)
153+
self._clip_model = AutoModel.from_pretrained(
154+
self.model_name,
155+
dtype=dtype,
156+
trust_remote_code=True,
157+
).to(self.device)
158+
self._clip_model.train(mode=False)
159+
160+
config = AutoConfig.from_pretrained(self.llm_model_name, trust_remote_code=True)
161+
dev_str = str(self.device)
162+
attn_impl = "sdpa" if dev_str == "cuda" or dev_str.startswith("cuda:") else "eager"
163+
config.attn_implementation = attn_impl
164+
if hasattr(config, "_attn_implementation"):
165+
config._attn_implementation = attn_impl
166+
llm_model = LlamaEncoderModel.from_pretrained(
167+
self.llm_model_name,
168+
dtype=dtype,
169+
config=config,
170+
trust_remote_code=True,
171+
)
172+
llm_model.config._name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct"
173+
tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
174+
self._l2v = LLM2Vec(llm_model, tokenizer, pooling_mode="mean", max_length=512, doc_max_length=512)
175+
176+
def score(self, images: list, text_prompt: str) -> list[float] | None:
177+
"""
178+
Compute similarity scores between images and text.
179+
180+
Parameters
181+
----------
182+
images : list
183+
List of PIL.Image.Image.
184+
text_prompt : str
185+
Reference text (e.g. ground-truth answer).
186+
187+
Returns
188+
-------
189+
list[float] | None
190+
Per-image scores, or None on failure.
191+
"""
192+
self._load_models()
193+
pil_images = _to_pil_list(images)
194+
if not pil_images:
195+
return None
196+
input_pixels = self._processor(images=pil_images, return_tensors="pt").pixel_values.to(self.device)
197+
captions = [text_prompt]
198+
text_features = self._l2v.encode(captions, convert_to_tensor=True, device=self.device).to(self.device)
199+
text_features = self._clip_model.get_text_features(text_features)
200+
201+
with torch.no_grad():
202+
if self.device == "cuda":
203+
with torch.amp.autocast(device_type="cuda"):
204+
image_features = self._clip_model.get_image_features(input_pixels)
205+
else:
206+
image_features = self._clip_model.get_image_features(input_pixels.float())
207+
208+
image_features = image_features.float()
209+
text_features = text_features.float()
210+
image_features /= image_features.norm(dim=-1, keepdim=True)
211+
text_features /= text_features.norm(dim=-1, keepdim=True)
212+
213+
text_probs = (image_features @ text_features.T).cpu().tolist()
214+
return [p[0] for p in text_probs]
215+
216+
217+
@MetricRegistry.register("oneig_reasoning")
218+
class OneIGReasoningMetric(StatefulMetric):
219+
"""
220+
OneIG reasoning score: LLM2CLIP similarity between GT answer text and generated image.
221+
222+
Uses ``reasoning_gt_answer`` from aux (populated by OneIG Knowledge_Reasoning loader;
223+
language is chosen at dataset load via ``reasoning_language``). MVP: 1×1 grid (whole
224+
image as single cell). Llama-derived checkpoints may require
225+
``HF_TOKEN`` and ``huggingface-cli login``.
226+
227+
Parameters
228+
----------
229+
processor_model : str, optional
230+
CLIP processor model ID.
231+
model_name : str, optional
232+
LLM2CLIP model ID.
233+
llm_model_name : str, optional
234+
LLM2Vec model ID.
235+
device : str | torch.device | None, optional
236+
Device for inference.
237+
scorer : _LLM2CLIPScorer | None, optional
238+
Optional scorer instance for testing (injected mock).
239+
call_type : str, optional
240+
Call type for the metric.
241+
**kwargs : Any
242+
Additional keyword arguments for :class:`StatefulMetric`.
243+
244+
Notes
245+
-----
246+
Prompt benchmarks yield ``(prompts, aux_list)``. With default ``call_type``
247+
``y_gt``, ``aux_list`` is the list (or tensor coerced to a list) of per-sample
248+
dicts parallel to generated images. Each dict must include a non-empty
249+
``reasoning_gt_answer`` for Knowledge/Reasoning samples. Missing GT, scorer
250+
failures, or :meth:`compute` with no scored samples raise ``ValueError`` or
251+
``RuntimeError`` instead of returning a placeholder score.
252+
"""
253+
254+
metric_name: str = "oneig_reasoning"
255+
default_call_type: str = "y_gt"
256+
higher_is_better: bool = True
257+
runs_on: list[str] = ["cuda", "cpu"]
258+
259+
def __init__(
260+
self,
261+
processor_model: str = "openai/clip-vit-large-patch14-336",
262+
model_name: str = "microsoft/LLM2CLIP-Openai-L-14-336",
263+
llm_model_name: str = "microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned",
264+
device: str | torch.device | None = None,
265+
scorer: _LLM2CLIPScorer | None = None,
266+
call_type: str | None = None,
267+
**kwargs: Any,
268+
) -> None:
269+
super().__init__(device=device, **kwargs)
270+
self.call_type = get_call_type_for_single_metric(
271+
call_type if call_type is not None else SINGLE, self.default_call_type
272+
)
273+
self.processor_model = processor_model
274+
self.model_name = model_name
275+
self.llm_model_name = llm_model_name
276+
self._scorer = scorer
277+
self.add_state("scores", default=[])
278+
279+
def _get_scorer(self) -> _LLM2CLIPScorer:
280+
if self._scorer is not None:
281+
return self._scorer
282+
return _LLM2CLIPScorer(
283+
processor_model=self.processor_model,
284+
model_name=self.model_name,
285+
llm_model_name=self.llm_model_name,
286+
device=self.device,
287+
)
288+
289+
def _get_gt_text(self, aux: dict) -> str:
290+
val = aux.get("reasoning_gt_answer")
291+
if val is None or (isinstance(val, str) and not val.strip()):
292+
raise ValueError(
293+
"oneig_reasoning requires 'reasoning_gt_answer' in aux for Knowledge_Reasoning rows. "
294+
f"Got keys: {list(aux.keys())}."
295+
)
296+
return str(val).strip()
297+
298+
def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None:
299+
"""
300+
Score each image against its GT answer text via LLM2CLIP similarity.
301+
302+
Parameters
303+
----------
304+
x : list[Any] | torch.Tensor
305+
Unused batch metadata.
306+
gt : torch.Tensor
307+
Ground-truth slot with per-sample aux dicts containing ``reasoning_gt_answer``.
308+
outputs : torch.Tensor
309+
Model outputs (generated images).
310+
311+
Raises
312+
------
313+
ValueError
314+
If a per-sample aux entry is not a dict or lacks a non-empty
315+
``reasoning_gt_answer``.
316+
RuntimeError
317+
If the LLM2CLIP scorer returns no scores for a sample.
318+
"""
319+
inputs = metric_data_processor(x, gt, outputs, self.call_type)
320+
images = _process_images(inputs[0])
321+
aux_list = inputs[1] if len(inputs) > 1 else []
322+
if isinstance(aux_list, torch.Tensor):
323+
aux_list = aux_list.tolist()
324+
325+
scorer = self._get_scorer()
326+
327+
for i, image in enumerate(images):
328+
aux = aux_list[i] if i < len(aux_list) else {}
329+
if not isinstance(aux, dict):
330+
raise ValueError(f"oneig_reasoning requires aux[{i}] to be a dict. Got: {type(aux)}.")
331+
text = self._get_gt_text(aux)
332+
result = scorer.score([image], text)
333+
if result is None or len(result) == 0:
334+
raise RuntimeError(f"oneig_reasoning: LLM2CLIP scorer returned no scores for sample {i}.")
335+
self.scores.append(float(sum(result) / len(result)))
336+
337+
def compute(self) -> MetricResult:
338+
"""
339+
Compute the mean reasoning score across all samples.
340+
341+
Returns
342+
-------
343+
MetricResult
344+
Mean LLM2CLIP similarity.
345+
346+
Raises
347+
------
348+
RuntimeError
349+
If :meth:`update` was not called or scored no samples.
350+
"""
351+
if not self.scores:
352+
raise RuntimeError(
353+
"oneig_reasoning: no samples were scored; call update() with valid "
354+
"batches and non-empty reasoning_gt_answer before compute()."
355+
)
356+
mean_score = sum(self.scores) / len(self.scores)
357+
return MetricResult(self.metric_name, self.__dict__, float(mean_score))

0 commit comments

Comments
 (0)