Skip to content

shapiq.vision Sub-package for Image Classification#549

Open
S2k-1 wants to merge 16 commits into
mmschlk:mainfrom
S2k-1:vision_main
Open

shapiq.vision Sub-package for Image Classification#549
S2k-1 wants to merge 16 commits into
mmschlk:mainfrom
S2k-1:vision_main

Conversation

@S2k-1

@S2k-1 S2k-1 commented Jun 12, 2026

Copy link
Copy Markdown

Motivation and Context

shapiq currently has no first-class support for explaining vision models. The only image-handling code lives inside the benchmark games (shapiq_games.benchmark.local_xai.ImageClassifier) where masking logic is hard-coded per model, which is not suitable for end-user pipelines.

This PR introduces shapiq.vision, a new subpackage that plugs cleanly into the existing Explainer → Game → Imputer pipeline and supports both CNN and Vision Transformer (ViT) architectures. The design follows the pattern of shapiq.tree / shapiq.graph by separating concerns into dedicated modules:

  • players.py — pluggable player-definition strategies (superpixel, fixed grid, custom masks, ViT patch tokens)
  • masking.py — pluggable masking strategies (mean-color, zero/constant, token-based attention masking for ViTs)
  • architecture.py — model-coupling layer that binds a player strategy and a masking strategy to a concrete model type (CNNArchitecture, TransformerArchitecture)
  • imputer.pyImageImputer(Imputer) that implements batched value_function evaluation over coalitions
  • explainer.pyImageExplainer(Explainer) that wires the imputer into the approximator pipeline
  • utils.py — format-agnostic image conversion helpers (as_hwc_array, to_tensor_chw, etc.)

Public API Changes

  • No Public API changes
  • Yes, Public API changes (Details below)

No existing public APIs are modified or removed. We just implemented new ones for the shapiq.vision subpackage. See example:

        from shapiq.vision.architecture import CNNArchitecture, TransformerArchitecture
        from shapiq.vision.explainer import ImageExplainer

        # --- CNN (ResNet-style) ---
        arch = CNNArchitecture(model=my_resnet)
        explainer = ImageExplainer(model_architecture=arch, data=my_image)
        iv = explainer.explain_function(x=None, budget=256)

        # --- ViT ---
        arch = TransformerArchitecture(model=my_vit, vit_processor=processor)
        explainer = ImageExplainer(model_architecture=arch, data=my_image,
                                   index="SII", max_order=2)
        iv = explainer.explain_function(x=None, budget=512)

How Has This Been Tested?

Introduced 135 unit tests.
Manually verified end to end with ResNet-18 (torchvision) and ViT (google/vit-base-patch32-384) on a sample on ImageNet images.


Checklist

  • The changes have been tested locally.
  • Documentation has been updated (if the public API or usage changes).
  • An entry has been added to CHANGELOG.md (if relevant for users).
  • The code follows the project's style guidelines.
  • I have considered the impact of these changes on the public API.

Notes

Draft PR for LMU Practical: Toolbox for Machine Learning (Group Vision subpackage)

We have tested different types of CNNs and ViTs for our provided architecture setup (ViT & CNN). More detailed information can be found in: S2k-1#26
We have tested 28 CNNs and 15 ViTs, out of those we achieved 23/28 for CNN (We already have the fix for the 5 missing ones) and for ViTs we achieved 2/15. We already planned on how to make the architecture more robust so we increase the amount of supported ViTs.

Due to missing testing, not all masking strategies that we experimented with are merged into this PR version. We are working on that.

t-muras and others added 8 commits June 5, 2026 21:06
* ADD: image imputer implementation for resnet with superpixel players and zero / mean color masking

* added image explainer and game execution to image mvp notebook

* Merge remote-tracking branch 'vision/vit-prototype'

* ADD: initialized package structure for shapiq.vision

* ADD: entry point to Image Explainer

* ADD: players and masking strategies copied from resnet prototype

* ADD: ImageImputer copied from resnet prototype

* ADD: ImageExplainer only for resnet, structur similar to TabularExplainer

* ADD: init file for vision subpackage

