|
| 1 | +# Copyright 2025 - Pruna AI GmbH. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""OneIG reasoning score via LLM2CLIP text-image similarity. |
| 16 | +
|
| 17 | +Llama-derived checkpoints may require ``HF_TOKEN`` and ``huggingface-cli login``. |
| 18 | +
|
| 19 | +Hugging Face download tuning (optional): |
| 20 | +
|
| 21 | +- ``PRUNA_ONEIG_HF_VERBOSE=1`` or ``HF_DEBUG=1`` — hub **debug** logging and tqdm |
| 22 | + progress bars (helps when stderr is piped; pair with ``python -u`` or |
| 23 | + ``PYTHONUNBUFFERED=1`` for line-buffered output). |
| 24 | +- ``PRUNA_ONEIG_HF_FAST_DOWNLOAD=1`` — enable **hf_transfer** multi-part downloads |
| 25 | + (requires ``pruna[evaluation]``, which lists ``hf_transfer``). Alternatively, set |
| 26 | + ``HF_HUB_ENABLE_HF_TRANSFER=1`` **before** starting Python so the hub picks it up at |
| 27 | + import time. |
| 28 | +
|
| 29 | +``transformers`` is pinned to ``<5`` in ``pyproject.toml``. The LLM2CLIP loading path |
| 30 | +(``CLIPImageProcessor``, ``AutoModel``, ``LlamaEncoderModel``) is exercised on **4.x** |
| 31 | +releases in CI and manual smoke runs. ``transformers`` 5.x has had reports of |
| 32 | +``from_pretrained`` not fully initializing some non-persistent buffers (for example |
| 33 | +``position_ids``) for certain architectures; the pin avoids that class of failures |
| 34 | +until those issues are clearly resolved upstream. |
| 35 | +""" |
| 36 | + |
| 37 | +from __future__ import annotations |
| 38 | + |
| 39 | +import os |
| 40 | +from typing import Any |
| 41 | + |
| 42 | +import torch |
| 43 | + |
| 44 | +from pruna.evaluation.metrics.metric_stateful import StatefulMetric |
| 45 | +from pruna.evaluation.metrics.registry import MetricRegistry |
| 46 | +from pruna.evaluation.metrics.result import MetricResult |
| 47 | +from pruna.evaluation.metrics.utils import ( |
| 48 | + SINGLE, |
| 49 | + get_call_type_for_single_metric, |
| 50 | + metric_data_processor, |
| 51 | +) |
| 52 | +from pruna.evaluation.metrics.vlm_utils import _process_images |
| 53 | +from pruna.logging.logger import pruna_logger |
| 54 | + |
| 55 | + |
| 56 | +def _env_truthy(raw: str | None) -> bool: |
| 57 | + if raw is None: |
| 58 | + return False |
| 59 | + return raw.strip().upper() in {"1", "ON", "YES", "TRUE"} |
| 60 | + |
| 61 | + |
| 62 | +def _prepare_huggingface_hub_for_oneig_downloads() -> None: |
| 63 | + """ |
| 64 | + Apply Hugging Face Hub verbosity and optional fast downloads before checkpoints load. |
| 65 | +
|
| 66 | + ``HF_HUB_ENABLE_HF_TRANSFER`` is read when ``huggingface_hub`` loads; if it was |
| 67 | + false, we flip the in-module flag after importing ``hf_transfer`` when |
| 68 | + ``PRUNA_ONEIG_HF_FAST_DOWNLOAD=1``. |
| 69 | + """ |
| 70 | + if _env_truthy(os.environ.get("PRUNA_ONEIG_HF_VERBOSE")) or _env_truthy(os.environ.get("HF_DEBUG")): |
| 71 | + from huggingface_hub.utils import enable_progress_bars |
| 72 | + from huggingface_hub.utils.logging import set_verbosity_debug |
| 73 | + |
| 74 | + set_verbosity_debug() |
| 75 | + enable_progress_bars() |
| 76 | + |
| 77 | + if not _env_truthy(os.environ.get("PRUNA_ONEIG_HF_FAST_DOWNLOAD")): |
| 78 | + return |
| 79 | + |
| 80 | + import hf_transfer # noqa: F401 # type: ignore[import-not-found] |
| 81 | + import huggingface_hub.constants as hf_constants |
| 82 | + |
| 83 | + hf_constants.HF_HUB_ENABLE_HF_TRANSFER = True |
| 84 | + pruna_logger.info("oneig_reasoning: enabled hf_transfer downloads (PRUNA_ONEIG_HF_FAST_DOWNLOAD=1).") |
| 85 | + |
| 86 | + |
| 87 | +def _to_pil_list(images: list) -> list: |
| 88 | + """Convert images to list of PIL.Image (RGB).""" |
| 89 | + import numpy as np |
| 90 | + from PIL import Image |
| 91 | + |
| 92 | + out: list = [] |
| 93 | + for img in images: |
| 94 | + if isinstance(img, Image.Image): |
| 95 | + out.append(img.convert("RGB")) |
| 96 | + elif isinstance(img, torch.Tensor): |
| 97 | + if img.ndim == 4: |
| 98 | + img = img[0] |
| 99 | + if img.max() > 1: |
| 100 | + img = img / 255.0 |
| 101 | + np_img = (img.cpu().numpy() * 255).astype("uint8") |
| 102 | + if np_img.shape[0] == 3: |
| 103 | + np_img = np_img.transpose(1, 2, 0) |
| 104 | + out.append(Image.fromarray(np_img)) |
| 105 | + elif hasattr(img, "__array__"): |
| 106 | + out.append(Image.fromarray(np.asarray(img)).convert("RGB")) |
| 107 | + else: |
| 108 | + out.append(img) |
| 109 | + return out |
| 110 | + |
| 111 | + |
| 112 | +class _LLM2CLIPScorer: |
| 113 | + """ |
| 114 | + Thin wrapper around LLM2CLIP text-image similarity. |
| 115 | +
|
| 116 | + Accepts PIL images and a single answer string; returns per-image scores. |
| 117 | + Best-effort alignment with OneIG-Benchmark scripts (CUDA + bfloat16). |
| 118 | + """ |
| 119 | + |
| 120 | + def __init__( |
| 121 | + self, |
| 122 | + processor_model: str = "openai/clip-vit-large-patch14-336", |
| 123 | + model_name: str = "microsoft/LLM2CLIP-Openai-L-14-336", |
| 124 | + llm_model_name: str = "microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned", |
| 125 | + device: str = "cuda", |
| 126 | + ) -> None: |
| 127 | + self.processor_model = processor_model |
| 128 | + self.model_name = model_name |
| 129 | + self.llm_model_name = llm_model_name |
| 130 | + self.device = device |
| 131 | + self._processor = None |
| 132 | + self._clip_model = None |
| 133 | + self._l2v = None |
| 134 | + |
| 135 | + def _load_models(self) -> None: |
| 136 | + if self._clip_model is not None: |
| 137 | + return |
| 138 | + _prepare_huggingface_hub_for_oneig_downloads() |
| 139 | + from transformers import AutoConfig, AutoModel, AutoTokenizer, CLIPImageProcessor |
| 140 | + |
| 141 | + from pruna.evaluation.metrics.vendor.oneig_llm2vec import LLM2Vec |
| 142 | + from pruna.evaluation.metrics.vendor.oneig_llm2vec.modeling_llama_encoder import LlamaEncoderModel |
| 143 | + |
| 144 | + pruna_logger.info( |
| 145 | + "oneig_reasoning: downloading or loading LLM2CLIP checkpoints " |
| 146 | + "(%s, %s). First run can take many minutes and several gigabytes; " |
| 147 | + "Hugging Face download progress may look idle when logs are piped.", |
| 148 | + self.model_name, |
| 149 | + self.llm_model_name, |
| 150 | + ) |
| 151 | + dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
| 152 | + self._processor = CLIPImageProcessor.from_pretrained(self.processor_model) |
| 153 | + self._clip_model = AutoModel.from_pretrained( |
| 154 | + self.model_name, |
| 155 | + dtype=dtype, |
| 156 | + trust_remote_code=True, |
| 157 | + ).to(self.device) |
| 158 | + self._clip_model.train(mode=False) |
| 159 | + |
| 160 | + config = AutoConfig.from_pretrained(self.llm_model_name, trust_remote_code=True) |
| 161 | + dev_str = str(self.device) |
| 162 | + attn_impl = "sdpa" if dev_str == "cuda" or dev_str.startswith("cuda:") else "eager" |
| 163 | + config.attn_implementation = attn_impl |
| 164 | + if hasattr(config, "_attn_implementation"): |
| 165 | + config._attn_implementation = attn_impl |
| 166 | + llm_model = LlamaEncoderModel.from_pretrained( |
| 167 | + self.llm_model_name, |
| 168 | + dtype=dtype, |
| 169 | + config=config, |
| 170 | + trust_remote_code=True, |
| 171 | + ) |
| 172 | + llm_model.config._name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct" |
| 173 | + tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name) |
| 174 | + self._l2v = LLM2Vec(llm_model, tokenizer, pooling_mode="mean", max_length=512, doc_max_length=512) |
| 175 | + |
| 176 | + def score(self, images: list, text_prompt: str) -> list[float] | None: |
| 177 | + """ |
| 178 | + Compute similarity scores between images and text. |
| 179 | +
|
| 180 | + Parameters |
| 181 | + ---------- |
| 182 | + images : list |
| 183 | + List of PIL.Image.Image. |
| 184 | + text_prompt : str |
| 185 | + Reference text (e.g. ground-truth answer). |
| 186 | +
|
| 187 | + Returns |
| 188 | + ------- |
| 189 | + list[float] | None |
| 190 | + Per-image scores, or None on failure. |
| 191 | + """ |
| 192 | + self._load_models() |
| 193 | + pil_images = _to_pil_list(images) |
| 194 | + if not pil_images: |
| 195 | + return None |
| 196 | + input_pixels = self._processor(images=pil_images, return_tensors="pt").pixel_values.to(self.device) |
| 197 | + captions = [text_prompt] |
| 198 | + text_features = self._l2v.encode(captions, convert_to_tensor=True, device=self.device).to(self.device) |
| 199 | + text_features = self._clip_model.get_text_features(text_features) |
| 200 | + |
| 201 | + with torch.no_grad(): |
| 202 | + if self.device == "cuda": |
| 203 | + with torch.amp.autocast(device_type="cuda"): |
| 204 | + image_features = self._clip_model.get_image_features(input_pixels) |
| 205 | + else: |
| 206 | + image_features = self._clip_model.get_image_features(input_pixels.float()) |
| 207 | + |
| 208 | + image_features = image_features.float() |
| 209 | + text_features = text_features.float() |
| 210 | + image_features /= image_features.norm(dim=-1, keepdim=True) |
| 211 | + text_features /= text_features.norm(dim=-1, keepdim=True) |
| 212 | + |
| 213 | + text_probs = (image_features @ text_features.T).cpu().tolist() |
| 214 | + return [p[0] for p in text_probs] |
| 215 | + |
| 216 | + |
| 217 | +@MetricRegistry.register("oneig_reasoning") |
| 218 | +class OneIGReasoningMetric(StatefulMetric): |
| 219 | + """ |
| 220 | + OneIG reasoning score: LLM2CLIP similarity between GT answer text and generated image. |
| 221 | +
|
| 222 | + Uses ``reasoning_gt_answer`` from aux (populated by OneIG Knowledge_Reasoning loader; |
| 223 | + language is chosen at dataset load via ``reasoning_language``). MVP: 1×1 grid (whole |
| 224 | + image as single cell). Llama-derived checkpoints may require |
| 225 | + ``HF_TOKEN`` and ``huggingface-cli login``. |
| 226 | +
|
| 227 | + Parameters |
| 228 | + ---------- |
| 229 | + processor_model : str, optional |
| 230 | + CLIP processor model ID. |
| 231 | + model_name : str, optional |
| 232 | + LLM2CLIP model ID. |
| 233 | + llm_model_name : str, optional |
| 234 | + LLM2Vec model ID. |
| 235 | + device : str | torch.device | None, optional |
| 236 | + Device for inference. |
| 237 | + scorer : _LLM2CLIPScorer | None, optional |
| 238 | + Optional scorer instance for testing (injected mock). |
| 239 | + call_type : str, optional |
| 240 | + Call type for the metric. |
| 241 | + **kwargs : Any |
| 242 | + Additional keyword arguments for :class:`StatefulMetric`. |
| 243 | +
|
| 244 | + Notes |
| 245 | + ----- |
| 246 | + Prompt benchmarks yield ``(prompts, aux_list)``. With default ``call_type`` |
| 247 | + ``y_gt``, ``aux_list`` is the list (or tensor coerced to a list) of per-sample |
| 248 | + dicts parallel to generated images. Each dict must include a non-empty |
| 249 | + ``reasoning_gt_answer`` for Knowledge/Reasoning samples. Missing GT, scorer |
| 250 | + failures, or :meth:`compute` with no scored samples raise ``ValueError`` or |
| 251 | + ``RuntimeError`` instead of returning a placeholder score. |
| 252 | + """ |
| 253 | + |
| 254 | + metric_name: str = "oneig_reasoning" |
| 255 | + default_call_type: str = "y_gt" |
| 256 | + higher_is_better: bool = True |
| 257 | + runs_on: list[str] = ["cuda", "cpu"] |
| 258 | + |
| 259 | + def __init__( |
| 260 | + self, |
| 261 | + processor_model: str = "openai/clip-vit-large-patch14-336", |
| 262 | + model_name: str = "microsoft/LLM2CLIP-Openai-L-14-336", |
| 263 | + llm_model_name: str = "microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned", |
| 264 | + device: str | torch.device | None = None, |
| 265 | + scorer: _LLM2CLIPScorer | None = None, |
| 266 | + call_type: str | None = None, |
| 267 | + **kwargs: Any, |
| 268 | + ) -> None: |
| 269 | + super().__init__(device=device, **kwargs) |
| 270 | + self.call_type = get_call_type_for_single_metric( |
| 271 | + call_type if call_type is not None else SINGLE, self.default_call_type |
| 272 | + ) |
| 273 | + self.processor_model = processor_model |
| 274 | + self.model_name = model_name |
| 275 | + self.llm_model_name = llm_model_name |
| 276 | + self._scorer = scorer |
| 277 | + self.add_state("scores", default=[]) |
| 278 | + |
| 279 | + def _get_scorer(self) -> _LLM2CLIPScorer: |
| 280 | + if self._scorer is not None: |
| 281 | + return self._scorer |
| 282 | + return _LLM2CLIPScorer( |
| 283 | + processor_model=self.processor_model, |
| 284 | + model_name=self.model_name, |
| 285 | + llm_model_name=self.llm_model_name, |
| 286 | + device=self.device, |
| 287 | + ) |
| 288 | + |
| 289 | + def _get_gt_text(self, aux: dict) -> str: |
| 290 | + val = aux.get("reasoning_gt_answer") |
| 291 | + if val is None or (isinstance(val, str) and not val.strip()): |
| 292 | + raise ValueError( |
| 293 | + "oneig_reasoning requires 'reasoning_gt_answer' in aux for Knowledge_Reasoning rows. " |
| 294 | + f"Got keys: {list(aux.keys())}." |
| 295 | + ) |
| 296 | + return str(val).strip() |
| 297 | + |
| 298 | + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: |
| 299 | + """ |
| 300 | + Score each image against its GT answer text via LLM2CLIP similarity. |
| 301 | +
|
| 302 | + Parameters |
| 303 | + ---------- |
| 304 | + x : list[Any] | torch.Tensor |
| 305 | + Unused batch metadata. |
| 306 | + gt : torch.Tensor |
| 307 | + Ground-truth slot with per-sample aux dicts containing ``reasoning_gt_answer``. |
| 308 | + outputs : torch.Tensor |
| 309 | + Model outputs (generated images). |
| 310 | +
|
| 311 | + Raises |
| 312 | + ------ |
| 313 | + ValueError |
| 314 | + If a per-sample aux entry is not a dict or lacks a non-empty |
| 315 | + ``reasoning_gt_answer``. |
| 316 | + RuntimeError |
| 317 | + If the LLM2CLIP scorer returns no scores for a sample. |
| 318 | + """ |
| 319 | + inputs = metric_data_processor(x, gt, outputs, self.call_type) |
| 320 | + images = _process_images(inputs[0]) |
| 321 | + aux_list = inputs[1] if len(inputs) > 1 else [] |
| 322 | + if isinstance(aux_list, torch.Tensor): |
| 323 | + aux_list = aux_list.tolist() |
| 324 | + |
| 325 | + scorer = self._get_scorer() |
| 326 | + |
| 327 | + for i, image in enumerate(images): |
| 328 | + aux = aux_list[i] if i < len(aux_list) else {} |
| 329 | + if not isinstance(aux, dict): |
| 330 | + raise ValueError(f"oneig_reasoning requires aux[{i}] to be a dict. Got: {type(aux)}.") |
| 331 | + text = self._get_gt_text(aux) |
| 332 | + result = scorer.score([image], text) |
| 333 | + if result is None or len(result) == 0: |
| 334 | + raise RuntimeError(f"oneig_reasoning: LLM2CLIP scorer returned no scores for sample {i}.") |
| 335 | + self.scores.append(float(sum(result) / len(result))) |
| 336 | + |
| 337 | + def compute(self) -> MetricResult: |
| 338 | + """ |
| 339 | + Compute the mean reasoning score across all samples. |
| 340 | +
|
| 341 | + Returns |
| 342 | + ------- |
| 343 | + MetricResult |
| 344 | + Mean LLM2CLIP similarity. |
| 345 | +
|
| 346 | + Raises |
| 347 | + ------ |
| 348 | + RuntimeError |
| 349 | + If :meth:`update` was not called or scored no samples. |
| 350 | + """ |
| 351 | + if not self.scores: |
| 352 | + raise RuntimeError( |
| 353 | + "oneig_reasoning: no samples were scored; call update() with valid " |
| 354 | + "batches and non-empty reasoning_gt_answer before compute()." |
| 355 | + ) |
| 356 | + mean_score = sum(self.scores) / len(self.scores) |
| 357 | + return MetricResult(self.metric_name, self.__dict__, float(mean_score)) |
0 commit comments