Skip to content

Commit c51653e

Browse files
feat(text-metrics): split oneig_alignment into dedicated branch
Adds oneig_alignment metric implementation, its focused tests, and benchmark subset wiring while keeping reasoning and text-rendering metrics for later stacked PRs. Made-with: Cursor
1 parent 04ab2e5 commit c51653e

3 files changed

Lines changed: 394 additions & 7 deletions

File tree

src/pruna/evaluation/benchmarks.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,30 @@ def list(cls, task_type: str | None = None) -> list[str]:
272272
reference="https://arxiv.org/abs/2504.17761",
273273
),
274274
Benchmark(
275-
name="OneIG",
276-
description=(
277-
"Omni-dimensional benchmark for text-to-image evaluation. Six dataset categories "
278-
"(Anime_Stylization, General_Object, Knowledge_Reasoning, Multilingualism, Portrait, "
279-
"Text_Rendering) plus fine-grained style classes. Includes alignment questions."
280-
),
281-
metrics=[], # Paper uses dimension-specific metrics; not in Pruna
275+
name="OneIG Anime Stylization",
276+
description="OneIG subset: anime and stylized imagery.",
277+
metrics=["oneig_alignment"],
278+
task_type="text_to_image",
279+
reference="https://arxiv.org/abs/2506.07977",
280+
),
281+
Benchmark(
282+
name="OneIG General Object",
283+
description="OneIG subset: everyday objects and scenes.",
284+
metrics=["oneig_alignment"],
285+
task_type="text_to_image",
286+
reference="https://arxiv.org/abs/2506.07977",
287+
),
288+
Benchmark(
289+
name="OneIG Multilingualism",
290+
description="OneIG subset: multilingual prompts (incl. Chinese splits).",
291+
metrics=["oneig_alignment"],
292+
task_type="text_to_image",
293+
reference="https://arxiv.org/abs/2506.07977",
294+
),
295+
Benchmark(
296+
name="OneIG Portrait",
297+
description="OneIG subset: people and portraits.",
298+
metrics=["oneig_alignment"],
282299
task_type="text_to_image",
283300
reference="https://arxiv.org/abs/2506.07977",
284301
),
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""OneIG alignment scoring with dependency masking (parent ``No`` gates children)."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Any, Mapping
20+
21+
import torch
22+
23+
from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric
24+
from pruna.evaluation.metrics.registry import MetricRegistry
25+
from pruna.evaluation.metrics.utils import metric_data_processor
26+
from pruna.evaluation.metrics.vlm_utils import _process_images
27+
28+
29+
def _int_dict_keys(mapping: Mapping[Any, Any]) -> dict[int, Any]:
30+
return {int(k): v for k, v in mapping.items()}
31+
32+
33+
def _normalize_dependencies(deps: Any) -> dict[int, list[int]]:
34+
if not isinstance(deps, Mapping):
35+
return {}
36+
out: dict[int, list[int]] = {}
37+
for k, v in deps.items():
38+
key = int(k)
39+
if isinstance(v, list):
40+
out[key] = [int(p) for p in v]
41+
else:
42+
out[key] = []
43+
return out
44+
45+
46+
def _active_oneig_question_ids(qmap: dict[int, Any]) -> list[int]:
47+
"""Question ids with real prompt text (excludes HF ``datasets`` padding and empty slots)."""
48+
active: list[int] = []
49+
for qi in sorted(qmap):
50+
text = qmap[qi]
51+
if text is None:
52+
continue
53+
s = str(text).strip()
54+
if not s or s == "None":
55+
continue
56+
active.append(qi)
57+
return active
58+
59+
60+
def apply_oneig_dependency_mask(
61+
raw_scores: Mapping[int, float],
62+
dependencies: Mapping[int, list[int]],
63+
) -> dict[int, float]:
64+
"""
65+
Apply OneIG ``filter_score`` logic per dependency graph (single grid cell).
66+
67+
Parents with semantic answer ``No`` (score ``0``) force dependent question
68+
scores to ``0``. Parent id ``0`` is ignored, matching the reference script.
69+
70+
Parameters
71+
----------
72+
raw_scores : Mapping[int, float]
73+
Map question id → VLM score in ``{0, 1}`` (or float) before masking.
74+
dependencies : Mapping[int, list[int]]
75+
Map child question id → list of parent question ids (use ``[0]`` for roots).
76+
77+
Returns
78+
-------
79+
dict[int, float]
80+
Copy of scores with dependent questions zeroed when any non-zero parent
81+
scored ``0``.
82+
"""
83+
filtered = {int(k): float(v) for k, v in raw_scores.items()}
84+
deps = _normalize_dependencies(dependencies)
85+
raw = dict(filtered)
86+
for child_id, parent_ids in deps.items():
87+
if child_id not in filtered:
88+
continue
89+
any_parent_no = False
90+
for parent_id in parent_ids:
91+
if parent_id == 0:
92+
continue
93+
if parent_id not in raw:
94+
continue
95+
if raw[parent_id] == 0.0:
96+
any_parent_no = True
97+
break
98+
if any_parent_no:
99+
filtered[child_id] = 0.0
100+
return filtered
101+
102+
103+
def aggregate_oneig_alignment_per_cell(filtered_scores: Mapping[int, float], question_ids: list[int]) -> float:
104+
"""
105+
Mean filtered score over all questions in the prompt (one grid cell).
106+
107+
Parameters
108+
----------
109+
filtered_scores : Mapping[int, float]
110+
Post-mask scores for each question id.
111+
question_ids : list[int]
112+
Ordered ids (typically sorted ascending) defining the denominator.
113+
114+
Returns
115+
-------
116+
float
117+
Average score in ``[0, 1]`` if inputs are binary; ``0.0`` if ``question_ids`` is empty.
118+
"""
119+
if not question_ids:
120+
return 0.0
121+
s = sum(float(filtered_scores[qid]) for qid in question_ids)
122+
return s / float(len(question_ids))
123+
124+
125+
@MetricRegistry.register("oneig_alignment")
126+
class OneIGAlignmentMetric(QAAccuracyMetric):
127+
"""
128+
OneIG alignment with dependency-aware aggregation.
129+
130+
Reuses :class:`QAAccuracyMetric` VLM Yes/No scoring but aggregates like
131+
``OneIG-Benchmark`` ``alignment_score.py`` for a **single** grid cell (no
132+
``split_mxn_grid``): question ids are sorted numerically, raw scores are
133+
masked when any non-root parent is ``No``, then the mean over all questions
134+
is stored per image. Entries with null or blank question text (HF ``datasets``
135+
schema padding) are omitted from scoring.
136+
137+
Numerical parity with upstream also depends on the VLM (e.g. ``openai/gpt-4o`` via
138+
litellm vs reference Qwen2.5-VL).
139+
140+
Parameters
141+
----------
142+
*args : Any
143+
Additional positional arguments for :class:`QAAccuracyMetric`.
144+
vlm : BaseVLM | None, optional
145+
Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored.
146+
vlm_type : {"litellm", "transformers"}, optional
147+
VLM backend. Default is ``"litellm"``.
148+
model_name : str | None, optional
149+
Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not
150+
provided (e.g. ``openai/gpt-4o``).
151+
vlm_kwargs : dict, optional
152+
Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models,
153+
set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options.
154+
structured_output : bool, optional
155+
Use structured generation (litellm pydantic; transformers outlines when applicable).
156+
Default is True.
157+
device : str | torch.device | None, optional
158+
Device for transformers VLM.
159+
api_key : str | None, optional
160+
API key for litellm.
161+
call_type : str, optional
162+
Call type for the metric.
163+
**kwargs : Any
164+
Additional keyword arguments for :class:`QAAccuracyMetric`.
165+
166+
Examples
167+
--------
168+
Same ``hosted`` / ``local`` pattern as ``QAAccuracyMetric`` and
169+
:func:`~pruna.evaluation.metrics.vlm_base.get_vlm`:
170+
171+
.. code-block:: python
172+
173+
import torch
174+
175+
from pruna.evaluation.metrics import OneIGAlignmentMetric
176+
177+
hosted = OneIGAlignmentMetric(vlm_type="litellm", model_name="openai/gpt-4o")
178+
local = OneIGAlignmentMetric(
179+
vlm_type="transformers",
180+
model_name="HuggingFaceTB/SmolVLM-256M-Instruct",
181+
device="cpu",
182+
vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}},
183+
)
184+
"""
185+
186+
metric_name: str = "oneig_alignment"
187+
metric_units: str = "alignment"
188+
189+
def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None:
190+
"""
191+
Score each question with the VLM, apply dependency masking, append per-cell mean.
192+
193+
Parameters
194+
----------
195+
x : list[Any] | torch.Tensor
196+
Unused batch metadata (kept for metric interface).
197+
gt : torch.Tensor
198+
Ground-truth slot holding per-sample aux dicts with ``questions`` and
199+
optionally ``dependencies``.
200+
outputs : torch.Tensor
201+
Model outputs (images) evaluated against the questions.
202+
"""
203+
inputs = metric_data_processor(x, gt, outputs, self.call_type)
204+
images = _process_images(inputs[0])
205+
aux_list = inputs[1] if len(inputs) > 1 else []
206+
if isinstance(aux_list, torch.Tensor):
207+
aux_list = aux_list.tolist()
208+
for i, image in enumerate(images):
209+
aux = aux_list[i] if i < len(aux_list) else {}
210+
if not isinstance(aux, dict):
211+
raise ValueError(
212+
"oneig_alignment requires aux[{}] to be a dict with 'questions'. Got: {!r}.".format(i, type(aux))
213+
)
214+
qs = aux.get("questions")
215+
if not isinstance(qs, dict) or not qs:
216+
raise ValueError(
217+
f"oneig_alignment requires 'questions' as a non-empty dict on aux. Got keys: {list(aux.keys())}."
218+
)
219+
qmap = _int_dict_keys(qs)
220+
qids = _active_oneig_question_ids(qmap)
221+
if not qids:
222+
self.scores.append(0.0)
223+
continue
224+
question_texts = [str(qmap[qi]) for qi in qids]
225+
deps = _normalize_dependencies(aux.get("dependencies", {}))
226+
raw_scores_list = self.vlm.score(
227+
[image] * len(question_texts),
228+
question_texts,
229+
["Yes"] * len(question_texts),
230+
response_format=self.response_format,
231+
)
232+
raw_map = {qid: float(raw_scores_list[j]) for j, qid in enumerate(qids)}
233+
filtered = apply_oneig_dependency_mask(raw_map, deps)
234+
self.scores.append(aggregate_oneig_alignment_per_cell(filtered, qids))

0 commit comments

Comments
 (0)