* ADD: Patchstrategy for ViT & value function in imputer

* ADD: default masking and player strategy decision in explainer

* ADD: bug fixes to run explainer for both resnet and vit

* ADD: testing notebook to visualize how to interact with the explainer

* ADD: moved prototype notebooks to specific folder

* ADD: cat image from current frontend testing

* REFACTOR: decouple vision package via ModelArchitectureStrategy

Replaces model_type if/elif branches with ResNetArchitecture and ViTArchitecture.
Player and masking strategies split into Pixel/Latent subtypes.
Update test notebook with current implementation

* ADD:
- dynamic dispatch of model architecture based on the given model
- refactor some variable names, cleaned imputer - model architecture interaction

* REMOVE: debugging prints

* REFACTOR: testing notebook

* ADD: plotting function for heatmap

Taking image (as numpy array), the explainer and label type as input and outputting one plot made up of 2 subplots:
First plot is image with alpha overlay of first order interaction values. Second plot is actual values in bar chart.

* ADD: Heatmap only argument, so barchart not plotted

* Merge branch 'superpixel-improvement'

---------

Co-authored-by: Tamara Muras <Tamara.Muras>
Co-authored-by: Alexander Feix <alexander.feix03@gmail.com>
Co-authored-by: S2k-1 <219272227+S2k-1@users.noreply.github.com>
* REFACTOR: Restnet/Pixel -> CNN and ViT/latent -> Transformer. For player and masking strategies aswell as architecture

* REFACTOR: split player definition and masking into respective classes for Transformers

* REFACTOR: rename masking strategy function, move logit call to architecture and adjust arguments for function call.

* ADD: torch import only where necessary and ensure annotations are evaluated lazy

* REFACTOR: move build pixel mask to patch strategy and introduce player mask property in architecture for visualizations

* REFACTOR: change order of player strategies in file

---------

Co-authored-by: Tamara Muras <Tamara.Muras>
Co-authored-by: S2k-1 <219272227+S2k-1@users.noreply.github.com>
…and internal handling between torch and numpy (#34)

* ADD: image conversion methods in utils

* REFACTOR: change internal handling to torch mainly, except for players

* REFACTOR: convert player masks to torch in architecture instead of masking

* ADD: ImageLike typy to also support input of torch and pil images

* Fix: wrong typing in architecture

* remove debug print and unnecessary method

* ADD: appropirate batching in the value function of imputer

* REFACTOR: comment on masking strategies

* comment on players

* REFACTOR: removed model auto dispatch and require to input model architecture

* ADD: improved docstring on imputer

* ADD: player documentation

---------

Co-authored-by: Tamara Muras <Tamara.Muras>
* ADD: imputer fit method and image property in imputer

* fix: explainer only updates imputer when x is not None being passed to explain

* remove unintended import

---------

Co-authored-by: Tamara Muras <Tamara.Muras>
…tom masks from Superpixels (#37)

* ADD: Custom Player strategy and gridstrategy for CNN architectures

* Refactor: exclude custom masks from superpixels and add to customplayer strategy

* Refactor: improve grid strategy to take patch or grid size instead of row and cols

---------

Co-authored-by: Tamara Muras <Tamara.Muras>
* ADD: vision tests

* ADD & FIX: added readl model tests and fixed tensor channel bug

* ADD: unit test cases for custom players

* Fix: custom player mask not casted to bool

* ADD: unit tests for grid strategy

* ADD: tests for imputer fit method

* Fix: imputer not returning self

* Add: Adjust explainer tests due to changed implementation of explain function and renamed batch size arg in imputer

* fix: failing tests due to inconsistent renaming of batch size arg

---------

Co-authored-by: Tamara Muras <Tamara.Muras>
Co-authored-by: Alexander Feix <alexander.feix03@gmail.com>
* Remove: notebooks used for quick testing

* fix: apply pre-commit auto-fixes for vision package

* Refactor: improve code quality and add docstrings

* Refactor: fix all code quality issues from main files of vision package

- did not improve plot functions
- did not remove lazy import statements as recommended

* fix: explainer couldnt be initialized with transformer model

---------

Co-authored-by: Tamara Muras <Tamara.Muras>
@mmschlk

mmschlk commented Jun 17, 2026

Copy link
Copy Markdown
Owner

Hello, thank you for your PR already! Please ping me here as soon as you want a PR review. :)

Note however, that for a review, the CI pipeline should really be green first (all tests pass, the code-quality checks are okay, and the docs building pipeline compiles). Otherwise it's hard for me to give you good feedback. :)

I can also approve the workflows from time to time if you ping me.

@S2k-1

S2k-1 commented Jun 18, 2026

Copy link
Copy Markdown
Author

@mmschlk ty/ruff errors fixed. CI workflow on our fork only got errors for codecov upload. (expected due to missing token)
Please approve the workflow here :)

