Skip to content

Vision-Language Model Imputer Module#543

Open
Gength wants to merge 20 commits into
mmschlk:mainfrom
yhuang122:main
Open

Vision-Language Model Imputer Module#543
Gength wants to merge 20 commits into
mmschlk:mainfrom
yhuang122:main

Conversation

@Gength

@Gength Gength commented Jun 10, 2026

Copy link
Copy Markdown

Description

This PR introduces a new shapiq.imputer.vision sub-package that extends shapiq's imputer framework with pluggable segmenters and maskers for vision-language model explanation.

Keywords: CLIP, SigLIP, VLM explanation, image occlusion, patch segmentation, SLIC superpixels

Motivation

shapiq currently supports tabular data imputation (MarginalImputer, BaselineImputer, GaussianImputer, etc.) and nearest-neighbour explainer games (NNExplainerGameBase). However, explaining vision-language models (VLMs) such as CLIP and SigLIP requires a different paradigm:

  • Tabular imputation replaces missing features with sampled/imputed values.
  • VLM explanation requires occluding spatial regions (image patches or superpixels) and text tokens, then measuring the change in model similarity.

This PR adds a modular vision pipeline that follows shapiq's existing Game contract while introducing two new abstractions — Segmenter and Masker — that can be mixed and matched.

Solution

A new sub-package shapiq/imputer/vision/ with:

Component Description Pluggable?
Segmenter (ABC) Defines which pixels/tokens belong to each player
Masker (ABC) Defines how occlusion is applied to model inputs
VisionImputer Orchestrates Segmenter → Masker → Model forward pass
VisionImputerFactory Auto-detects model type, assembles components
VisionLanguageGame Thin shapiq.Game adapter for approximators
PatchSegmenter Rigid grid aligned with ViT patch embeddings
SLICSegmenter Perceptual superpixels (skimage SLIC) for CNNs
CustomSegmenter User-provided binary masks as players
VisionMeanMasker Multiplicative zero-out (mean fill)
VisionBlurMasker Gaussian blur occlusion (CPU skimage)
TextAttentionMasker Attention mask replacement for text tokens
CrossModalMeanMasker Composite: VisionMean + TextAttention
CrossModalBlurMasker Composite: VisionBlur + TextAttention

Related Work

This follows the same architectural pattern as shapiq.explainer.nn.games.NNExplainerGameBase(Game) — a domain-specific Game subclass that does not inherit from Imputer. Like the NN explainer games, VisionLanguageGame(Game) can be used directly with any shapiq approximator.


Changes

New files

shapiq/imputer/vision/
├── __init__.py                  # Public API exports
├── base.py                      # Data types: SpatialLayout, PhysicalMask, ProcessorOutput
├── imputer.py                   # VisionImputer (orchestration)
├── factory.py                   # VisionImputerFactory (assembly)
├── game.py                      # VisionLanguageGame (Game adapter)
├── segmenters/
│   ├── base.py                  # Segmenter(ABC) + SegmenterConfig + per-strategy params
│   ├── __init__.py              # Registry
│   ├── patch.py                 # PatchSegmenter (ViT grid)
│   ├── slic.py                  # SLICSegmenter (superpixels)
│   └── custom.py                # CustomSegmenter (user-provided masks)
└── maskers/
    ├── base.py                  # Masker(ABC) + MaskerConfig + per-strategy params
    ├── __init__.py              # Registry
    ├── vision_mean.py           # VisionMeanMasker (zero-out)
    ├── vision_blur.py           # VisionBlurMasker (Gaussian blur)
    ├── text_attention.py        # TextAttentionMasker (attention swap)
    ├── crossmodal_mean.py       # CrossModalMeanMasker (composite)
    └── crossmodal_blur.py       # CrossModalBlurMasker (composite)

src/shapiq/explainer/vision.py   # VisionExplainer (Explainer integration)

VisionExplainer

New explainer integration at src/shapiq/explainer/vision.py:

  • VisionExplainer(Explainer) — wraps VisionImputerFactoryVisionLanguageGame → approximator
  • Auto-dispatch: shapiq.Explainer(model, data=image, text=..., processor=...) automatically routes HF VLMs to VisionExplainer
  • Registered in ExplainerTypes ("vision") and get_explainers()

