From 7f0966c1fae4e49d023b76d08b1c2689b7998541 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 28 Apr 2026 14:58:39 +0200 Subject: [PATCH 1/7] feat(text-metrics): split qa_accuracy into dedicated PR branch Isolates qa_accuracy metric implementation and GenEval benchmark wiring so it can be reviewed independently before stacking the remaining text metrics. Made-with: Cursor --- src/pruna/evaluation/benchmarks.py | 2 +- .../evaluation/metrics/metric_qa_accuracy.py | 204 ++++++++++++++++++ 2 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 src/pruna/evaluation/metrics/metric_qa_accuracy.py diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py index e52ae463..de005c9b 100644 --- a/src/pruna/evaluation/benchmarks.py +++ b/src/pruna/evaluation/benchmarks.py @@ -226,7 +226,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "counting, colors, position, color attributes. Evaluates fine-grained alignment " "between prompts and generated images via VQA-style questions." ), - metrics=["clip_score"], # §3.2: Mask2Former; not in Pruna + metrics=["qa_accuracy", "clip_score"], # strict QA + CLIP score task_type="text_to_image", reference="https://arxiv.org/abs/2310.11513", ), diff --git a/src/pruna/evaluation/metrics/metric_qa_accuracy.py b/src/pruna/evaluation/metrics/metric_qa_accuracy.py new file mode 100644 index 00000000..6dd36c2f --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_qa_accuracy.py @@ -0,0 +1,204 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QA Accuracy metric using VLM for image understanding evaluation.""" + +from __future__ import annotations + +from typing import Any, Literal + +import numpy as np +import torch + +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, StatefulVLMMeanScoresMetric +from pruna.evaluation.metrics.vlm_utils import VQAnswer, _process_images + + +@MetricRegistry.register("qa_accuracy") +class QAAccuracyMetric(StatefulVLMMeanScoresMetric): + """ + QA Accuracy metric. + + Uses a VLM to score yes/no alignment between each question and the generated image. + Higher scores indicate better image understanding. + + **Multiple questions** come from each auxiliary dict's ``questions`` mapping (e.g. GenEval + atomic probes, OneIG items). Each question is scored independently via :meth:`BaseVLM.score` + with expected answer ``"Yes"``. + + **Aggregation** (``aggregation`` kwarg): + + - ``mean`` (default): per image, average VLM scores over all questions; the metric's + :meth:`compute` returns the mean of those per-image values across ``update`` calls. + - ``all_or_nothing``: per image, ``1.0`` only if **every** question scores strictly above + ``0.5`` (scores equal to ``0.5`` count as failure). This matches strict GenEval-style + reporting (all atomic checks must pass per sample; see `GenEval + `_). :class:`~pruna.evaluation.task.Task` wires this for + the GenEval benchmark. + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Supports ``aggregation``: ``"mean"`` or ``"all_or_nothing"``. + + Raises + ------ + ValueError + If ``aggregation`` is not ``"mean"`` or ``"all_or_nothing"``. + + Examples + -------- + Same ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics import QAAccuracyMetric + + hosted = QAAccuracyMetric(vlm_type="litellm", model_name="openai/gpt-4o") + local = QAAccuracyMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + """ + + scores: list[float] + default_call_type: str = "y_gt" + higher_is_better: bool = True + metric_units: str = "accuracy" + metric_name: str = "qa_accuracy" + + def __init__( + self, + *args, + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(device=device) + self.response_format = VQAnswer if structured_output else None + self.aggregation = kwargs.pop("aggregation", "mean") + if self.aggregation not in {"mean", "all_or_nothing"}: + raise ValueError( + f"qa_accuracy aggregation must be one of {{'mean', 'all_or_nothing'}}. Got: {self.aggregation!r}." + ) + self.metric_units = type(self).metric_units + self._init_vlm_scores( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + ) + + def _extract_questions(self, gt: Any, n: int) -> list[list[str]]: + if isinstance(gt, (list, tuple)) and len(gt) >= n: + out = [] + for i in range(n): + v = gt[i] + if isinstance(v, dict) and "questions" in v: + qs = v["questions"] + out.append(list(qs.values()) if isinstance(qs, dict) else list(qs)) + else: + out.append([]) + return out + return [[] for _ in range(n)] + + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : list[Any] | torch.Tensor + The input data. + gt : torch.Tensor + The ground truth (questions per image). + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + auxiliaries = inputs[1] if len(inputs) > 1 else [] + questions_per_image = self._extract_questions(auxiliaries, len(images)) + for i, image in enumerate(images): + questions = questions_per_image[i] if i < len(questions_per_image) else [] + if not questions: + aux = auxiliaries[i] if i < len(auxiliaries) else {} + raise ValueError( + "qa_accuracy requires 'questions' in auxiliaries. " + "Use a benchmark that provides it (e.g. GenEval, DPG, OneIG). " + f"Got aux keys: {list(aux.keys()) if isinstance(aux, dict) else 'not a dict'}." + ) + scores = self.vlm.score( + [image] * len(questions), + questions, + ["Yes"] * len(questions), + response_format=self.response_format, + ) + if self.aggregation == "all_or_nothing": + score = 1.0 if all(s > 0.5 for s in scores) else 0.0 + else: + score = float(np.mean(scores)) + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the QA accuracy score. + + Returns + ------- + MetricResult + The mean QA accuracy across all updates. + """ + return self.compute_mean_of_scores() From f2d5c5aa08a890293371e919c7ea0b0ff4c45d30 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Fri, 8 May 2026 11:44:09 +0200 Subject: [PATCH 2/7] feat(text-metrics): export QAAccuracyMetric in split branch Co-authored-by: Cursor --- src/pruna/evaluation/metrics/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index d5cba6b8..389b9533 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -23,6 +23,7 @@ from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric as RapidataMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper @@ -53,6 +54,7 @@ "SharpnessMetric", "AestheticLAION", "LMEvalMetric", + "QAAccuracyMetric", "RapidataMetric", "BaseVLM", "LitellmVLM", From e7fca2ea734152d2bf1b754871f811d43a3ff08b Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Tue, 2 Jun 2026 19:26:15 +0200 Subject: [PATCH 3/7] fix(metrics): qa_accuracy keyword-only aggregation Remove redundant metric_units assignment; aggregation is keyword-only. Co-authored-by: Cursor --- src/pruna/evaluation/metrics/metric_qa_accuracy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_qa_accuracy.py b/src/pruna/evaluation/metrics/metric_qa_accuracy.py index 6dd36c2f..f954c0eb 100644 --- a/src/pruna/evaluation/metrics/metric_qa_accuracy.py +++ b/src/pruna/evaluation/metrics/metric_qa_accuracy.py @@ -120,16 +120,17 @@ def __init__( device: str | torch.device | None = None, api_key: str | None = None, call_type: str = SINGLE, + *, + aggregation: str = "mean", **kwargs: Any, ) -> None: super().__init__(device=device) self.response_format = VQAnswer if structured_output else None - self.aggregation = kwargs.pop("aggregation", "mean") + self.aggregation = aggregation if self.aggregation not in {"mean", "all_or_nothing"}: raise ValueError( f"qa_accuracy aggregation must be one of {{'mean', 'all_or_nothing'}}. Got: {self.aggregation!r}." ) - self.metric_units = type(self).metric_units self._init_vlm_scores( vlm=vlm, vlm_type=vlm_type, From 03677c4aea1d5c470f544039a2b76156cc17d174 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 4 Jun 2026 07:45:45 +0200 Subject: [PATCH 4/7] fix(metrics): unblock CI sync and VLM metric init Drop the broken Intel uv index (aligned with main), fix QAAccuracy keyword-only aggregation syntax, pass single/y_gt call types correctly for OneIG alignment, and expose metric_units on results. Co-authored-by: Cursor --- pyproject.toml | 24 ++++--------------- .../evaluation/metrics/metric_qa_accuracy.py | 11 ++++----- 2 files changed, 10 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 004fa3f0..a603847b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,41 +67,33 @@ name = "pruna_internal" url = "https://prunaai.pythonanywhere.com/simple/" explicit = true -[[tool.uv.index]] -name = "intel-pytorch-extension" -url = "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/" -explicit = true - [tool.uv] index-strategy = "first-index" +exclude-newer = "1 week" # protection against compromised dependencies +# trusted dev wheels that are missing an upload date +exclude-newer-package = { gptqmodel = false, "stable-fast-pruna" = false } conflicts = [ [{ extra = "awq" }, { extra = "vbench" }], [{ extra = "vllm" }, { extra = "vbench" }], - [{ extra = "intel" }, { extra = "awq" }], [{ extra = "gptq" }, { extra = "awq" }], - # intel is incompatible with all stable-fast variants and vllm - [{ extra = "intel" }, { extra = "stable-fast" }, { extra = "stable-fast-extraindex" }], - [{ extra = "intel" }, { extra = "full" }, { extra = "stable-fast-extraindex" }], - [{ extra = "intel" }, { extra = "vllm" }], [{ extra = "kvpress" }, { extra = "vbench" }], ] [tool.uv.sources] gptqmodel = { index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'" } -intel-extension-for-pytorch = { index = "intel-pytorch-extension" } stable-fast-pruna = { index = "pruna_internal", extra = "stable-fast-extraindex" } [project] name = "pruna" -version = "0.3.2" +version = "0.3.3" description = "Smash your AI models" authors = [ {name = "Pruna AI", email = "hello@pruna.ai"} ] license = {file = "LICENSE"} readme = "README.md" -requires-python = ">=3.10,<3.13" +requires-python = ">=3.10,<3.14" keywords = ["AI", "machine learning", "model optimization", "pruning"] classifiers = [ "Development Status :: 4 - Beta", @@ -246,12 +238,6 @@ lmharness = [ "lm-eval>=0.4.0" ] -# Intel extension is tightly coupled with the torch version -intel = [ - "intel-extension-for-pytorch>=2.7.0", - "torch>=2.7.0,<2.9.0", - "torchvision>=0.22.0,<0.24.0", -] kvpress = [ "kvpress>=0.5.2", ] diff --git a/src/pruna/evaluation/metrics/metric_qa_accuracy.py b/src/pruna/evaluation/metrics/metric_qa_accuracy.py index f954c0eb..ba5ed118 100644 --- a/src/pruna/evaluation/metrics/metric_qa_accuracy.py +++ b/src/pruna/evaluation/metrics/metric_qa_accuracy.py @@ -55,8 +55,6 @@ class QAAccuracyMetric(StatefulVLMMeanScoresMetric): Parameters ---------- - *args : Any - Additional positional arguments. vlm : BaseVLM | None, optional Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. vlm_type : {"litellm", "transformers"}, optional @@ -76,8 +74,10 @@ class QAAccuracyMetric(StatefulVLMMeanScoresMetric): API key for litellm. call_type : str, optional Call type for the metric. + aggregation : {"mean", "all_or_nothing"}, optional + Per-image score aggregation (keyword-only). Default is ``"mean"``. **kwargs : Any - Supports ``aggregation``: ``"mean"`` or ``"all_or_nothing"``. + Additional keyword arguments forwarded to the parent class. Raises ------ @@ -111,7 +111,6 @@ class QAAccuracyMetric(StatefulVLMMeanScoresMetric): def __init__( self, - *args, vlm: BaseVLM | None = None, vlm_type: Literal["litellm", "transformers"] = "litellm", model_name: str | None = None, @@ -119,7 +118,7 @@ def __init__( structured_output: bool = True, device: str | torch.device | None = None, api_key: str | None = None, - call_type: str = SINGLE, + call_type: str | None = None, *, aggregation: str = "mean", **kwargs: Any, @@ -139,7 +138,7 @@ def __init__( structured_output=structured_output, device=device, api_key=api_key, - call_type=call_type, + call_type=call_type if call_type is not None else SINGLE, ) def _extract_questions(self, gt: Any, n: int) -> list[list[str]]: From 1e36ec219fadbb311b1e04654ab288a985e653ff Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 4 Jun 2026 09:34:35 +0200 Subject: [PATCH 5/7] fix(ci): lint/docstrings and stack-appropriate VLM tests Replace forward-import VLM test module on pre-e2e branches with infrastructure-only tests; propagate docstring and conftest fixes. Co-authored-by: Cursor --- scripts/test_vlm_base_infrastructure_infra.py | 108 ++++ scripts/verify-vlm-stack-ci.sh | 125 ++++ src/pruna/evaluation/metrics/metric_torch.py | 47 +- tests/conftest.py | 7 + .../test_vlm_base_infrastructure.py | 596 +----------------- 5 files changed, 280 insertions(+), 603 deletions(-) create mode 100644 scripts/test_vlm_base_infrastructure_infra.py create mode 100755 scripts/verify-vlm-stack-ci.sh diff --git a/scripts/test_vlm_base_infrastructure_infra.py b/scripts/test_vlm_base_infrastructure_infra.py new file mode 100644 index 00000000..b6ac9b1c --- /dev/null +++ b/scripts/test_vlm_base_infrastructure_infra.py @@ -0,0 +1,108 @@ +"""Tests for VLM base classes and vlm_utils (infrastructure PR only).""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from pruna.evaluation.metrics.vlm_base import BaseVLM, LitellmVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import FloatOutput, get_score_from_response, yes_no_first_token_id_groups + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (FloatOutput(score=8.0), 0.8), + ({"score": 5.0}, 0.5), + ('{"score": 7.5}', 0.75), + ('{"score": 10}', 1.0), + ("8", 0.8), + ("Score: 7.5 out of 10", 0.75), + ("", 0.0), + ], +) +def test_get_score_from_response(raw: object, expected: float) -> None: + """``get_score_from_response`` maps pydantic, dict, JSON, and text to ``[0, 1]``.""" + assert get_score_from_response(raw) == pytest.approx(expected) + + +@pytest.mark.cpu +def test_get_vlm_returns_custom() -> None: + """get_vlm returns the provided VLM instance unchanged.""" + custom = MagicMock(spec=BaseVLM) + out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") + assert out is custom + + +@pytest.mark.cpu +def test_yes_no_first_token_id_groups_disjoint() -> None: + """Prefix token ids for Yes vs No should not overlap (avoids double-counting).""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("gpt2") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids and no_ids + assert not (set(yes_ids) & set(no_ids)) + + +@pytest.mark.cpu +def test_get_vlm_requires_model_name_without_vlm() -> None: + """get_vlm raises ValueError when no model_name is given and no vlm is provided.""" + with pytest.raises(ValueError, match="model_name"): + get_vlm(vlm=None, vlm_type="litellm") + + +@pytest.mark.cpu +def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: + """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" + pytest.importorskip("litellm") + import math + + import numpy as np + from PIL import Image + + def make_top_logprob(token, logprob): + t = MagicMock() + t.token = token + t.logprob = logprob + return t + + first_tok = MagicMock() + first_tok.top_logprobs = [ + make_top_logprob("Yes", math.log(0.10)), + make_top_logprob(" yes", math.log(0.05)), + make_top_logprob("No", math.log(0.20)), + make_top_logprob(" no", math.log(0.10)), + make_top_logprob("maybe", math.log(0.55)), + ] + + mock_logprobs = MagicMock() + mock_logprobs.content = [first_tok] + + mock_choice = MagicMock() + mock_choice.logprobs = mock_logprobs + mock_choice.message.content = "Yes" + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + with patch("litellm.completion", return_value=mock_response): + vlm = LitellmVLM(model_name="openai/gpt-4o") + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") + + assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_yes_no_token_ids_smolvlm_nonempty() -> None: + """SmolVLM tokenizer yields non-empty yes/no prefix id groups.""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids + assert no_ids diff --git a/scripts/verify-vlm-stack-ci.sh b/scripts/verify-vlm-stack-ci.sh new file mode 100755 index 00000000..9e997111 --- /dev/null +++ b/scripts/verify-vlm-stack-ci.sh @@ -0,0 +1,125 @@ +#!/usr/bin/env bash +# Run CI-equivalent lint + tests for each VLM stack branch before pushing. +# +# Usage: +# ./scripts/verify-vlm-stack-ci.sh +# ./scripts/verify-vlm-stack-ci.sh feat/vlm-pr-3b-oneig-alignment +# +# Lint: ruff + ty + docstring checks on evaluation.metrics (see tests.yaml). +# Tests: pytest with the same markers as CI base matrix (cpu, no_extras, …). + +set -euo pipefail + +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +MARK='cpu and not slow and not style and no_extras' +export PRUNA_CI_CPU_ONLY=1 + +INFRA_TEST_TEMPLATE="$ROOT/scripts/test_vlm_base_infrastructure_infra.py" + +branches=( + feat/vlm-pr-1-vendor + feat/vlm-pr-2-infrastructure + feat/vlm-pr-3a-qa-accuracy + feat/vlm-pr-3b-oneig-alignment + feat/vlm-pr-3c-text-score-pair + feat/vlm-pr-3d-oneig-reasoning + feat/vlm-pr-4a-vqa + feat/vlm-pr-4b-vie-score + feat/vlm-pr-4c-img-edit-score + feat/vlm-pr-5-e2e-tests +) + +if [[ $# -gt 0 ]]; then + branches=("$@") +fi + +# Branches before e2e use infrastructure-only VLM tests (no forward imports). +needs_infra_test_only() { + case "$1" in + feat/vlm-pr-2-infrastructure|feat/vlm-pr-3a-qa-accuracy|feat/vlm-pr-3b-oneig-alignment|feat/vlm-pr-3c-text-score-pair|feat/vlm-pr-3d-oneig-reasoning|feat/vlm-pr-4a-vqa|feat/vlm-pr-4b-vie-score|feat/vlm-pr-4c-img-edit-score) + return 0 + ;; + esac + return 1 +} + +tests_for_branch() { + local b=$1 + if needs_infra_test_only "$b" || [[ "$b" == "feat/vlm-pr-2-infrastructure" ]]; then + if [[ -f tests/evaluation/test_vlm_base_infrastructure.py ]]; then + echo "tests/evaluation/test_vlm_base_infrastructure.py" + fi + elif git cat-file -e "$b:tests/evaluation/test_vlm_base_infrastructure.py" 2>/dev/null; then + echo "tests/evaluation/test_vlm_base_infrastructure.py" + fi + if git cat-file -e "$b:tests/evaluation/test_text_metrics.py" 2>/dev/null; then + echo "tests/evaluation/test_text_metrics.py" + fi + if git cat-file -e "$b:tests/evaluation/_vlm_batch_snapshot_helpers.py" 2>/dev/null; then + echo "tests/evaluation/_vlm_batch_snapshot_helpers.py" + fi +} + +run_lint() { + echo "--- ruff (src/pruna) ---" + uv run ruff check src/pruna + + echo "--- ty (src/pruna) ---" + uv run ty check src/pruna + + echo "--- docstring style (evaluation.metrics) ---" + uv run pytest -m style -q tests/style/test_docstrings.py -k "evaluation.metrics" --maxfail=3 +} + +run_tests() { + local b=$1 + local -a paths=() + while IFS= read -r line; do + [[ -n "$line" ]] && paths+=("$line") + done < <(tests_for_branch "$b") + + if [[ ${#paths[@]} -eq 0 ]]; then + echo "(no VLM-specific tests on this branch; skipping pytest)" + return 0 + fi + + echo "--- pytest (${paths[*]}) ---" + uv run pytest -q --tb=line -m "$MARK" --maxfail=3 "${paths[@]}" +} + +orig_branch=$(git branch --show-current 2>/dev/null || echo main) +failed=() + +for b in "${branches[@]}"; do + echo "" + echo "========== $b ==========" + git checkout "$b" --quiet + + if needs_infra_test_only "$b" && [[ -f "$INFRA_TEST_TEMPLATE" ]]; then + cp "$INFRA_TEST_TEMPLATE" tests/evaluation/test_vlm_base_infrastructure.py + fi + + if ! uv sync --extra dev --quiet 2>/dev/null; then + uv sync --extra dev + fi + + if run_lint && run_tests "$b"; then + echo "PASS $b" + else + echo "FAIL $b" + failed+=("$b") + fi +done + +git checkout "$orig_branch" --quiet 2>/dev/null || true + +echo "" +if [[ ${#failed[@]} -eq 0 ]]; then + echo "All branches passed lint and tests." + exit 0 +fi +echo "Failed:" +printf ' - %s\n' "${failed[@]}" +exit 1 diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 4d329d86..b2c16f00 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -50,6 +50,26 @@ ) from pruna.logging.logger import pruna_logger +_PRUNA_TASK_ROUTING_KWARGS: tuple[str, ...] = ( + "vlm_type", + "model_name", + "structured_output", + "vlm_kwargs", + "api_key", +) + + +def _strip_task_routing_kwargs(kwargs: dict[str, Any]) -> None: + """ + Drop kwargs :class:`~pruna.evaluation.task.Task` passes when building mixed metric lists. + + Torchmetrics classes often end with ``**kwargs`` and would otherwise accept bogus keys + until a lower layer raises. Stripping here keeps :class:`TorchMetricWrapper` the single + choke point between Pruna routing and torchmetrics constructors. + """ + for key in _PRUNA_TASK_ROUTING_KWARGS: + kwargs.pop(key, None) + def default_update(metric: Metric, *args, **kwargs) -> None: """ @@ -124,9 +144,7 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None: def ssim_update( - metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, - preds: Any, - target: Any + metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, preds: Any, target: Any ) -> None: """ Update handler for SSIM or MS-SSIM metric. @@ -152,29 +170,22 @@ class TorchMetrics(Enum): """ Enumeration of torchmetrics metrics for evaluation. - This enum provides a tuple per member (metric_factory, update_fn, call_type): - metric_factory builds the metric (typically a torchmetrics class, or - functools.partial when some constructor arguments are fixed); update_fn is - an optional custom update handler; call_type describes how inputs are paired - for the metric. + Each member value is a ``(metric_factory, update_fn, call_type)`` tuple. Parameters ---------- value : tuple - Tuple holding metric_factory, update_fn, and call_type as described above. + ``(metric_factory, update_fn, call_type)`` for this enum member. names : str - The name of the enum member. + Enum member name. module : str - The module where the enum is defined. + Defining module name. qualname : str - The qualified name of the enum. + Qualified name of the enum class. type : type - The type of the enum. + Enum metaclass type. start : int - The start index for auto-numbering enum values. - boundary : enum.FlagBoundary or None - Boundary handling mode used by the Enum functional API for Flag and - IntFlag enums. + Auto-numbering start index for functional API enums. """ fid = (FrechetInceptionDistance, fid_update, "gt_y") @@ -246,6 +257,7 @@ def __new__(cls, metric_name: str, call_type: str = "", **kwargs) -> StatefulMet if metric_name == "clip_score" and call_type.startswith(PAIRWISE): from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore + _strip_task_routing_kwargs(kwargs) return PairwiseClipScore(**kwargs) return super().__new__(cls) @@ -259,6 +271,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: If the metric name is not supported. """ self.metric_name = metric_name + _strip_task_routing_kwargs(kwargs) super().__init__(kwargs.pop("device", None)) try: self.metric = TorchMetrics[metric_name](**kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 80d54825..6dff757b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,14 @@ +import os from typing import Any import pytest +if os.environ.get("PRUNA_CI_CPU_ONLY") == "1": + import torch + + if hasattr(torch.backends, "mps"): + torch.backends.mps.is_available = lambda: False # type: ignore[method-assign] + # import all fixtures to make them avaliable for pytest from .fixtures import * # noqa: F403, F401 diff --git a/tests/evaluation/test_vlm_base_infrastructure.py b/tests/evaluation/test_vlm_base_infrastructure.py index a4eaa139..b6ac9b1c 100644 --- a/tests/evaluation/test_vlm_base_infrastructure.py +++ b/tests/evaluation/test_vlm_base_infrastructure.py @@ -1,50 +1,12 @@ -"""Tests for VLM metrics (VQA, ImageEditScore, QAAccuracy, TextScore, VieScore) and vlm_utils helpers.""" +"""Tests for VLM base classes and vlm_utils (infrastructure PR only).""" from unittest.mock import MagicMock, patch import pytest import torch -from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric -from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric -from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric -from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric -from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric -from pruna.evaluation.metrics.metric_vqa import VQAMetric -from pruna.evaluation.metrics.result import MetricResult -from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm -from pruna.evaluation.metrics.vlm_utils import ( - FloatOutput, - VLM_AUX_IMAGE_BYTES_KEY_ORDER, - get_score_from_response, - yes_no_first_token_id_groups, -) - -from ._vlm_batch_snapshot_helpers import ( - BenchmarkVlmBatchOutcome, - pred_tensor_from_auxiliaries, - safe_json_for_snapshot, - vlm_benchmark_batch_to_json_record, -) - -SMOL_VLM = "HuggingFaceTB/SmolVLM-256M-Instruct" - -_ALL_VLM = ( - VQAMetric, - ImageEditScoreMetric, - QAAccuracyMetric, - OneIGAlignmentMetric, - TextScoreMetric, - OneIGTextScoreMetric, - VieScoreMetric, -) - -_SLOW_SMOL_SUBSET = ( - VQAMetric, - OneIGAlignmentMetric, - ImageEditScoreMetric, - VieScoreMetric, -) +from pruna.evaluation.metrics.vlm_base import BaseVLM, LitellmVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import FloatOutput, get_score_from_response, yes_no_first_token_id_groups @pytest.mark.parametrize( @@ -64,115 +26,6 @@ def test_get_score_from_response(raw: object, expected: float) -> None: assert get_score_from_response(raw) == pytest.approx(expected) -def _dummy_image(batch: int = 1, size: int = 224) -> torch.Tensor: - return torch.rand(batch, 3, size, size) - - -def _update_metric(metric: object, prompts: list, images: torch.Tensor) -> None: - if isinstance(metric, OneIGAlignmentMetric): - metric.update( - prompts, - [ - { - "questions": {"1": "Is there a cat?", "2": "Is it sleeping?"}, - "dependencies": {"1": [0], "2": [1]}, - } - ], - images, - ) - elif isinstance(metric, QAAccuracyMetric): - metric.update( - prompts, - [{"questions": {"1": "Is there a cat?"}}], - images, - ) - elif isinstance(metric, (TextScoreMetric, OneIGTextScoreMetric)): - metric.update(prompts, ["cat"], images) - else: - metric.update(prompts, images, images) - - -@pytest.mark.cpu -@pytest.mark.slow -@pytest.mark.parametrize("metric_cls", _SLOW_SMOL_SUBSET) -def test_vlm_metrics_transformers_smolvlm(metric_cls: type) -> None: - """Smoke-test a subset with local SmolVLM (full matrix covered by litellm mock).""" - metric = metric_cls( - vlm_type="transformers", - model_name=SMOL_VLM, - device="cpu", - structured_output=True, - ) - images = _dummy_image(batch=1) - prompts = ["a cat"] - _update_metric(metric, prompts, images) - result = metric.compute() - assert result.name == metric.metric_name - assert isinstance(result.result, float) - if metric.higher_is_better: - assert 0.0 <= result.result <= 1.0 - else: - assert result.result >= 0.0 - - -@pytest.mark.cpu -@pytest.mark.parametrize("metric_cls", _ALL_VLM) -def test_vlm_metrics_litellm_mocked(metric_cls: type) -> None: - """Each VLM metric runs end-to-end with mocked litellm.""" - pytest.importorskip("litellm") - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - if metric_cls in (VQAMetric, QAAccuracyMetric, OneIGAlignmentMetric): - mock_response.choices[0].message.content = '{"answer": "Yes"}' - else: - mock_response.choices[0].message.content = '{"score": 8}' - - with patch("litellm.completion") as mock_completion: - mock_completion.return_value = mock_response - - metric = metric_cls( - vlm_type="litellm", - model_name="gpt-4o", - device="cpu", - structured_output=True, - ) - images = _dummy_image(batch=1) - prompts = ["a cat"] - _update_metric(metric, prompts, images) - result = metric.compute() - - assert result.name == metric.metric_name - assert isinstance(result.result, float) - assert mock_completion.called - - -@pytest.mark.cpu -def test_vlm_metrics_empty_compute_returns_zero() -> None: - """No updates → compute is 0.0 (same for all stateful VLM metrics).""" - metric = VQAMetric( - vlm_type="transformers", - model_name=SMOL_VLM, - device="cpu", - structured_output=True, - ) - assert metric.compute().result == 0.0 - - -@pytest.mark.cpu -def test_vlm_metrics_custom_vlm() -> None: - """Custom VLM passed to VQAMetric is used instead of the default litellm backend.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ["Yes"] - mock_vlm.score.return_value = [1.0] - - metric = VQAMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", structured_output=True) - images = _dummy_image(batch=1) - prompts = ["a cat"] - metric.update(prompts, images, images) - assert metric.compute().result == 1.0 - mock_vlm.score.assert_called() - - @pytest.mark.cpu def test_get_vlm_returns_custom() -> None: """get_vlm returns the provided VLM instance unchanged.""" @@ -200,286 +53,15 @@ def test_get_vlm_requires_model_name_without_vlm() -> None: get_vlm(vlm=None, vlm_type="litellm") -@pytest.mark.cpu -@pytest.mark.parametrize( - "metric_cls, expected_name, expected_result", - [ - (TextScoreMetric, "text_score", 1.0), - (OneIGTextScoreMetric, "oneig_text_score", 1.0), - ], -) -def test_text_metrics_list_str_gt(metric_cls: type, expected_name: str, expected_result: float) -> None: - """Text metrics accept plain string ground-truth and return the expected score.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ["hello world"] - - metric = metric_cls(vlm=mock_vlm, vlm_type="litellm", device="cpu") - images = _dummy_image(batch=1) - metric.update(["a prompt"], ["hello world"], images) - result = metric.compute() - - assert result.result == expected_result - assert result.name == expected_name - mock_vlm.generate.assert_called_once() - - -@pytest.mark.cpu -def test_text_score_result_in_zero_one_range() -> None: - """TextScoreMetric must return a normalized score in [0, 1], not raw edit distance.""" - mock_vlm = MagicMock(spec=BaseVLM) - # VLM OCR returns something very different from ground truth (high edit distance) - mock_vlm.generate.return_value = ["completely wrong text abcdefghijklmnop"] - - metric = TextScoreMetric(vlm=mock_vlm, device="cpu") - images = _dummy_image(batch=1) - metric.update(["prompt"], ["hello"], images) - result = metric.compute() - - assert 0.0 <= result.result <= 1.0, f"TextScoreMetric must return [0,1], got {result.result}" - assert result.result < 0.5, f"Very different strings should score below 0.5, got {result.result}" - - -@pytest.mark.cpu -def test_text_score_perfect_match_is_one() -> None: - """TextScoreMetric: identical OCR and GT -> score 1.0.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ["hello world"] - - metric = TextScoreMetric(vlm=mock_vlm, device="cpu") - images = _dummy_image(batch=1) - metric.update(["prompt"], ["hello world"], images) - result = metric.compute() - - assert result.result == 1.0, f"Perfect match should give 1.0, got {result.result}" - assert result.higher_is_better is True - - -@pytest.mark.cpu -def test_text_score_registry_aliases() -> None: - """Registry aliases ocr_levenshtein and ocr_text_score resolve to the correct metric classes.""" - from pruna.evaluation.metrics.registry import MetricRegistry - - lev = MetricRegistry.get_metric("ocr_levenshtein", device="cpu", model_name="openai/gpt-4o") - comp = MetricRegistry.get_metric("ocr_text_score", device="cpu", model_name="openai/gpt-4o") - assert type(lev).__name__ == "TextScoreMetric" - assert type(comp).__name__ == "OneIGTextScoreMetric" - assert lev.metric_name == "text_score" - assert comp.metric_name == "oneig_text_score" - - -@pytest.mark.cpu -def test_oneig_text_score_utils_golden_composite() -> None: - """oneig_mean_text_score returns expected component values for a known input.""" - from pruna.evaluation.metrics.metric_text_score_utils import oneig_mean_text_score - - ed, cr, wac, composite = oneig_mean_text_score( - edit_distances=[10.0], - completion_ratios=[0.0], - match_counts=[2], - gt_totals=[4], - language_mode="EN", - ) - assert ed == 10.0 - assert cr == 0.0 - assert wac == 0.5 - assert composite == pytest.approx(0.95) - - _, _, _, zh = oneig_mean_text_score( - edit_distances=[30.0], - completion_ratios=[0.0], - match_counts=[0], - gt_totals=[1], - language_mode="ZH", - ) - assert zh == pytest.approx(0.4) - - -@pytest.mark.cpu -def test_qa_accuracy_all_or_nothing_partial_fail() -> None: - """all_or_nothing: if any question scores 0, the image score is 0.0 (not a partial mean).""" - mock_vlm = MagicMock(spec=BaseVLM) - # First question Yes (1.0), second question No (0.0) → mean=0.5, all_or_nothing=0.0 - mock_vlm.score.return_value = [1.0, 0.0] - - metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") - metric.update( - ["a prompt"], - [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], - _dummy_image(batch=1), - ) - result = metric.compute() - assert result.result == 0.0, f"Expected 0.0 for all_or_nothing with one No, got {result.result}" - - -@pytest.mark.cpu -def test_qa_accuracy_all_or_nothing_all_yes() -> None: - """all_or_nothing: all Yes → score 1.0.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.score.return_value = [1.0, 1.0] - - metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") - metric.update( - ["a prompt"], - [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], - _dummy_image(batch=1), - ) - result = metric.compute() - assert result.result == 1.0, f"Expected 1.0 for all_or_nothing with all Yes, got {result.result}" - - -@pytest.mark.cpu -def test_qa_accuracy_invalid_aggregation_raises() -> None: - """qa_accuracy rejects aggregation values other than mean / all_or_nothing.""" - mock_vlm = MagicMock(spec=BaseVLM) - with pytest.raises(ValueError, match="aggregation"): - QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="median") - - -@pytest.mark.cpu -def test_vie_score_tie_uses_source_from_gt_and_two_image_sc() -> None: - """With ``source_image_bytes`` in gt, VieScore calls two-image SC then PQ on the edited image.""" - from io import BytesIO - - from PIL import Image - - buf = BytesIO() - Image.new("RGB", (8, 8), color=(0, 0, 200)).save(buf, format="PNG") - src_bytes = buf.getvalue() - - mock_vlm = MagicMock() - mock_vlm.generate_with_image_lists.return_value = ['{"score": [8.0, 8.0], "reasoning": "ok"}'] - mock_vlm.generate.return_value = ['{"score": [9.0, 9.0], "reasoning": "ok"}'] - - metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) - pred = _dummy_image(batch=1) - metric.update( - ["make the sky purple"], - [{"source_image_bytes": src_bytes}], - pred, - ) - result = metric.compute() - - assert mock_vlm.generate_with_image_lists.called - assert mock_vlm.generate.called - assert 0.0 <= result.result <= 1.0 - - -@pytest.mark.cpu -def test_vie_score_uses_get_score_from_response() -> None: - """VieScoreMetric ``t2i`` path parses JSON ``score`` lists via ``viescore_min_scores_0_10``.""" - mock_vlm = MagicMock(spec=BaseVLM) - # LitellmVLM returns model_dump_json() for structured outputs → JSON string (two SC + two PQ sub-scores) - mock_vlm.generate.return_value = ['{"score": [8.0, 8.0], "reasoning": ""}'] - - metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) - metric.update(["a cat on a sofa"], _dummy_image(batch=1), _dummy_image(batch=1)) - result = metric.compute() - - # min(SC)=8, min(PQ)=8 → sqrt(8 * 8) / 10 = 0.8 - assert abs(result.result - 0.8) < 0.01, f"Expected ~0.8, got {result.result}" - - -@pytest.mark.cpu -def test_img_edit_score_negative_response_clamped() -> None: - """img_edit_score must be non-negative even when the VLM generates a negative JSON score. - - Regression for: Outlines constrained decoding can emit {"score": -10} despite the - FloatOutput JSON schema specifying minimum=0, because Outlines does not enforce numeric - bounds during token sampling. The fix is max(0.0, ...) in get_score_from_response. - """ - mock_vlm = MagicMock(spec=BaseVLM) - # Simulate Outlines generating a negative value (the bug scenario) - mock_vlm.generate.return_value = ['{"score": -10.0}'] - - metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) - metric.update(["replace the boot with a mug"], torch.zeros(1), _dummy_image(batch=1)) - result = metric.compute() - - assert result.result >= 0.0, f"img_edit_score must be >= 0, got {result.result}" - - -@pytest.mark.cpu -def test_qa_accuracy_all_or_nothing_ambiguous_score() -> None: - """all_or_nothing: score exactly 0.5 (ambiguous) is treated as No → result 0.0.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.score.return_value = [0.5] - - metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") - metric.update( - ["a prompt"], - [{"questions": {"1": "Is there a cat?"}}], - _dummy_image(batch=1), - ) - result = metric.compute() - assert result.result == 0.0, f"Score 0.5 should be treated as No (ambiguous), got {result.result}" - - -@pytest.mark.cpu -@pytest.mark.slow -def test_yes_no_token_ids_smolvlm_nonempty() -> None: - """SmolVLM tokenizer must yield non-empty disjoint yes/no prefix ids for VQAScore scoring.""" - pytest.importorskip("transformers") - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") - yes_ids, no_ids = yes_no_first_token_id_groups(tok) - assert len(yes_ids) > 0, "SmolVLM tokenizer has no 'Yes'-prefix token ids" - assert len(no_ids) > 0, "SmolVLM tokenizer has no 'No'-prefix token ids" - assert not (set(yes_ids) & set(no_ids)), "yes_ids and no_ids must be disjoint" - - -@pytest.mark.cpu -def test_img_edit_score_uses_prompt_from_x() -> None: - """img_edit_score must score the edited image against the instruction from x, not gt.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ['{"score": 9}'] - - metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu") - pred = _dummy_image(batch=1) - metric.update( - ["replace the cat with a dog"], # x = instruction - pred, # gt = unused for y_x - pred, # outputs = edited image - ) - result = metric.compute() - - call_args = mock_vlm.generate.call_args - prompt_sent = call_args[0][1][0] # second positional arg = prompts list, first item - assert "replace the cat with a dog" in prompt_sent, f"Instruction not in VLM prompt. Got: {prompt_sent}" - assert abs(result.result - 0.9) < 0.01, f"Expected ~0.9, got {result.result}" - - -@pytest.mark.cpu -def test_vie_score_geditbench_gap_documented() -> None: - """VieScoreMetric infers text--image editing from ``source_image_bytes`` in aux (no ``task_type``). - - This test fails if a ``task_type`` parameter is added to ``__init__`` without updating - GEditBench integration tests and benchmark copy accordingly. - """ - import inspect - - sig = inspect.signature(VieScoreMetric.__init__) - assert "task_type" not in sig.parameters, ( - "VieScoreMetric now has task_type — update GEditBench docs and e2e tests, then remove this sentinel." - ) - - @pytest.mark.cpu def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" pytest.importorskip("litellm") import math - from unittest.mock import MagicMock, patch import numpy as np from PIL import Image - from pruna.evaluation.metrics.vlm_base import LitellmVLM - - # Simulate top_logprobs for first output token: - # "Yes" → logprob=-2.303 (p≈0.10), " yes" → logprob=-2.996 (p≈0.05) → total p_yes≈0.15 - # "No" → logprob=-1.609 (p≈0.20), " no" → logprob=-2.303 (p≈0.10) → total p_no≈0.30 - # normalized: p_yes/(p_yes+p_no) ≈ 0.15/0.45 ≈ 0.333 def make_top_logprob(token, logprob): t = MagicMock() t.token = token @@ -510,175 +92,17 @@ def make_top_logprob(token, logprob): img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") - # Should be ~0.333 (p_yes=0.15 / (p_yes+p_no)=0.45), not just 0.10 (first match) assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" @pytest.mark.cpu @pytest.mark.slow -def test_vqa_probability_score_normalized() -> None: - """P(Yes) from TransformersVLM.score use_probability=True is in [0, 1].""" +def test_yes_no_token_ids_smolvlm_nonempty() -> None: + """SmolVLM tokenizer yields non-empty yes/no prefix id groups.""" pytest.importorskip("transformers") - import numpy as np - from PIL import Image - - from pruna.evaluation.metrics.vlm_base import TransformersVLM - - vlm = TransformersVLM( - model_name="HuggingFaceTB/SmolVLM-256M-Instruct", - device="cpu", - use_outlines=False, - ) - img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) - scores = vlm.score([img], ["Is there a cat?"], ["Yes"], use_probability=True) - assert len(scores) == 1 - assert 0.0 <= scores[0] <= 1.0, f"P(Yes) must be in [0, 1], got {scores[0]}" - - -# --------------------------------------------------------------------------- -# vlm_benchmark_batch_to_json_record serialization tests -# --------------------------------------------------------------------------- - - -def test_vlm_benchmark_batch_to_json_record_serializes_batch() -> None: - """Record includes prompts, pred shape, and metric fields.""" - mr = MetricResult(name="qa_accuracy", params={}, result=0.25, higher_is_better=True) - outcome = BenchmarkVlmBatchOutcome( - result=mr, - prompts=["prompt"], - auxiliaries=[{"path": "/tmp/x.png"}], - pred=torch.zeros(1, 3, 8, 8), - ) - rec = vlm_benchmark_batch_to_json_record( - outcome, - benchmark_key="GenEval", - benchmark_name="GenEval", - metric_name="qa_accuracy", - vlm_type="transformers", - model_name="m", - device="cpu", - ) - assert rec["inputs"]["prompts"] == ["prompt"] - assert rec["pred"]["shape"] == [1, 3, 8, 8] - assert rec["metric_result"]["result"] == 0.25 - - -def test_safe_json_handles_bytes_without_expanding() -> None: - """Bytes values in aux (e.g. source_image_bytes) are summarized, not expanded to str repr.""" - result = safe_json_for_snapshot({"source_image_bytes": b"\xff\xd8\xff" * 1000, "name": "test"}) - assert result["source_image_bytes"] == {"bytes_len": 3000} - assert result["name"] == "test" - - -def test_vlm_benchmark_batch_to_json_record_preserves_null_question_slots() -> None: - """Padded ``None`` question labels stay JSON null, not the string ``"None"``.""" - mr = MetricResult(name="oneig_alignment", params={}, result=1.0, higher_is_better=True) - outcome = BenchmarkVlmBatchOutcome( - result=mr, - prompts=["p"], - auxiliaries=[{"questions": {"1": "Are there boys?", "21": None}, "subset": "Anime_Stylization"}], - pred=torch.zeros(1, 3, 8, 8), - ) - rec = vlm_benchmark_batch_to_json_record( - outcome, - benchmark_key="OneIGAnimeStylization", - benchmark_name="OneIG Anime Stylization", - metric_name="oneig_alignment", - vlm_type="transformers", - model_name="m", - device="cpu", - ) - qs = rec["inputs"]["auxiliary_0"]["questions"] - assert qs["1"] == "Are there boys?" - assert qs["21"] is None - - -# --------------------------------------------------------------------------- -# pred_tensor_from_auxiliaries (test helper, wraps pil_rgb_from_aux_image_bytes) tests -# --------------------------------------------------------------------------- - - -def _make_jpeg_bytes(h: int = 32, w: int = 32) -> bytes: - """Return a tiny JPEG-encoded RGB image as bytes (test helper).""" - import io - - import numpy as np - from PIL import Image - - arr = (np.random.rand(h, w, 3) * 255).astype("uint8") - buf = io.BytesIO() - Image.fromarray(arr).save(buf, format="JPEG") - return buf.getvalue() - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_uses_source_image_bytes() -> None: - """pred_tensor_from_auxiliaries decodes source_image_bytes into a float tensor in [0, 1].""" - src_bytes = _make_jpeg_bytes() - aux = [{"source_image_bytes": src_bytes, "category": "background_change"}] - pred = pred_tensor_from_auxiliaries(aux, size=64) - - assert pred.shape == (1, 3, 64, 64), f"Expected (1,3,64,64), got {pred.shape}" - assert pred.min() >= 0.0 and pred.max() <= 1.0, "Pixel values must be in [0, 1]" - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_falls_back_to_noise_without_source_image() -> None: - """pred_tensor_from_auxiliaries returns random noise when no source_image_bytes is present.""" - aux = [{"category": "single_object"}] - pred = pred_tensor_from_auxiliaries(aux, size=32) - assert pred.shape == (1, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_mixed_batch() -> None: - """Batch with one source image and one missing falls back per-item.""" - src_bytes = _make_jpeg_bytes() - aux = [ - {"source_image_bytes": src_bytes, "category": "color_alter"}, - {"category": "style_change"}, # no source image - ] - pred = pred_tensor_from_auxiliaries(aux, size=32) - assert pred.shape == (2, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_generic_bytes_scan() -> None: - """pred_tensor_from_auxiliaries discovers image bytes under an unknown field name (generic scan).""" - src_bytes = _make_jpeg_bytes() - aux = [{"my_custom_image_bytes": src_bytes, "category": "motion_change"}] - pred = pred_tensor_from_auxiliaries(aux, size=32) - assert pred.shape == (1, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_known_names_take_priority() -> None: - """Known field names are resolved before the generic bytes scan.""" - src_bytes_known = _make_jpeg_bytes(16, 16) - src_bytes_unknown = _make_jpeg_bytes(32, 32) - first_known = VLM_AUX_IMAGE_BYTES_KEY_ORDER[0] - aux = [{"other_bytes": src_bytes_unknown, first_known: src_bytes_known}] - pred = pred_tensor_from_auxiliaries(aux, size=16) - # Should use the known key (16x16 image → 16x16 crop); generic scan would pick 32x32 - assert pred.shape == (1, 3, 16, 16) - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_require_source_image_raises_when_missing() -> None: - """require_source_image=True raises ValueError instead of silently returning noise.""" - aux = [{"category": "replace"}] # no image bytes - with pytest.raises(ValueError, match="require_source_image=True"): - pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) - + from transformers import AutoTokenizer -@pytest.mark.cpu -def test_pred_from_auxiliaries_require_source_image_succeeds_when_present() -> None: - """require_source_image=True succeeds and decodes bytes when source_image_bytes is present.""" - src_bytes = _make_jpeg_bytes() - aux = [{"source_image_bytes": src_bytes, "category": "replace"}] - pred = pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) - assert pred.shape == (1, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 + tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids + assert no_ids From de1cc64f0a9a28461668e86642f8492a9701349b Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 4 Jun 2026 09:55:42 +0200 Subject: [PATCH 6/7] fix(ci): ruff on infra VLM test template Co-authored-by: Cursor --- scripts/test_vlm_base_infrastructure_infra.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/scripts/test_vlm_base_infrastructure_infra.py b/scripts/test_vlm_base_infrastructure_infra.py index b6ac9b1c..524cd49b 100644 --- a/scripts/test_vlm_base_infrastructure_infra.py +++ b/scripts/test_vlm_base_infrastructure_infra.py @@ -1,9 +1,22 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tests for VLM base classes and vlm_utils (infrastructure PR only).""" from unittest.mock import MagicMock, patch import pytest -import torch from pruna.evaluation.metrics.vlm_base import BaseVLM, LitellmVLM, get_vlm from pruna.evaluation.metrics.vlm_utils import FloatOutput, get_score_from_response, yes_no_first_token_id_groups From 2dc8924e8ea04a0c15168157a468f15ba630d2c6 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 4 Jun 2026 10:18:07 +0200 Subject: [PATCH 7/7] chore: drop local-only scripts from PR scope Remove verify helper and duplicate infra test template from scripts/; tests live under tests/evaluation/ only. Co-authored-by: Cursor --- scripts/test_vlm_base_infrastructure_infra.py | 121 ----------------- scripts/verify-vlm-stack-ci.sh | 125 ------------------ 2 files changed, 246 deletions(-) delete mode 100644 scripts/test_vlm_base_infrastructure_infra.py delete mode 100755 scripts/verify-vlm-stack-ci.sh diff --git a/scripts/test_vlm_base_infrastructure_infra.py b/scripts/test_vlm_base_infrastructure_infra.py deleted file mode 100644 index 524cd49b..00000000 --- a/scripts/test_vlm_base_infrastructure_infra.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for VLM base classes and vlm_utils (infrastructure PR only).""" - -from unittest.mock import MagicMock, patch - -import pytest - -from pruna.evaluation.metrics.vlm_base import BaseVLM, LitellmVLM, get_vlm -from pruna.evaluation.metrics.vlm_utils import FloatOutput, get_score_from_response, yes_no_first_token_id_groups - - -@pytest.mark.parametrize( - ("raw", "expected"), - [ - (FloatOutput(score=8.0), 0.8), - ({"score": 5.0}, 0.5), - ('{"score": 7.5}', 0.75), - ('{"score": 10}', 1.0), - ("8", 0.8), - ("Score: 7.5 out of 10", 0.75), - ("", 0.0), - ], -) -def test_get_score_from_response(raw: object, expected: float) -> None: - """``get_score_from_response`` maps pydantic, dict, JSON, and text to ``[0, 1]``.""" - assert get_score_from_response(raw) == pytest.approx(expected) - - -@pytest.mark.cpu -def test_get_vlm_returns_custom() -> None: - """get_vlm returns the provided VLM instance unchanged.""" - custom = MagicMock(spec=BaseVLM) - out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") - assert out is custom - - -@pytest.mark.cpu -def test_yes_no_first_token_id_groups_disjoint() -> None: - """Prefix token ids for Yes vs No should not overlap (avoids double-counting).""" - pytest.importorskip("transformers") - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained("gpt2") - yes_ids, no_ids = yes_no_first_token_id_groups(tok) - assert yes_ids and no_ids - assert not (set(yes_ids) & set(no_ids)) - - -@pytest.mark.cpu -def test_get_vlm_requires_model_name_without_vlm() -> None: - """get_vlm raises ValueError when no model_name is given and no vlm is provided.""" - with pytest.raises(ValueError, match="model_name"): - get_vlm(vlm=None, vlm_type="litellm") - - -@pytest.mark.cpu -def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: - """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" - pytest.importorskip("litellm") - import math - - import numpy as np - from PIL import Image - - def make_top_logprob(token, logprob): - t = MagicMock() - t.token = token - t.logprob = logprob - return t - - first_tok = MagicMock() - first_tok.top_logprobs = [ - make_top_logprob("Yes", math.log(0.10)), - make_top_logprob(" yes", math.log(0.05)), - make_top_logprob("No", math.log(0.20)), - make_top_logprob(" no", math.log(0.10)), - make_top_logprob("maybe", math.log(0.55)), - ] - - mock_logprobs = MagicMock() - mock_logprobs.content = [first_tok] - - mock_choice = MagicMock() - mock_choice.logprobs = mock_logprobs - mock_choice.message.content = "Yes" - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - - with patch("litellm.completion", return_value=mock_response): - vlm = LitellmVLM(model_name="openai/gpt-4o") - img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) - score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") - - assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" - - -@pytest.mark.cpu -@pytest.mark.slow -def test_yes_no_token_ids_smolvlm_nonempty() -> None: - """SmolVLM tokenizer yields non-empty yes/no prefix id groups.""" - pytest.importorskip("transformers") - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") - yes_ids, no_ids = yes_no_first_token_id_groups(tok) - assert yes_ids - assert no_ids diff --git a/scripts/verify-vlm-stack-ci.sh b/scripts/verify-vlm-stack-ci.sh deleted file mode 100755 index 9e997111..00000000 --- a/scripts/verify-vlm-stack-ci.sh +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env bash -# Run CI-equivalent lint + tests for each VLM stack branch before pushing. -# -# Usage: -# ./scripts/verify-vlm-stack-ci.sh -# ./scripts/verify-vlm-stack-ci.sh feat/vlm-pr-3b-oneig-alignment -# -# Lint: ruff + ty + docstring checks on evaluation.metrics (see tests.yaml). -# Tests: pytest with the same markers as CI base matrix (cpu, no_extras, …). - -set -euo pipefail - -ROOT="$(cd "$(dirname "$0")/.." && pwd)" -cd "$ROOT" - -MARK='cpu and not slow and not style and no_extras' -export PRUNA_CI_CPU_ONLY=1 - -INFRA_TEST_TEMPLATE="$ROOT/scripts/test_vlm_base_infrastructure_infra.py" - -branches=( - feat/vlm-pr-1-vendor - feat/vlm-pr-2-infrastructure - feat/vlm-pr-3a-qa-accuracy - feat/vlm-pr-3b-oneig-alignment - feat/vlm-pr-3c-text-score-pair - feat/vlm-pr-3d-oneig-reasoning - feat/vlm-pr-4a-vqa - feat/vlm-pr-4b-vie-score - feat/vlm-pr-4c-img-edit-score - feat/vlm-pr-5-e2e-tests -) - -if [[ $# -gt 0 ]]; then - branches=("$@") -fi - -# Branches before e2e use infrastructure-only VLM tests (no forward imports). -needs_infra_test_only() { - case "$1" in - feat/vlm-pr-2-infrastructure|feat/vlm-pr-3a-qa-accuracy|feat/vlm-pr-3b-oneig-alignment|feat/vlm-pr-3c-text-score-pair|feat/vlm-pr-3d-oneig-reasoning|feat/vlm-pr-4a-vqa|feat/vlm-pr-4b-vie-score|feat/vlm-pr-4c-img-edit-score) - return 0 - ;; - esac - return 1 -} - -tests_for_branch() { - local b=$1 - if needs_infra_test_only "$b" || [[ "$b" == "feat/vlm-pr-2-infrastructure" ]]; then - if [[ -f tests/evaluation/test_vlm_base_infrastructure.py ]]; then - echo "tests/evaluation/test_vlm_base_infrastructure.py" - fi - elif git cat-file -e "$b:tests/evaluation/test_vlm_base_infrastructure.py" 2>/dev/null; then - echo "tests/evaluation/test_vlm_base_infrastructure.py" - fi - if git cat-file -e "$b:tests/evaluation/test_text_metrics.py" 2>/dev/null; then - echo "tests/evaluation/test_text_metrics.py" - fi - if git cat-file -e "$b:tests/evaluation/_vlm_batch_snapshot_helpers.py" 2>/dev/null; then - echo "tests/evaluation/_vlm_batch_snapshot_helpers.py" - fi -} - -run_lint() { - echo "--- ruff (src/pruna) ---" - uv run ruff check src/pruna - - echo "--- ty (src/pruna) ---" - uv run ty check src/pruna - - echo "--- docstring style (evaluation.metrics) ---" - uv run pytest -m style -q tests/style/test_docstrings.py -k "evaluation.metrics" --maxfail=3 -} - -run_tests() { - local b=$1 - local -a paths=() - while IFS= read -r line; do - [[ -n "$line" ]] && paths+=("$line") - done < <(tests_for_branch "$b") - - if [[ ${#paths[@]} -eq 0 ]]; then - echo "(no VLM-specific tests on this branch; skipping pytest)" - return 0 - fi - - echo "--- pytest (${paths[*]}) ---" - uv run pytest -q --tb=line -m "$MARK" --maxfail=3 "${paths[@]}" -} - -orig_branch=$(git branch --show-current 2>/dev/null || echo main) -failed=() - -for b in "${branches[@]}"; do - echo "" - echo "========== $b ==========" - git checkout "$b" --quiet - - if needs_infra_test_only "$b" && [[ -f "$INFRA_TEST_TEMPLATE" ]]; then - cp "$INFRA_TEST_TEMPLATE" tests/evaluation/test_vlm_base_infrastructure.py - fi - - if ! uv sync --extra dev --quiet 2>/dev/null; then - uv sync --extra dev - fi - - if run_lint && run_tests "$b"; then - echo "PASS $b" - else - echo "FAIL $b" - failed+=("$b") - fi -done - -git checkout "$orig_branch" --quiet 2>/dev/null || true - -echo "" -if [[ ${#failed[@]} -eq 0 ]]; then - echo "All branches passed lint and tests." - exit 0 -fi -echo "Failed:" -printf ' - %s\n' "${failed[@]}" -exit 1