|
| 1 | +# Copyright 2025 - Pruna AI GmbH. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""OneIG alignment scoring with dependency masking (parent ``No`` gates children).""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +from typing import Any, Mapping |
| 20 | + |
| 21 | +import torch |
| 22 | + |
| 23 | +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric |
| 24 | +from pruna.evaluation.metrics.registry import MetricRegistry |
| 25 | +from pruna.evaluation.metrics.utils import metric_data_processor |
| 26 | +from pruna.evaluation.metrics.vlm_utils import _process_images |
| 27 | + |
| 28 | + |
| 29 | +def _int_dict_keys(mapping: Mapping[Any, Any]) -> dict[int, Any]: |
| 30 | + return {int(k): v for k, v in mapping.items()} |
| 31 | + |
| 32 | + |
| 33 | +def _normalize_dependencies(deps: Any) -> dict[int, list[int]]: |
| 34 | + if not isinstance(deps, Mapping): |
| 35 | + return {} |
| 36 | + out: dict[int, list[int]] = {} |
| 37 | + for k, v in deps.items(): |
| 38 | + key = int(k) |
| 39 | + if isinstance(v, list): |
| 40 | + out[key] = [int(p) for p in v] |
| 41 | + else: |
| 42 | + out[key] = [] |
| 43 | + return out |
| 44 | + |
| 45 | + |
| 46 | +def _active_oneig_question_ids(qmap: dict[int, Any]) -> list[int]: |
| 47 | + """Question ids with real prompt text (excludes HF ``datasets`` padding and empty slots).""" |
| 48 | + active: list[int] = [] |
| 49 | + for qi in sorted(qmap): |
| 50 | + text = qmap[qi] |
| 51 | + if text is None: |
| 52 | + continue |
| 53 | + s = str(text).strip() |
| 54 | + if not s or s == "None": |
| 55 | + continue |
| 56 | + active.append(qi) |
| 57 | + return active |
| 58 | + |
| 59 | + |
| 60 | +def apply_oneig_dependency_mask( |
| 61 | + raw_scores: Mapping[int, float], |
| 62 | + dependencies: Mapping[int, list[int]], |
| 63 | +) -> dict[int, float]: |
| 64 | + """ |
| 65 | + Apply OneIG ``filter_score`` logic per dependency graph (single grid cell). |
| 66 | +
|
| 67 | + Parents with semantic answer ``No`` (score ``0``) force dependent question |
| 68 | + scores to ``0``. Parent id ``0`` is ignored, matching the reference script. |
| 69 | +
|
| 70 | + Parameters |
| 71 | + ---------- |
| 72 | + raw_scores : Mapping[int, float] |
| 73 | + Map question id → VLM score in ``{0, 1}`` (or float) before masking. |
| 74 | + dependencies : Mapping[int, list[int]] |
| 75 | + Map child question id → list of parent question ids (use ``[0]`` for roots). |
| 76 | +
|
| 77 | + Returns |
| 78 | + ------- |
| 79 | + dict[int, float] |
| 80 | + Copy of scores with dependent questions zeroed when any non-zero parent |
| 81 | + scored ``0``. |
| 82 | + """ |
| 83 | + filtered = {int(k): float(v) for k, v in raw_scores.items()} |
| 84 | + deps = _normalize_dependencies(dependencies) |
| 85 | + raw = dict(filtered) |
| 86 | + for child_id, parent_ids in deps.items(): |
| 87 | + if child_id not in filtered: |
| 88 | + continue |
| 89 | + any_parent_no = False |
| 90 | + for parent_id in parent_ids: |
| 91 | + if parent_id == 0: |
| 92 | + continue |
| 93 | + if parent_id not in raw: |
| 94 | + continue |
| 95 | + if raw[parent_id] == 0.0: |
| 96 | + any_parent_no = True |
| 97 | + break |
| 98 | + if any_parent_no: |
| 99 | + filtered[child_id] = 0.0 |
| 100 | + return filtered |
| 101 | + |
| 102 | + |
| 103 | +def aggregate_oneig_alignment_per_cell(filtered_scores: Mapping[int, float], question_ids: list[int]) -> float: |
| 104 | + """ |
| 105 | + Mean filtered score over all questions in the prompt (one grid cell). |
| 106 | +
|
| 107 | + Parameters |
| 108 | + ---------- |
| 109 | + filtered_scores : Mapping[int, float] |
| 110 | + Post-mask scores for each question id. |
| 111 | + question_ids : list[int] |
| 112 | + Ordered ids (typically sorted ascending) defining the denominator. |
| 113 | +
|
| 114 | + Returns |
| 115 | + ------- |
| 116 | + float |
| 117 | + Average score in ``[0, 1]`` if inputs are binary; ``0.0`` if ``question_ids`` is empty. |
| 118 | + """ |
| 119 | + if not question_ids: |
| 120 | + return 0.0 |
| 121 | + s = sum(float(filtered_scores[qid]) for qid in question_ids) |
| 122 | + return s / float(len(question_ids)) |
| 123 | + |
| 124 | + |
| 125 | +@MetricRegistry.register("oneig_alignment") |
| 126 | +class OneIGAlignmentMetric(QAAccuracyMetric): |
| 127 | + """ |
| 128 | + OneIG alignment with dependency-aware aggregation. |
| 129 | +
|
| 130 | + Reuses :class:`QAAccuracyMetric` VLM Yes/No scoring but aggregates like |
| 131 | + ``OneIG-Benchmark`` ``alignment_score.py`` for a **single** grid cell (no |
| 132 | + ``split_mxn_grid``): question ids are sorted numerically, raw scores are |
| 133 | + masked when any non-root parent is ``No``, then the mean over all questions |
| 134 | + is stored per image. Entries with null or blank question text (HF ``datasets`` |
| 135 | + schema padding) are omitted from scoring. |
| 136 | +
|
| 137 | + Numerical parity with upstream also depends on the VLM (e.g. ``openai/gpt-4o`` via |
| 138 | + litellm vs reference Qwen2.5-VL). |
| 139 | +
|
| 140 | + Parameters |
| 141 | + ---------- |
| 142 | + *args : Any |
| 143 | + Additional positional arguments for :class:`QAAccuracyMetric`. |
| 144 | + vlm : BaseVLM | None, optional |
| 145 | + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. |
| 146 | + vlm_type : {"litellm", "transformers"}, optional |
| 147 | + VLM backend. Default is ``"litellm"``. |
| 148 | + model_name : str | None, optional |
| 149 | + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not |
| 150 | + provided (e.g. ``openai/gpt-4o``). |
| 151 | + vlm_kwargs : dict, optional |
| 152 | + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, |
| 153 | + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. |
| 154 | + structured_output : bool, optional |
| 155 | + Use structured generation (litellm pydantic; transformers outlines when applicable). |
| 156 | + Default is True. |
| 157 | + device : str | torch.device | None, optional |
| 158 | + Device for transformers VLM. |
| 159 | + api_key : str | None, optional |
| 160 | + API key for litellm. |
| 161 | + call_type : str, optional |
| 162 | + Call type for the metric. |
| 163 | + **kwargs : Any |
| 164 | + Additional keyword arguments for :class:`QAAccuracyMetric`. |
| 165 | +
|
| 166 | + Examples |
| 167 | + -------- |
| 168 | + Same ``hosted`` / ``local`` pattern as ``QAAccuracyMetric`` and |
| 169 | + :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: |
| 170 | +
|
| 171 | + .. code-block:: python |
| 172 | +
|
| 173 | + import torch |
| 174 | +
|
| 175 | + from pruna.evaluation.metrics import OneIGAlignmentMetric |
| 176 | +
|
| 177 | + hosted = OneIGAlignmentMetric(vlm_type="litellm", model_name="openai/gpt-4o") |
| 178 | + local = OneIGAlignmentMetric( |
| 179 | + vlm_type="transformers", |
| 180 | + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", |
| 181 | + device="cpu", |
| 182 | + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, |
| 183 | + ) |
| 184 | + """ |
| 185 | + |
| 186 | + metric_name: str = "oneig_alignment" |
| 187 | + metric_units: str = "alignment" |
| 188 | + |
| 189 | + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: |
| 190 | + """ |
| 191 | + Score each question with the VLM, apply dependency masking, append per-cell mean. |
| 192 | +
|
| 193 | + Parameters |
| 194 | + ---------- |
| 195 | + x : list[Any] | torch.Tensor |
| 196 | + Unused batch metadata (kept for metric interface). |
| 197 | + gt : torch.Tensor |
| 198 | + Ground-truth slot holding per-sample aux dicts with ``questions`` and |
| 199 | + optionally ``dependencies``. |
| 200 | + outputs : torch.Tensor |
| 201 | + Model outputs (images) evaluated against the questions. |
| 202 | + """ |
| 203 | + inputs = metric_data_processor(x, gt, outputs, self.call_type) |
| 204 | + images = _process_images(inputs[0]) |
| 205 | + aux_list = inputs[1] if len(inputs) > 1 else [] |
| 206 | + if isinstance(aux_list, torch.Tensor): |
| 207 | + aux_list = aux_list.tolist() |
| 208 | + for i, image in enumerate(images): |
| 209 | + aux = aux_list[i] if i < len(aux_list) else {} |
| 210 | + if not isinstance(aux, dict): |
| 211 | + raise ValueError( |
| 212 | + "oneig_alignment requires aux[{}] to be a dict with 'questions'. Got: {!r}.".format(i, type(aux)) |
| 213 | + ) |
| 214 | + qs = aux.get("questions") |
| 215 | + if not isinstance(qs, dict) or not qs: |
| 216 | + raise ValueError( |
| 217 | + f"oneig_alignment requires 'questions' as a non-empty dict on aux. Got keys: {list(aux.keys())}." |
| 218 | + ) |
| 219 | + qmap = _int_dict_keys(qs) |
| 220 | + qids = _active_oneig_question_ids(qmap) |
| 221 | + if not qids: |
| 222 | + self.scores.append(0.0) |
| 223 | + continue |
| 224 | + question_texts = [str(qmap[qi]) for qi in qids] |
| 225 | + deps = _normalize_dependencies(aux.get("dependencies", {})) |
| 226 | + raw_scores_list = self.vlm.score( |
| 227 | + [image] * len(question_texts), |
| 228 | + question_texts, |
| 229 | + ["Yes"] * len(question_texts), |
| 230 | + response_format=self.response_format, |
| 231 | + ) |
| 232 | + raw_map = {qid: float(raw_scores_list[j]) for j, qid in enumerate(qids)} |
| 233 | + filtered = apply_oneig_dependency_mask(raw_map, deps) |
| 234 | + self.scores.append(aggregate_oneig_alignment_per_cell(filtered, qids)) |
0 commit comments