Tests

tests/shapiq/tests_unit/tests_imputer/test_vision_imputer.py
tests/shapiq/tests_unit/tests_imputer/test_vision_explainer.py

99 unit tests covering:

Category Tests Coverage
Data types 10 Config, Layout, Mask, ProcessorOutput
PatchSegmenter 7 Layout, masks, CLIP BOS/EOS, all/zero occlusion
SLICSegmenter 3 Requires image, real image layout, mask generation
VisionMeanMasker 4 Apply, zero-out, pass-through, preserves text
TextAttentionMasker 2 Swap mask, preserves image
CrossModalMeanMasker 1 Composite apply
CustomSegmenter 5 Requires masks, invalid dims, layout, mask union, text pass-through
VisionBlurMasker 4 Blur apply, sigma configurable, pass-through, text preservation
CrossModalBlurMasker 2 Composite blur+text, registry key
Registry 7 Get/unknown for both registries
ExactComputer correctness 5 SV order-1, k-SII order-2, all-ones/zeros, exact SV values
Player×Masker matrix 10 forward_1d shape (4) + InteractionValues (4) + blur combos (2)
VisionExplainer 4 SV, k-SII order-2, baseline, game property
Explainer auto-dispatch 1 HF VLM → VisionExplainer routing
Text mask formats 3 CLIP, SigLIP, SigLIP2 padding
VisionImputer forward 6 1D multi/single batch, crossmodal, properties
VisionLanguageGame 5 Init, normalization, delegation, properties
Factory metadata 5 Model type, ViT/CNN dims, text count
Factory build 5 Default, AMP, forward, empty coalition, SLIC build

Dependencies

Package Status Notes
torch Already in shapiq (optional) Used by shapiq.explainer.nn
transformers New dependency HuggingFace model loading
scikit-image New dependency SLIC superpixels, Gaussian blur
Pillow Already in shapiq Image loading

Usage

from shapiq.imputer.vision import VisionImputerFactory, VisionLanguageGame
from transformers import CLIPProcessor, CLIPModel
from PIL import Image

# 1. Load model & data
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
image = Image.open("dog.jpg").convert("RGB")
text = "a black dog next to a yellow hydrant"

# 2. Build imputer (default: PatchSegmenter + CrossModalMeanMasker)
factory = VisionImputerFactory()
imputer = factory.build(model, processor, image, text)

# 3. Wrap as shapiq Game
game = VisionLanguageGame(imputer, batch_size=64)

# 4. Use with any shapiq approximator
from shapiq import KernelSHAPIQ
approximator = KernelSHAPIQ(n=game.n_players, max_order=2, index="k-SII")
interaction_values = approximator(budget=2**12, game=game)
# Custom segmenter (SLIC superpixels for CNN backbones)
from shapiq.imputer.vision import SegmenterConfig, SlicParams
seg_cfg = SegmenterConfig(
    strategy="slic",
    slic=SlicParams(n_segments=60, compactness=10.0),
)
imputer = factory.build(model, processor, image, text, segmenter_config=seg_cfg)
# Custom masker (Gaussian blur instead of zero-out)
from shapiq.imputer.vision import MaskerConfig
msk_cfg = MaskerConfig(strategy="crossmodal_blur")
imputer = factory.build(model, processor, image, text, masker_config=msk_cfg)
# Custom segmenter: user-provided binary masks
import numpy as np
from shapiq.imputer.vision import SegmenterConfig

# Create 3 player masks on a 224×224 image
masks = np.zeros((3, 224, 224), dtype=bool)
masks[0, :75, :] = True    # top region
masks[1, 75:150, :] = True  # middle region
masks[2, 150:, :] = True    # bottom region

seg_cfg = SegmenterConfig(strategy="custom_segmenter")
imputer = factory.build(
    model, processor, image, text,
    segmenter_config=seg_cfg,
    masks=masks,  # passed to CustomSegmenter via segmenter_kwargs
)
# VisionExplainer (auto-dispatch)
from shapiq import Explainer

