Skip to content

Commit 04ab2e5

Browse files
feat(text-metrics): split qa_accuracy into dedicated PR branch
Isolates qa_accuracy metric implementation and GenEval benchmark wiring so it can be reviewed independently before stacking the remaining text metrics. Made-with: Cursor
1 parent 7054e53 commit 04ab2e5

2 files changed

Lines changed: 205 additions & 1 deletion

File tree

src/pruna/evaluation/benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def list(cls, task_type: str | None = None) -> list[str]:
226226
"counting, colors, position, color attributes. Evaluates fine-grained alignment "
227227
"between prompts and generated images via VQA-style questions."
228228
),
229-
metrics=["clip_score"], # §3.2: Mask2Former; not in Pruna
229+
metrics=["qa_accuracy", "clip_score"], # strict QA + CLIP score
230230
task_type="text_to_image",
231231
reference="https://arxiv.org/abs/2310.11513",
232232
),
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""QA Accuracy metric using VLM for image understanding evaluation."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Any, Literal
20+
21+
import numpy as np
22+
import torch
23+
24+
from pruna.evaluation.metrics.registry import MetricRegistry
25+
from pruna.evaluation.metrics.result import MetricResult
26+
from pruna.evaluation.metrics.utils import (
27+
SINGLE,
28+
metric_data_processor,
29+
)
30+
from pruna.evaluation.metrics.vlm_base import BaseVLM, StatefulVLMMeanScoresMetric
31+
from pruna.evaluation.metrics.vlm_utils import VQAnswer, _process_images
32+
33+
34+
@MetricRegistry.register("qa_accuracy")
35+
class QAAccuracyMetric(StatefulVLMMeanScoresMetric):
36+
"""
37+
QA Accuracy metric.
38+
39+
Uses a VLM to score yes/no alignment between each question and the generated image.
40+
Higher scores indicate better image understanding.
41+
42+
**Multiple questions** come from each auxiliary dict's ``questions`` mapping (e.g. GenEval
43+
atomic probes, OneIG items). Each question is scored independently via :meth:`BaseVLM.score`
44+
with expected answer ``"Yes"``.
45+
46+
**Aggregation** (``aggregation`` kwarg):
47+
48+
- ``mean`` (default): per image, average VLM scores over all questions; the metric's
49+
:meth:`compute` returns the mean of those per-image values across ``update`` calls.
50+
- ``all_or_nothing``: per image, ``1.0`` only if **every** question scores strictly above
51+
``0.5`` (scores equal to ``0.5`` count as failure). This matches strict GenEval-style
52+
reporting (all atomic checks must pass per sample; see `GenEval
53+
<https://arxiv.org/abs/2310.11513>`_). :class:`~pruna.evaluation.task.Task` wires this for
54+
the GenEval benchmark.
55+
56+
Parameters
57+
----------
58+
*args : Any
59+
Additional positional arguments.
60+
vlm : BaseVLM | None, optional
61+
Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored.
62+
vlm_type : {"litellm", "transformers"}, optional
63+
VLM backend. Default is "litellm".
64+
model_name : str | None, optional
65+
Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not
66+
provided (e.g. ``openai/gpt-4o``).
67+
vlm_kwargs : dict, optional
68+
Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models,
69+
set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options.
70+
structured_output : bool, optional
71+
Use structured generation (litellm pydantic; transformers outlines when applicable).
72+
Default is True.
73+
device : str | torch.device | None, optional
74+
Device for transformers VLM.
75+
api_key : str | None, optional
76+
API key for litellm.
77+
call_type : str, optional
78+
Call type for the metric.
79+
**kwargs : Any
80+
Supports ``aggregation``: ``"mean"`` or ``"all_or_nothing"``.
81+
82+
Raises
83+
------
84+
ValueError
85+
If ``aggregation`` is not ``"mean"`` or ``"all_or_nothing"``.
86+
87+
Examples
88+
--------
89+
Same ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`:
90+
91+
.. code-block:: python
92+
93+
import torch
94+
95+
from pruna.evaluation.metrics import QAAccuracyMetric
96+
97+
hosted = QAAccuracyMetric(vlm_type="litellm", model_name="openai/gpt-4o")
98+
local = QAAccuracyMetric(
99+
vlm_type="transformers",
100+
model_name="HuggingFaceTB/SmolVLM-256M-Instruct",
101+
device="cpu",
102+
vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}},
103+
)
104+
"""
105+
106+
scores: list[float]
107+
default_call_type: str = "y_gt"
108+
higher_is_better: bool = True
109+
metric_units: str = "accuracy"
110+
metric_name: str = "qa_accuracy"
111+
112+
def __init__(
113+
self,
114+
*args,
115+
vlm: BaseVLM | None = None,
116+
vlm_type: Literal["litellm", "transformers"] = "litellm",
117+
model_name: str | None = None,
118+
vlm_kwargs: dict | None = None,
119+
structured_output: bool = True,
120+
device: str | torch.device | None = None,
121+
api_key: str | None = None,
122+
call_type: str = SINGLE,
123+
**kwargs: Any,
124+
) -> None:
125+
super().__init__(device=device)
126+
self.response_format = VQAnswer if structured_output else None
127+
self.aggregation = kwargs.pop("aggregation", "mean")
128+
if self.aggregation not in {"mean", "all_or_nothing"}:
129+
raise ValueError(
130+
f"qa_accuracy aggregation must be one of {{'mean', 'all_or_nothing'}}. Got: {self.aggregation!r}."
131+
)
132+
self.metric_units = type(self).metric_units
133+
self._init_vlm_scores(
134+
vlm=vlm,
135+
vlm_type=vlm_type,
136+
model_name=model_name,
137+
vlm_kwargs=vlm_kwargs,
138+
structured_output=structured_output,
139+
device=device,
140+
api_key=api_key,
141+
call_type=call_type,
142+
)
143+
144+
def _extract_questions(self, gt: Any, n: int) -> list[list[str]]:
145+
if isinstance(gt, (list, tuple)) and len(gt) >= n:
146+
out = []
147+
for i in range(n):
148+
v = gt[i]
149+
if isinstance(v, dict) and "questions" in v:
150+
qs = v["questions"]
151+
out.append(list(qs.values()) if isinstance(qs, dict) else list(qs))
152+
else:
153+
out.append([])
154+
return out
155+
return [[] for _ in range(n)]
156+
157+
def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None:
158+
"""
159+
Update the metric with new batch data.
160+
161+
Parameters
162+
----------
163+
x : list[Any] | torch.Tensor
164+
The input data.
165+
gt : torch.Tensor
166+
The ground truth (questions per image).
167+
outputs : torch.Tensor
168+
The output images.
169+
"""
170+
inputs = metric_data_processor(x, gt, outputs, self.call_type)
171+
images = _process_images(inputs[0])
172+
auxiliaries = inputs[1] if len(inputs) > 1 else []
173+
questions_per_image = self._extract_questions(auxiliaries, len(images))
174+
for i, image in enumerate(images):
175+
questions = questions_per_image[i] if i < len(questions_per_image) else []
176+
if not questions:
177+
aux = auxiliaries[i] if i < len(auxiliaries) else {}
178+
raise ValueError(
179+
"qa_accuracy requires 'questions' in auxiliaries. "
180+
"Use a benchmark that provides it (e.g. GenEval, DPG, OneIG). "
181+
f"Got aux keys: {list(aux.keys()) if isinstance(aux, dict) else 'not a dict'}."
182+
)
183+
scores = self.vlm.score(
184+
[image] * len(questions),
185+
questions,
186+
["Yes"] * len(questions),
187+
response_format=self.response_format,
188+
)
189+
if self.aggregation == "all_or_nothing":
190+
score = 1.0 if all(s > 0.5 for s in scores) else 0.0
191+
else:
192+
score = float(np.mean(scores))
193+
self.scores.append(score)
194+
195+
def compute(self) -> MetricResult:
196+
"""
197+
Compute the QA accuracy score.
198+
199+
Returns
200+
-------
201+
MetricResult
202+
The mean QA accuracy across all updates.
203+
"""
204+
return self.compute_mean_of_scores()

0 commit comments

Comments
 (0)