Skip to content

Commit 3d76a02

Browse files
feat(vision-metrics): split vqa into dedicated branch
Introduces VQAMetric with GenAI Bench benchmark wiring and focused VQA unit coverage as the first vision metric stack PR. Made-with: Cursor
1 parent e586366 commit 3d76a02

3 files changed

Lines changed: 185 additions & 1 deletion

File tree

src/pruna/evaluation/benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def list(cls, task_type: str | None = None) -> list[str]:
174174
"Covers basic skills (scene, attributes, spatial relationships) to advanced reasoning "
175175
"(counting, comparison, logic/negation) with over 24k human ratings."
176176
),
177-
metrics=[], # Paper uses VQAScore only; not in Pruna
177+
metrics=["vqa", "clip_score"],
178178
task_type="text_to_image",
179179
reference="https://arxiv.org/abs/2406.13743",
180180
),
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
VQA (Visual Question Answering) metric.
17+
18+
Reference: VQAScore - Evaluating Text-to-Visual Generation with Image-to-Text Generation
19+
https://arxiv.org/abs/2404.01291
20+
21+
Note: VQAScore uses P(Yes) (probability of "Yes" answer) for ranking. With litellm,
22+
use_probability=True (default) requests logprobs for soft scores when the provider supports it.
23+
Set use_probability=False for binary 0/1. With ``transformers``, ``use_probability=True``
24+
uses next-token softmax mass on yes/no prefix tokens (VQAScore-style); ``False`` uses
25+
generation plus binary matching.
26+
27+
For API keys, LiteLLM vs local ``transformers``, and hosted vs local construction, see
28+
:doc:`Evaluate a model </docs_pruna/user_manual/evaluate>` (Vision-language judge metrics) and
29+
:func:`~pruna.evaluation.metrics.vlm_base.get_vlm`.
30+
"""
31+
32+
from __future__ import annotations
33+
34+
from typing import Any, Literal
35+
36+
import torch
37+
38+
from pruna.evaluation.metrics.registry import MetricRegistry
39+
from pruna.evaluation.metrics.result import MetricResult
40+
from pruna.evaluation.metrics.utils import SINGLE, metric_data_processor
41+
from pruna.evaluation.metrics.vlm_base import BaseVLM, StatefulVLMMeanScoresMetric, prompts_from_y_x_inputs
42+
from pruna.evaluation.metrics.vlm_utils import VQAnswer, _process_images
43+
44+
45+
@MetricRegistry.register("vqa")
46+
class VQAMetric(StatefulVLMMeanScoresMetric):
47+
"""
48+
VQA (Visual Question Answering) metric.
49+
50+
Uses VLM to answer "Does this image show '{prompt}'?" and scores alignment.
51+
Higher scores indicate better image-text alignment.
52+
53+
VQAScore (arXiv:2404.01291) uses P(Yes) for ranking. Default ``use_probability=True``
54+
with litellm requests logprobs for soft scores when supported.
55+
56+
Parameters
57+
----------
58+
vlm : BaseVLM | None, optional
59+
Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored.
60+
vlm_type : {"litellm", "transformers"}, optional
61+
VLM backend to use. Default is "litellm".
62+
model_name : str | None, optional
63+
Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not
64+
provided (e.g. ``openai/gpt-4o``).
65+
vlm_kwargs : dict, optional
66+
Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models,
67+
set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options.
68+
structured_output : bool, optional
69+
Use structured generation for stable outputs (litellm pydantic; transformers outlines
70+
when a string format is used). Default is True.
71+
device : str | torch.device | None, optional
72+
Device for transformers VLM.
73+
api_key : str | None, optional
74+
API key for litellm.
75+
call_type : str, optional
76+
Call type for the metric.
77+
use_probability : bool, optional
78+
If True, use P(Yes) when backend supports logprobs (litellm). Otherwise binary 0/1.
79+
Default is True for paper alignment.
80+
**kwargs : Any
81+
Additional arguments.
82+
83+
Notes
84+
-----
85+
For strict binary scoring without logprobs, pass ``use_probability=False``. Hosted vs
86+
local setup: :doc:`Evaluate a model </docs_pruna/user_manual/evaluate>` (Vision-language judge metrics).
87+
"""
88+
89+
scores: list[float]
90+
default_call_type: str = "y_x"
91+
higher_is_better: bool = True
92+
metric_name: str = "vqa"
93+
94+
def __init__(
95+
self,
96+
vlm: BaseVLM | None = None,
97+
vlm_type: Literal["litellm", "transformers"] = "litellm",
98+
model_name: str | None = None,
99+
vlm_kwargs: dict | None = None,
100+
structured_output: bool = True,
101+
device: str | torch.device | None = None,
102+
api_key: str | None = None,
103+
call_type: str = SINGLE,
104+
use_probability: bool = True,
105+
**kwargs: Any,
106+
) -> None:
107+
super().__init__(device=device)
108+
self.use_probability = use_probability
109+
self.response_format = VQAnswer if structured_output else None
110+
self._init_vlm_scores(
111+
vlm=vlm,
112+
vlm_type=vlm_type,
113+
model_name=model_name,
114+
vlm_kwargs=vlm_kwargs,
115+
structured_output=structured_output,
116+
device=device,
117+
api_key=api_key,
118+
call_type=call_type,
119+
)
120+
121+
def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None:
122+
"""
123+
Update the metric with new batch data.
124+
125+
Parameters
126+
----------
127+
x : list[Any] | torch.Tensor
128+
The input data (prompts).
129+
gt : torch.Tensor
130+
The ground truth (unused; present for call-type compatibility).
131+
outputs : torch.Tensor
132+
The output images.
133+
"""
134+
inputs = metric_data_processor(x, gt, outputs, self.call_type)
135+
images = _process_images(inputs[0])
136+
prompts = prompts_from_y_x_inputs(inputs, len(images))
137+
for i, image in enumerate(images):
138+
prompt = prompts[i] if i < len(prompts) else ""
139+
question = f'Does this image show "{prompt}"?'
140+
score = self.vlm.score(
141+
[image],
142+
[question],
143+
["Yes"],
144+
response_format=self.response_format,
145+
use_probability=self.use_probability,
146+
)[0]
147+
self.scores.append(score)
148+
149+
def compute(self) -> MetricResult:
150+
"""
151+
Compute the VQA score.
152+
153+
Returns
154+
-------
155+
MetricResult
156+
The mean VQA score across all updates.
157+
"""
158+
return self.compute_mean_of_scores()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Vision metric tests split by dedicated metric PR branches."""
2+
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
import torch
7+
8+
from pruna.evaluation.metrics.metric_vqa import VQAMetric
9+
from pruna.evaluation.metrics.vlm_base import BaseVLM
10+
11+
12+
@pytest.mark.cpu
13+
def test_vqa_uses_prompt_question_and_scores_yes_probability() -> None:
14+
"""VQA asks prompt-grounded yes/no question and stores returned score."""
15+
mock_vlm = MagicMock(spec=BaseVLM)
16+
mock_vlm.score.return_value = [0.7]
17+
18+
metric = VQAMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", use_probability=True)
19+
images = torch.rand(1, 3, 64, 64)
20+
metric.update(["a cat"], images, images)
21+
22+
result = metric.compute()
23+
assert result.name == "vqa"
24+
assert result.result == 0.7
25+
call = mock_vlm.score.call_args
26+
assert call[0][1] == ['Does this image show "a cat"?']

0 commit comments

Comments
 (0)