explainer = Explainer(
    model=clip_model,
    data=image,
    text="a black dog next to a yellow hydrant",
    processor=processor,
)
iv = explainer.explain(budget=2048)
# Implementing a custom segmenter
from shapiq.imputer.vision import Segmenter, SegmenterConfig, SpatialLayout, PhysicalMask

class MySegmenter(Segmenter):
    def __init__(self, config: SegmenterConfig):
        super().__init__(config)
        self._layout = SpatialLayout(...)
    def get_layout(self) -> SpatialLayout:
        return self._layout
    def generate_masks(self, coalitions_image, coalitions_text, device=None) -> PhysicalMask:
        ...

Testing

# From the shapiq repository root
cd shapiq
pip install -e .[dev]

# Run only vision imputer tests
pytest tests/shapiq/tests_unit/tests_imputer/test_vision_imputer.py -v

# Run all imputer tests (including existing tabular imputers)
pytest tests/shapiq/tests_unit/tests_imputer/ -v

Expected output: 99 passed for the vision + explainer test suite.


Design Rationale

Why a new sub-package instead of extending Imputer?

The existing Imputer base class has a tabular-specific constructor signature (model, data, x, sample_size, categorical_features). A VLM imputer needs a HuggingFace model + processor, PIL.Image, and pluggable Segmenter/Masker components. Forcing these into the existing Imputer hierarchy would require breaking backward compatibility and would create a semantic mismatch between "imputing missing values" (tabular) and "occluding features" (vision).

Following the precedent of NNExplainerGameBase(Game) under explainer/nn/games/, the vision module introduces its own Game subclass that delegates all masking/batching logic to a VisionImputer orchestration layer.

CPU Planning, GPU Execution

Segmenters compute pixel-to-player mappings on CPU (e.g., skimage.segmentation.slic) and upload the result to GPU once. Mask application runs entirely on GPU via tensor ops. This keeps the expensive per-coalition work on GPU while respecting skimage's CPU-only API.

Open for extension

  • New segmenters: implement Segmenter.get_layout() + generate_masks(), decorate with @register_segmenter("name").
  • New maskers: implement Masker.apply(), decorate with @register_masker("name").
  • New backbones: the factory auto-detects model type from model.name_or_path; new model families require adding a detection branch in _infer_model_type.

Checklist

  • New code follows the project's coding style
  • Tests added for all new functionality
  • Existing tests continue to pass (137 passed, 3 skipped for all imputer tests; skipped tests are tabpfn optional dependency, unrelated to vision module)
  • Example notebook (docs/examples/vision_language_clip.ipynb) added
  • All public API has docstrings
  • __all__ defined in __init__.py

yhuang122 and others added 6 commits April 25, 2026 08:56
… tests

- Add CustomSegmenter — user-provided binary masks as players
  (registered as "custom_segmenter", file: segmenters/custom.py)
- Add VisionExplainer(Explainer) with auto-dispatch for HF VLMs
  (file: explainer/vision.py, registered in ExplainerTypes)
- Add ExactComputer correctness test with manually verified SV values
- Add Player×Masker matrix test (4 combos × forward_1d + InteractionValues)
- Add blur masker combos for both patch and custom segmenters
- Add VisionExplainer integration + auto-dispatch tests

99 tests passing (78 existing + 21 new)
Split the former vision/base.py into three files by responsibility:

- vision/base.py         — data transfer protocol only
                          (SpatialLayout, PhysicalMask, ProcessorOutput)
- vision/segmenters/base.py — Segmenter(ABC), SegmenterConfig,
                          per-strategy param dataclasses (PatchParams,
                          SlicParams, GradientGuidedParams, CustomSegmenterParams)
- vision/maskers/base.py    — Masker(ABC), MaskerConfig,
                          per-strategy param dataclasses (CrossModalMeanParams,
                          CrossModalBlurParams, VisionMeanParams,
                          VisionBlurParams, TextAttentionParams)

All cross-references within the vision package updated to relative imports.
99 tests pass.
clean unrelated files

restore to origin

