|
| 1 | +# Copyright 2025 - Pruna AI GmbH. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""QA Accuracy metric using VLM for image understanding evaluation.""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +from typing import Any, Literal |
| 20 | + |
| 21 | +import numpy as np |
| 22 | +import torch |
| 23 | + |
| 24 | +from pruna.evaluation.metrics.registry import MetricRegistry |
| 25 | +from pruna.evaluation.metrics.result import MetricResult |
| 26 | +from pruna.evaluation.metrics.utils import ( |
| 27 | + SINGLE, |
| 28 | + metric_data_processor, |
| 29 | +) |
| 30 | +from pruna.evaluation.metrics.vlm_base import BaseVLM, StatefulVLMMeanScoresMetric |
| 31 | +from pruna.evaluation.metrics.vlm_utils import VQAnswer, _process_images |
| 32 | + |
| 33 | + |
| 34 | +@MetricRegistry.register("qa_accuracy") |
| 35 | +class QAAccuracyMetric(StatefulVLMMeanScoresMetric): |
| 36 | + """ |
| 37 | + QA Accuracy metric. |
| 38 | +
|
| 39 | + Uses a VLM to score yes/no alignment between each question and the generated image. |
| 40 | + Higher scores indicate better image understanding. |
| 41 | +
|
| 42 | + **Multiple questions** come from each auxiliary dict's ``questions`` mapping (e.g. GenEval |
| 43 | + atomic probes, OneIG items). Each question is scored independently via :meth:`BaseVLM.score` |
| 44 | + with expected answer ``"Yes"``. |
| 45 | +
|
| 46 | + **Aggregation** (``aggregation`` kwarg): |
| 47 | +
|
| 48 | + - ``mean`` (default): per image, average VLM scores over all questions; the metric's |
| 49 | + :meth:`compute` returns the mean of those per-image values across ``update`` calls. |
| 50 | + - ``all_or_nothing``: per image, ``1.0`` only if **every** question scores strictly above |
| 51 | + ``0.5`` (scores equal to ``0.5`` count as failure). This matches strict GenEval-style |
| 52 | + reporting (all atomic checks must pass per sample; see `GenEval |
| 53 | + <https://arxiv.org/abs/2310.11513>`_). :class:`~pruna.evaluation.task.Task` wires this for |
| 54 | + the GenEval benchmark. |
| 55 | +
|
| 56 | + Parameters |
| 57 | + ---------- |
| 58 | + *args : Any |
| 59 | + Additional positional arguments. |
| 60 | + vlm : BaseVLM | None, optional |
| 61 | + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. |
| 62 | + vlm_type : {"litellm", "transformers"}, optional |
| 63 | + VLM backend. Default is "litellm". |
| 64 | + model_name : str | None, optional |
| 65 | + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not |
| 66 | + provided (e.g. ``openai/gpt-4o``). |
| 67 | + vlm_kwargs : dict, optional |
| 68 | + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, |
| 69 | + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. |
| 70 | + structured_output : bool, optional |
| 71 | + Use structured generation (litellm pydantic; transformers outlines when applicable). |
| 72 | + Default is True. |
| 73 | + device : str | torch.device | None, optional |
| 74 | + Device for transformers VLM. |
| 75 | + api_key : str | None, optional |
| 76 | + API key for litellm. |
| 77 | + call_type : str, optional |
| 78 | + Call type for the metric. |
| 79 | + **kwargs : Any |
| 80 | + Supports ``aggregation``: ``"mean"`` or ``"all_or_nothing"``. |
| 81 | +
|
| 82 | + Raises |
| 83 | + ------ |
| 84 | + ValueError |
| 85 | + If ``aggregation`` is not ``"mean"`` or ``"all_or_nothing"``. |
| 86 | +
|
| 87 | + Examples |
| 88 | + -------- |
| 89 | + Same ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: |
| 90 | +
|
| 91 | + .. code-block:: python |
| 92 | +
|
| 93 | + import torch |
| 94 | +
|
| 95 | + from pruna.evaluation.metrics import QAAccuracyMetric |
| 96 | +
|
| 97 | + hosted = QAAccuracyMetric(vlm_type="litellm", model_name="openai/gpt-4o") |
| 98 | + local = QAAccuracyMetric( |
| 99 | + vlm_type="transformers", |
| 100 | + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", |
| 101 | + device="cpu", |
| 102 | + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, |
| 103 | + ) |
| 104 | + """ |
| 105 | + |
| 106 | + scores: list[float] |
| 107 | + default_call_type: str = "y_gt" |
| 108 | + higher_is_better: bool = True |
| 109 | + metric_units: str = "accuracy" |
| 110 | + metric_name: str = "qa_accuracy" |
| 111 | + |
| 112 | + def __init__( |
| 113 | + self, |
| 114 | + *args, |
| 115 | + vlm: BaseVLM | None = None, |
| 116 | + vlm_type: Literal["litellm", "transformers"] = "litellm", |
| 117 | + model_name: str | None = None, |
| 118 | + vlm_kwargs: dict | None = None, |
| 119 | + structured_output: bool = True, |
| 120 | + device: str | torch.device | None = None, |
| 121 | + api_key: str | None = None, |
| 122 | + call_type: str = SINGLE, |
| 123 | + **kwargs: Any, |
| 124 | + ) -> None: |
| 125 | + super().__init__(device=device) |
| 126 | + self.response_format = VQAnswer if structured_output else None |
| 127 | + self.aggregation = kwargs.pop("aggregation", "mean") |
| 128 | + if self.aggregation not in {"mean", "all_or_nothing"}: |
| 129 | + raise ValueError( |
| 130 | + f"qa_accuracy aggregation must be one of {{'mean', 'all_or_nothing'}}. Got: {self.aggregation!r}." |
| 131 | + ) |
| 132 | + self.metric_units = type(self).metric_units |
| 133 | + self._init_vlm_scores( |
| 134 | + vlm=vlm, |
| 135 | + vlm_type=vlm_type, |
| 136 | + model_name=model_name, |
| 137 | + vlm_kwargs=vlm_kwargs, |
| 138 | + structured_output=structured_output, |
| 139 | + device=device, |
| 140 | + api_key=api_key, |
| 141 | + call_type=call_type, |
| 142 | + ) |
| 143 | + |
| 144 | + def _extract_questions(self, gt: Any, n: int) -> list[list[str]]: |
| 145 | + if isinstance(gt, (list, tuple)) and len(gt) >= n: |
| 146 | + out = [] |
| 147 | + for i in range(n): |
| 148 | + v = gt[i] |
| 149 | + if isinstance(v, dict) and "questions" in v: |
| 150 | + qs = v["questions"] |
| 151 | + out.append(list(qs.values()) if isinstance(qs, dict) else list(qs)) |
| 152 | + else: |
| 153 | + out.append([]) |
| 154 | + return out |
| 155 | + return [[] for _ in range(n)] |
| 156 | + |
| 157 | + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: |
| 158 | + """ |
| 159 | + Update the metric with new batch data. |
| 160 | +
|
| 161 | + Parameters |
| 162 | + ---------- |
| 163 | + x : list[Any] | torch.Tensor |
| 164 | + The input data. |
| 165 | + gt : torch.Tensor |
| 166 | + The ground truth (questions per image). |
| 167 | + outputs : torch.Tensor |
| 168 | + The output images. |
| 169 | + """ |
| 170 | + inputs = metric_data_processor(x, gt, outputs, self.call_type) |
| 171 | + images = _process_images(inputs[0]) |
| 172 | + auxiliaries = inputs[1] if len(inputs) > 1 else [] |
| 173 | + questions_per_image = self._extract_questions(auxiliaries, len(images)) |
| 174 | + for i, image in enumerate(images): |
| 175 | + questions = questions_per_image[i] if i < len(questions_per_image) else [] |
| 176 | + if not questions: |
| 177 | + aux = auxiliaries[i] if i < len(auxiliaries) else {} |
| 178 | + raise ValueError( |
| 179 | + "qa_accuracy requires 'questions' in auxiliaries. " |
| 180 | + "Use a benchmark that provides it (e.g. GenEval, DPG, OneIG). " |
| 181 | + f"Got aux keys: {list(aux.keys()) if isinstance(aux, dict) else 'not a dict'}." |
| 182 | + ) |
| 183 | + scores = self.vlm.score( |
| 184 | + [image] * len(questions), |
| 185 | + questions, |
| 186 | + ["Yes"] * len(questions), |
| 187 | + response_format=self.response_format, |
| 188 | + ) |
| 189 | + if self.aggregation == "all_or_nothing": |
| 190 | + score = 1.0 if all(s > 0.5 for s in scores) else 0.0 |
| 191 | + else: |
| 192 | + score = float(np.mean(scores)) |
| 193 | + self.scores.append(score) |
| 194 | + |
| 195 | + def compute(self) -> MetricResult: |
| 196 | + """ |
| 197 | + Compute the QA accuracy score. |
| 198 | +
|
| 199 | + Returns |
| 200 | + ------- |
| 201 | + MetricResult |
| 202 | + The mean QA accuracy across all updates. |
| 203 | + """ |
| 204 | + return self.compute_mean_of_scores() |
0 commit comments