…s for install (#40)

* ADD: example notebooks for quickstart and on defining players

* Fix: docstring inaccuracies for sphinx doc and finalize image explanation examples

* fix: code quality

* remove code blocks from example

* ADD: proper shapiq[vision] package to install and add tests to ensure import errors appear

* fix: test for import error when running in a row with the other tests

* fix: code quality

* fix: imputer torch import and CI pipeline shapiq import

* ADD: imputer to framework import testing

---------

Co-authored-by: Tamara Muras <Tamara.Muras>
@t-muras

t-muras commented Jun 21, 2026

Copy link
Copy Markdown

We also added two examples to the documentation to showcase how to interact with the vision subpackage and how to define different player strategies. These are kept rather simple until now, we'll extend them after we decided how to refactor the architectures (after the feedback).

We also now copied the behavior from ShaplEIG and SPEX approximators in defining optional dependencies for the vision package.

@t-muras

t-muras commented Jun 26, 2026

Copy link
Copy Markdown

@mmschlk Could you approve the workflow, so we can verify the full pipeline runs successfully?

@S2k-1

S2k-1 commented Jun 29, 2026

Copy link
Copy Markdown
Author

@Advueu963 could you please approve the workflow or ping @mmschlk

@Advueu963 Advueu963 requested a review from mmschlk June 29, 2026 16:21
@codecov

codecov Bot commented Jun 29, 2026

Copy link
Copy Markdown

@mmschlk mmschlk left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you already for your pull request and all of your work already on this matter!! :) I only had some minor comments (see individually). It's already pretty nice!

The only "major-ish" thing I have is also a bit of a question. Did we decided not to wirre ImageExplainer into the Explainer dispatching logic? So, we said we do not expect users to do

explainer = Explainer(image_model, ...)

and expect explainer to be ImageExplainer? Because I think this is not possible, yet. I would however, like to know weather it would be easy to allow users to call

explainer = Explainer(CNNArchitecture(model), ...)

and then get ImageExplainer. But then only these special architectures would work in the dispatching. Is this doable?

Thank you already!

data: ImageLike,
*,
imputer: ImageImputer | None = None,
index: ExplainerIndices = "k-SII",

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we actually went to "SV" as default in v1.5.0 and order 1.

Comment on lines +1067 to +1109
def plot_image_attributions(
self,
image: np.ndarray,
explainer: ImageExplainer,
*,
region_label: str = "Region",
alpha: float = 0.5,
cmap: Colormap | str | None = None,
show: bool = True,
heatmap_only: bool = True,
) -> tuple[Figure, Axes] | tuple[Figure, tuple[Axes, Axes]] | None:
"""Visualize first-order attributions as a heatmap overlaid on the original image.

Args:
image: Original image as a ``(H, W, C)`` numpy array.
explainer: The ``ImageExplainer`` used to produce this object. Its imputer
provides the pixel-space player masks.
region_label: x-axis label for the bar chart, e.g. ``"Patch"`` or
``"Superpixel"``. Defaults to ``"Region"``.
alpha: Transparency of the heatmap overlay. Defaults to ``0.5``.
cmap: Matplotlib colormap or name. ``None`` uses shapiq's BLUE→white→RED
diverging palette. Defaults to ``None``.
show: Whether to display the plot. Defaults to ``True``.
heatmap_only: Whether to show only the heatmap. Defaults to ``True``.

Returns:
If ``show`` is ``False`` and ``heatmap_only`` is ``True``, returns
``(figure, ax_heatmap)``. Otherwise returns ``(figure, (ax_heatmap, ax_bar))``.

"""
from shapiq.plot.vision import image_attribution_plot

