Skip to content

Commit 412b801

Browse files
committed
Update Captum tutorial dependencies for Captum 0.8.0
Pin Captum to 0.8.0, the last release retaining Captum Insights, and update the tutorial requirements accordingly. Captum 0.8.0 requires Python >=3.9 and torch >=1.10, so the tutorial now reflects those minimums and uses torchvision >=0.11.0 as the matching torchvision line for PyTorch 1.10. Remove the old matplotlib==3.3.4 requirement. Captum 0.8.0 uses the newer grid(visible=False) Matplotlib API, so the old workaround for the grid(b=False) argument rename is no longer needed. Update the original-image visualization example to pass a dummy attribution array instead of None. In Captum 0.8.0, visualize_image_attr normalizes attr before reaching the code that handles method="original_image", even though that code only displays the original image.
1 parent 5efb99a commit 412b801

1 file changed

Lines changed: 17 additions & 12 deletions

File tree

beginner_source/introyt/captumyt.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
the review is an example of feature attribution.
4949
- **Layer Attribution** examines the activity of a model’s hidden layer
5050
subsequent to a particular input. Examining the spatially-mapped
51-
output of a convolutional layer in response to an input image in an
51+
output of a convolutional layer in response to an input image is an
5252
example of layer attribution.
5353
- **Neuron Attribution** is analagous to layer attribution, but focuses
5454
on the activity of a single neuron.
@@ -97,14 +97,13 @@
9797
9898
Before you get started, you need to have a Python environment with:
9999
100-
- Python version 3.6 or higher
100+
- Python version 3.9.0 or higher
101101
- For the Captum Insights example, Flask 1.1 or higher and Flask-Compress
102-
(the latest version is recommended)
103-
- PyTorch version 1.2 or higher (the latest version is recommended)
104-
- TorchVision version 0.6 or higher (the latest version is recommended)
105-
- Captum (the latest version is recommended)
106-
- Matplotlib version 3.3.4, since Captum currently uses a Matplotlib
107-
function whose arguments have been renamed in later versions
102+
- PyTorch version 1.10.0 or higher
103+
- TorchVision version 0.11.0 or higher
104+
- Captum version 0.8.0, as Captum Insights was retired after this version
105+
and is no longer supported
106+
- Matplotlib version 3.5.0 or higher
108107
109108
To install Captum in an Anaconda or pip virtual environment, use the
110109
appropriate command for your environment below:
@@ -113,13 +112,13 @@
113112
114113
.. code-block:: sh
115114
116-
conda install pytorch torchvision captum flask-compress matplotlib=3.3.4 -c pytorch
115+
conda install "pytorch>=1.10.0" "torchvision>=0.11.0" "captum=0.8.0" "matplotlib>=3.5.0" flask-compress ipywidgets -c pytorch
117116
118117
With ``pip``:
119118
120119
.. code-block:: sh
121120
122-
pip install torch torchvision captum matplotlib==3.3.4 Flask-Compress
121+
pip install "torch>=1.10.0" "torchvision>=0.11.0" "captum==0.8.0" "matplotlib>=3.5.0" flask-compress ipywidgets
123122
124123
Restart this notebook in the environment you set up, and you’re ready to
125124
go!
@@ -266,8 +265,10 @@
266265
attributions_ig = integrated_gradients.attribute(input_img, target=pred_label_idx, n_steps=200)
267266

268267
# Show the original image for comparison
269-
_ = viz.visualize_image_attr(None, np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
270-
method="original_image", title="Original Image")
268+
_ = viz.visualize_image_attr(np.ones_like(np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0))),
269+
np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
270+
method="original_image",
271+
title="Original Image")
271272

272273
default_cmap = LinearSegmentedColormap.from_list('custom blue',
273274
[(0, '#ffffff'),
@@ -495,5 +496,9 @@ def full_img_transform(input):
495496
# model’s predictions with associated probabilities, and view heatmaps of
496497
# the attribution compared with the original image.
497498
#
499+
# If you are not using a Jupyter Notebook, you can use
500+
# ``visualizer.serve(debug=True)`` instead. This starts a local web server
501+
# and prints a URL that you can open in your browser.
502+
#
498503

499504
visualizer.render()

0 commit comments

Comments
 (0)