restore to origin
@mmschlk

mmschlk commented Jun 15, 2026

Copy link
Copy Markdown
Owner

Please ping me here as soon as you want a PR review. Note however, that for a review, the CI pipeline needs to be green first (all tests pass, the code-quality checks are okay, and the docs building pipeline compiles). :)

Gength and others added 3 commits June 15, 2026 15:16
Fix all ruff lint and ty type-checker errors across the vision module
to satisfy pre-commit CI requirements (ruff: 203→0, ty: 0 errors).

Changes:
- Add missing docstrings (D102/D107/D205/D417) to all public methods
- Add missing type annotations (ANN*) and replace Any with concrete types
- Convert relative imports to absolute (TID252)
- Move type-only imports to TYPE_CHECKING guards (TC001/TC002)
- Replace exception string literals with msg variable (EM101/EM102/TRY003)
- Fix dict() -> {} literals (C408), redundant assignments (RET504)
- Fix ambiguous x characters in docstrings/tests (RUF001/RUF002)
- Make bool params keyword-only (FBT001/FBT002)
- Add __all__ and noqa for registry __init__.py pattern (E402/F401)
- Fix ty errors: kwargs type annotations, import guard patterns
- Fix test files: replace x with x, unused variable
- Fix notebook: replace x with x, dict() -> {}, add per-file-ignores
The HuggingFace VLM check (hasattr(model, 'vision_model')) was
overwriting already-matched model types because Mock() returns
True for all hasattr checks. Added _model_type == 'tabular' guard
so VLM detection only fires when no prior check matched.
@Gength

Gength commented Jun 15, 2026

Copy link
Copy Markdown
Author

Please ping me here as soon as you want a PR review. Note however, that for a review, the CI pipeline needs to be green first (all tests pass, the code-quality checks are okay, and the docs building pipeline compiles). :)

Hi @mmschlk ,

I have fixed all the issues, and the CI pipeline is now green except for the "upload coverage to codecov" step. This remaining failure appears to be a configuration or token issue with the Codecov service rather than a code quality or test failure.

I am looking forward to your review of this PR!

@mmschlk mmschlk self-requested a review June 17, 2026 06:40

@mmschlk mmschlk left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you already for your very detailed work. Overall I am already quite happy with the extend and breath you did for bringing VLMs to shapiq. I commented on a couple of nitpick comments. Then I also have two more elaborate comments as well, which I will detail here further.

First, the break of the API design (also see the big comment in the Explainer's init). Currently your Explainer cannot be called with different x (image + text) inputs inside its explain function but all the instance relevant information has to be passed at init time. This is not consistent with shapiq's core API. Explainer inits carry the information "how explanation will be done in this setting" and explain_functions bring the instance related information, which x is to be explained. Then you can use the same explainer actually to explain mutliple x after each other. This is currently not possible, but would be nice to achieve.

Second, currently practitioners and novice users cannot will not really understand what the Epxlainer is actually about (also see the notebook comments). For this we need to have proper examples showcasing how to use the explainer and why the API (and its different choices e.g. for maskers matter and what they change). For this a set of example scripts alongside the existing examples would be very welcome. Note however, that these examples may not run too long on the doc building runners and can also be manually turned off to not always run automatically if you want. So please provide more examples.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we do not allow .ipynb files anymore inside the docs or the examples folder. So you cannot have notebook examples but only real script files. You can see how this is done with scripts here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 demo files are added to examples/vision/ now

Comment thread src/shapiq/explainer/vision.py Outdated
"""Valid index types for the VisionExplainer."""


class VisionExplainer(Explainer):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like the name VisionExplainer for this explainer class as it only really deals with VisionLanguage models. I would rather refactor it to be named as such: VisionLanguageExplainer.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VisionExplainer is renamed as VisionLanguageExplainer in new commit.

Comment thread src/shapiq/explainer/utils.py Outdated
and not (model_class or "").startswith("torch.nn.")
and _model_type == "tabular"
):
_model_type = "vision"

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment as the explainer name. "vision" is to generic for the thing you built regarding vision language models. Please address the naming refactor consistently across the PR from there.