return image_attribution_plot(
self,
image,
explainer.imputer.player_masks,
region_label=region_label,
alpha=alpha,
cmap=cmap,
show=show,
heatmap_only=heatmap_only,
)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decouple InteractionValues from the optional vision subpackage. InteractionValues is the central output type of the whole library and has to stay importable without torch/scikit-image. This method makes it depend on the optional vision subpackage (the ImageExplainer import up at line 32 even in the Type block), and unlike every other plot_* here it asks the caller to pass the explainer back in and then reaches into explainer.imputer.player_masks. Please move this out of the core type. We already have shapiq.plot.vision.image_attribution_plot; expose it there and have it take image + player_masks directly. The core class shouldn't know vision types exist.

Returns:
InteractionValues: The interaction values of the prediction.
"""
budget: int = kwargs.get("budget", 64)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few parameters are out of step with our other explainers and I'd like parity before we can merge:

  • budget is read from kwargs here instead of being an explicit argument (compare TabularExplainer.explain_function(x, budget=...)). Please make it a real signature parameter so it shows up in the docs/IDE.
  • no random_state is this not also necessary here?

ImageExplainerIndices = ExplainerIndices


class ImageExplainer(Explainer):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few parameters are out of step with our other explainers and I'd like parity before we can merge:

  • approximator (here in init): TabularExplainer lets users pass one; here it's hard-coded to "auto". Superpixel player counts get large, so users will want to swap it.
  • class_index (here in init): base and TabularExplainer accept it; here the explained class is hard-coded to argmax inside the architecture, so a multi-class model can't choose what to explain. We need that toggle however. I would like to be able to do this in the init as well.

def default_player_strategy(self) -> PatchStrategy:
"""Return a patch player strategy with a 3x3 grid."""
grid_size = self._model.config.image_size // self._model.config.patch_size
return PatchStrategy(grid_size=grid_size, n_players=9)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hard-codes a 3x3 grid, but PatchStrategy.init requires grid_size % sqrt(n_players) == 0. For one of the most common ViT, vit-base-patch16-224, the grid_size is 224/16 = 14 and 14 % 3 != 0, so this raises ValueError on construction and the default path crashes for the standard model. The examples only work because they use patch32-384 (grid 12). Please make the default adapt to grid_size (pick a perfect square whose root divides the grid), and add a regression test against a 14x14 grid so this can't silently come back.

"""Vision-based explanation methods for image models."""

try:
from .architecture import CNNArchitecture, ModelArchitectureStrategy, TransformerArchitecture

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads-up, and this one isn't really your fault as we already have the same pattern in ShaplEIG. This eager from .architecture import ... pulls architecture.py (and its top-level import torch) at import-shapiq time, so even a tabular-only user pays ~0.5s of torch import. The try/except guards graceful failure but not laziness. Please defer this with a module-level getattr (PEP 562) so torch is only imported when ImageExplainer is actually used. Soon I will open a matching issue against ShaplEIG so we fix both consistently then.

if self._normalize:
self.normalization_value = self.empty_prediction

def fit(self, x: ImageLike) -> ImageImputer:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our Imputer is a Game, and the approximator is built once for a fixed n_players. fit() recomputes self.n_features from the new image — and for SuperpixelStrategy the SLIC count varies image to image — but the approximator built back in explainer.py:106 is never rebuilt. So explaining a second image with one explainer desyncs the game and approximator (shape mismatch or silently wrong values). For this we need to rebuild the approximator when n_features changes. Right now reusing an explainer across images is unsafe.

@mmschlk mmschlk marked this pull request as ready for review June 30, 2026 16:16
@mmschlk mmschlk changed the title [Draft] shapiq.vision subpackage shapiq.vision Sub-package for Image Classification Jul 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants