shapiq.vision Sub-package for Image Classification#549
Conversation
* ADD: image imputer implementation for resnet with superpixel players and zero / mean color masking * added image explainer and game execution to image mvp notebook * Merge remote-tracking branch 'vision/vit-prototype' * ADD: initialized package structure for shapiq.vision * ADD: entry point to Image Explainer * ADD: players and masking strategies copied from resnet prototype * ADD: ImageImputer copied from resnet prototype * ADD: ImageExplainer only for resnet, structur similar to TabularExplainer * ADD: init file for vision subpackage * ADD: Patchstrategy for ViT & value function in imputer * ADD: default masking and player strategy decision in explainer * ADD: bug fixes to run explainer for both resnet and vit * ADD: testing notebook to visualize how to interact with the explainer * ADD: moved prototype notebooks to specific folder * ADD: cat image from current frontend testing * REFACTOR: decouple vision package via ModelArchitectureStrategy Replaces model_type if/elif branches with ResNetArchitecture and ViTArchitecture. Player and masking strategies split into Pixel/Latent subtypes. Update test notebook with current implementation * ADD: - dynamic dispatch of model architecture based on the given model - refactor some variable names, cleaned imputer - model architecture interaction * REMOVE: debugging prints * REFACTOR: testing notebook * ADD: plotting function for heatmap Taking image (as numpy array), the explainer and label type as input and outputting one plot made up of 2 subplots: First plot is image with alpha overlay of first order interaction values. Second plot is actual values in bar chart. * ADD: Heatmap only argument, so barchart not plotted * Merge branch 'superpixel-improvement' --------- Co-authored-by: Tamara Muras <Tamara.Muras> Co-authored-by: Alexander Feix <alexander.feix03@gmail.com> Co-authored-by: S2k-1 <219272227+S2k-1@users.noreply.github.com>
* REFACTOR: Restnet/Pixel -> CNN and ViT/latent -> Transformer. For player and masking strategies aswell as architecture * REFACTOR: split player definition and masking into respective classes for Transformers * REFACTOR: rename masking strategy function, move logit call to architecture and adjust arguments for function call. * ADD: torch import only where necessary and ensure annotations are evaluated lazy * REFACTOR: move build pixel mask to patch strategy and introduce player mask property in architecture for visualizations * REFACTOR: change order of player strategies in file --------- Co-authored-by: Tamara Muras <Tamara.Muras> Co-authored-by: S2k-1 <219272227+S2k-1@users.noreply.github.com>
…and internal handling between torch and numpy (#34) * ADD: image conversion methods in utils * REFACTOR: change internal handling to torch mainly, except for players * REFACTOR: convert player masks to torch in architecture instead of masking * ADD: ImageLike typy to also support input of torch and pil images * Fix: wrong typing in architecture * remove debug print and unnecessary method * ADD: appropirate batching in the value function of imputer * REFACTOR: comment on masking strategies * comment on players * REFACTOR: removed model auto dispatch and require to input model architecture * ADD: improved docstring on imputer * ADD: player documentation --------- Co-authored-by: Tamara Muras <Tamara.Muras>
* ADD: imputer fit method and image property in imputer * fix: explainer only updates imputer when x is not None being passed to explain * remove unintended import --------- Co-authored-by: Tamara Muras <Tamara.Muras>
…tom masks from Superpixels (#37) * ADD: Custom Player strategy and gridstrategy for CNN architectures * Refactor: exclude custom masks from superpixels and add to customplayer strategy * Refactor: improve grid strategy to take patch or grid size instead of row and cols --------- Co-authored-by: Tamara Muras <Tamara.Muras>
* ADD: vision tests * ADD & FIX: added readl model tests and fixed tensor channel bug * ADD: unit test cases for custom players * Fix: custom player mask not casted to bool * ADD: unit tests for grid strategy * ADD: tests for imputer fit method * Fix: imputer not returning self * Add: Adjust explainer tests due to changed implementation of explain function and renamed batch size arg in imputer * fix: failing tests due to inconsistent renaming of batch size arg --------- Co-authored-by: Tamara Muras <Tamara.Muras> Co-authored-by: Alexander Feix <alexander.feix03@gmail.com>
* Remove: notebooks used for quick testing * fix: apply pre-commit auto-fixes for vision package * Refactor: improve code quality and add docstrings * Refactor: fix all code quality issues from main files of vision package - did not improve plot functions - did not remove lazy import statements as recommended * fix: explainer couldnt be initialized with transformer model --------- Co-authored-by: Tamara Muras <Tamara.Muras>
|
Hello, thank you for your PR already! Please ping me here as soon as you want a PR review. :) Note however, that for a review, the CI pipeline should really be green first (all tests pass, the code-quality checks are okay, and the docs building pipeline compiles). Otherwise it's hard for me to give you good feedback. :) I can also approve the workflows from time to time if you ping me. |
|
@mmschlk ty/ruff errors fixed. CI workflow on our fork only got errors for codecov upload. (expected due to missing token) |
…s for install (#40) * ADD: example notebooks for quickstart and on defining players * Fix: docstring inaccuracies for sphinx doc and finalize image explanation examples * fix: code quality * remove code blocks from example * ADD: proper shapiq[vision] package to install and add tests to ensure import errors appear * fix: test for import error when running in a row with the other tests * fix: code quality * fix: imputer torch import and CI pipeline shapiq import * ADD: imputer to framework import testing --------- Co-authored-by: Tamara Muras <Tamara.Muras>
|
We also added two examples to the documentation to showcase how to interact with the vision subpackage and how to define different player strategies. These are kept rather simple until now, we'll extend them after we decided how to refactor the architectures (after the feedback). We also now copied the behavior from ShaplEIG and SPEX approximators in defining optional dependencies for the vision package. |
|
@mmschlk Could you approve the workflow, so we can verify the full pipeline runs successfully? |
|
@Advueu963 could you please approve the workflow or ping @mmschlk |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
mmschlk
left a comment
There was a problem hiding this comment.
Thank you already for your pull request and all of your work already on this matter!! :) I only had some minor comments (see individually). It's already pretty nice!
The only "major-ish" thing I have is also a bit of a question. Did we decided not to wirre ImageExplainer into the Explainer dispatching logic? So, we said we do not expect users to do
explainer = Explainer(image_model, ...)
and expect explainer to be ImageExplainer? Because I think this is not possible, yet. I would however, like to know weather it would be easy to allow users to call
explainer = Explainer(CNNArchitecture(model), ...)
and then get ImageExplainer. But then only these special architectures would work in the dispatching. Is this doable?
Thank you already!
| data: ImageLike, | ||
| *, | ||
| imputer: ImageImputer | None = None, | ||
| index: ExplainerIndices = "k-SII", |
There was a problem hiding this comment.
we actually went to "SV" as default in v1.5.0 and order 1.
| def plot_image_attributions( | ||
| self, | ||
| image: np.ndarray, | ||
| explainer: ImageExplainer, | ||
| *, | ||
| region_label: str = "Region", | ||
| alpha: float = 0.5, | ||
| cmap: Colormap | str | None = None, | ||
| show: bool = True, | ||
| heatmap_only: bool = True, | ||
| ) -> tuple[Figure, Axes] | tuple[Figure, tuple[Axes, Axes]] | None: | ||
| """Visualize first-order attributions as a heatmap overlaid on the original image. | ||
|
|
||
| Args: | ||
| image: Original image as a ``(H, W, C)`` numpy array. | ||
| explainer: The ``ImageExplainer`` used to produce this object. Its imputer | ||
| provides the pixel-space player masks. | ||
| region_label: x-axis label for the bar chart, e.g. ``"Patch"`` or | ||
| ``"Superpixel"``. Defaults to ``"Region"``. | ||
| alpha: Transparency of the heatmap overlay. Defaults to ``0.5``. | ||
| cmap: Matplotlib colormap or name. ``None`` uses shapiq's BLUE→white→RED | ||
| diverging palette. Defaults to ``None``. | ||
| show: Whether to display the plot. Defaults to ``True``. | ||
| heatmap_only: Whether to show only the heatmap. Defaults to ``True``. | ||
|
|
||
| Returns: | ||
| If ``show`` is ``False`` and ``heatmap_only`` is ``True``, returns | ||
| ``(figure, ax_heatmap)``. Otherwise returns ``(figure, (ax_heatmap, ax_bar))``. | ||
|
|
||
| """ | ||
| from shapiq.plot.vision import image_attribution_plot | ||
|
|
||
| return image_attribution_plot( | ||
| self, | ||
| image, | ||
| explainer.imputer.player_masks, | ||
| region_label=region_label, | ||
| alpha=alpha, | ||
| cmap=cmap, | ||
| show=show, | ||
| heatmap_only=heatmap_only, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Decouple InteractionValues from the optional vision subpackage. InteractionValues is the central output type of the whole library and has to stay importable without torch/scikit-image. This method makes it depend on the optional vision subpackage (the ImageExplainer import up at line 32 even in the Type block), and unlike every other plot_* here it asks the caller to pass the explainer back in and then reaches into explainer.imputer.player_masks. Please move this out of the core type. We already have shapiq.plot.vision.image_attribution_plot; expose it there and have it take image + player_masks directly. The core class shouldn't know vision types exist.
| Returns: | ||
| InteractionValues: The interaction values of the prediction. | ||
| """ | ||
| budget: int = kwargs.get("budget", 64) |
There was a problem hiding this comment.
A few parameters are out of step with our other explainers and I'd like parity before we can merge:
- budget is read from kwargs here instead of being an explicit argument (compare TabularExplainer.explain_function(x, budget=...)). Please make it a real signature parameter so it shows up in the docs/IDE.
- no random_state is this not also necessary here?
| ImageExplainerIndices = ExplainerIndices | ||
|
|
||
|
|
||
| class ImageExplainer(Explainer): |
There was a problem hiding this comment.
A few parameters are out of step with our other explainers and I'd like parity before we can merge:
- approximator (here in init): TabularExplainer lets users pass one; here it's hard-coded to "auto". Superpixel player counts get large, so users will want to swap it.
- class_index (here in init): base and TabularExplainer accept it; here the explained class is hard-coded to argmax inside the architecture, so a multi-class model can't choose what to explain. We need that toggle however. I would like to be able to do this in the init as well.
| def default_player_strategy(self) -> PatchStrategy: | ||
| """Return a patch player strategy with a 3x3 grid.""" | ||
| grid_size = self._model.config.image_size // self._model.config.patch_size | ||
| return PatchStrategy(grid_size=grid_size, n_players=9) |
There was a problem hiding this comment.
This hard-codes a 3x3 grid, but PatchStrategy.init requires grid_size % sqrt(n_players) == 0. For one of the most common ViT, vit-base-patch16-224, the grid_size is 224/16 = 14 and 14 % 3 != 0, so this raises ValueError on construction and the default path crashes for the standard model. The examples only work because they use patch32-384 (grid 12). Please make the default adapt to grid_size (pick a perfect square whose root divides the grid), and add a regression test against a 14x14 grid so this can't silently come back.
| """Vision-based explanation methods for image models.""" | ||
|
|
||
| try: | ||
| from .architecture import CNNArchitecture, ModelArchitectureStrategy, TransformerArchitecture |
There was a problem hiding this comment.
Heads-up, and this one isn't really your fault as we already have the same pattern in ShaplEIG. This eager from .architecture import ... pulls architecture.py (and its top-level import torch) at import-shapiq time, so even a tabular-only user pays ~0.5s of torch import. The try/except guards graceful failure but not laziness. Please defer this with a module-level getattr (PEP 562) so torch is only imported when ImageExplainer is actually used. Soon I will open a matching issue against ShaplEIG so we fix both consistently then.
| if self._normalize: | ||
| self.normalization_value = self.empty_prediction | ||
|
|
||
| def fit(self, x: ImageLike) -> ImageImputer: |
There was a problem hiding this comment.
Our Imputer is a Game, and the approximator is built once for a fixed n_players. fit() recomputes self.n_features from the new image — and for SuperpixelStrategy the SLIC count varies image to image — but the approximator built back in explainer.py:106 is never rebuilt. So explaining a second image with one explainer desyncs the game and approximator (shape mismatch or silently wrong values). For this we need to rebuild the approximator when n_features changes. Right now reusing an explainer across images is unsafe.
shapiq.vision Sub-package for Image Classification
Motivation and Context
shapiq currently has no first-class support for explaining vision models. The only image-handling code lives inside the benchmark games (
shapiq_games.benchmark.local_xai.ImageClassifier) where masking logic is hard-coded per model, which is not suitable for end-user pipelines.This PR introduces
shapiq.vision, a new subpackage that plugs cleanly into the existingExplainer → Game → Imputerpipeline and supports both CNN and Vision Transformer (ViT) architectures. The design follows the pattern ofshapiq.tree/shapiq.graphby separating concerns into dedicated modules:players.py— pluggable player-definition strategies (superpixel, fixed grid, custom masks, ViT patch tokens)masking.py— pluggable masking strategies (mean-color, zero/constant, token-based attention masking for ViTs)architecture.py— model-coupling layer that binds a player strategy and a masking strategy to a concrete model type (CNNArchitecture,TransformerArchitecture)imputer.py—ImageImputer(Imputer)that implements batchedvalue_functionevaluation over coalitionsexplainer.py—ImageExplainer(Explainer)that wires the imputer into the approximator pipelineutils.py— format-agnostic image conversion helpers (as_hwc_array,to_tensor_chw, etc.)Public API Changes
No existing public APIs are modified or removed. We just implemented new ones for the shapiq.vision subpackage. See example:
How Has This Been Tested?
Introduced 135 unit tests.
Manually verified end to end with ResNet-18 (torchvision) and ViT (google/vit-base-patch32-384) on a sample on ImageNet images.
Checklist
CHANGELOG.md(if relevant for users).Notes
Draft PR for LMU Practical: Toolbox for Machine Learning (Group Vision subpackage)
We have tested different types of CNNs and ViTs for our provided architecture setup (ViT & CNN). More detailed information can be found in: S2k-1#26
We have tested 28 CNNs and 15 ViTs, out of those we achieved 23/28 for CNN (We already have the fix for the 5 missing ones) and for ViTs we achieved 2/15. We already planned on how to make the architecture more robust so we increase the amount of supported ViTs.
Due to missing testing, not all masking strategies that we experimented with are merged into this PR version. We are working on that.