Comment thread src/shapiq/explainer/vision.py Outdated
@@ -0,0 +1,239 @@
"""Vision Explainer for shapiq.

The :class:`VisionExplainer` explains vision-language model predictions using

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have already marked this in the notebook file. It would be nice to also get a couple of example scripts (in the correct file structure as python executable scripts) that showcase how to do the explanations for laypeople. So this is rather a bigger comment, but the codebase should also be documented well for someone who does not really know how VisionLanguage models may be explained. A couple of examples 2-3 would be very welcome on this issue.

Comment thread src/shapiq/explainer/vision.py Outdated
interaction_values.baseline_value = self.baseline_value
return interaction_values

# ─── Internal helpers ─────────────────────────────────────────────────

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like these AI delimiters.

Comment thread src/shapiq/explainer/utils.py Outdated

ExplainerTypes = Literal[
"tabular", "tree", "tabpfn", "game", "product_kernel", "knn", "wknn", "tnn"
"tabular", "tree", "tabpfn", "game", "product_kernel", "knn", "wknn", "tnn", "vision"

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment with the naming

Operates exclusively on pixel_values. Must never touch input_ids or
attention_mask.

.. note::

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's implementation task allocation, I will remove this comment.

Comment thread src/shapiq/explainer/vision.py Outdated
max_order: int = 1,
random_state: int | None = None,
verbose: bool = False,
use_amp: bool = False,

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inside shapiq we currently do not support amp in any other places, which is why I would like to remove it here as well.

Comment thread src/shapiq/explainer/vision.py Outdated
model: Any, # noqa: ANN401
data: Any = None, # noqa: ANN401
*,
text: str = "",

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like defaults like this. The text being empty should actually not be possible to be provided as default. If a default is non-sensical then it should not have a default. As this is a Vision Language game it needs an input text and and input image both of which need to be provided. So for consistency this should not be provided too. However, there actually lies a rather more important issue now that is actually breaking the current shapiq API:

Explainer instantiations are not expected to carry information about the local explanation that you can do with them but rather carry the boiler plate setup code that is governing how "explaining" with this explainer will work (like setting the masking strategy or the processor steps, or the interaction index we are interested in together with the approximators that are available.

The actual "explanation point" is then always provided once the user's call the explainers.explain function. This is provided via the x parameter. Which in this case would then carry the image and the text. Of course, the overall setup you had can change in your VLM case compared for example to explaining tabular models (with different text lengths you will have different n_players at each explain time). So you would need to reinstantiate your approximators each time for each new x again.

This issue also can be seen in the explain_function doctoring as:

x: Ignored for vision models (the image is fixed at construction time). Kept for API compatibility.

You should not adhere to API compatibility by offering dead parameters, but actually be conform to the API.



@dataclass
class CrossModalMeanParams:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of these data classes are empty and carry no meaning. I was not able to see where they are all used? Maybe I missed something, but double check if this strucutre is necessary as it is here.

Gength and others added 6 commits July 1, 2026 01:10
…r fix

CHANGES:
- VisionExplainer renamed to VisionLanguageExplainer
- image/text moved from __init__ to explain(x={"image": ..., "text": ...})
- SegmenterConfig / MaskerConfig replaced per-strategy fields
- (patch/slic/custom_segmenter/crossmodal_mean, ...) with single params field

Features:
- VisionLanguageExplainer: new API conforming to shapiq convention;
- game + approximator rebuilt per explain() call (handles varying n_players)
- Example scripts: plot_vision_language_clip.py (CLIP Patch + SLIC),
- plot_vision_explainer_custom.py (custom masks + blur masker)
- safe_processor_call: fallback for Fast processors (Transformers 4.51+)

Cleanup:
- Remove .ipynb notebook, replace with .py Sphinx-Gallery scripts
- Remove notebook-only ruff rules from pyproject.toml
- Remove use_amp parameter from explainer, imputer, and factory
- Collapse 6 empty params dataclasses into single EmptyParams alias
@Gength Gength changed the title Pull Request: Vision-Language Model Imputation Module Vision-Language Model Imputer Module Jul 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants