Skip to content

Commit 07403fa

Browse files
thesteve0claudesekyondaMeta
authored
Issues 3859 (#3896)
Fixes #3859 ## Description Did the fixes from the issue and also updated the first attribution example to work with the updated captum libraries ## Checklist <!--- Make sure to add `x` to all items in the following checklist: --> - [x ] The issue that is being fixed is referred in the description (see above "Fixes #ISSUE_NUMBER") - [ x] Only one issue is addressed in this pull request - [ ] Labels from the issue that this PR is fixing are added to this pull request - [x ] No unnecessary issues are included into this pull request. cc @subramen --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com>
1 parent bcb7e29 commit 07403fa

1 file changed

Lines changed: 7 additions & 120 deletions

File tree

beginner_source/introyt/captumyt.py

Lines changed: 7 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -82,35 +82,25 @@
8282
- The ``captum.attr.visualization`` module (imported below as ``viz``)
8383
provides helpful functions for visualizing attributions related to
8484
images.
85-
- **Captum Insights** is an easy-to-use API on top of Captum that
86-
provides a visualization widget with ready-made visualizations for
87-
image, text, and arbitrary model types.
8885
89-
Both of these visualization toolsets will be demonstrated in this
90-
notebook. The first few examples will focus on computer vision use
91-
cases, but the Captum Insights section at the end will demonstrate
92-
visualization of attributions in a multi-model, visual
93-
question-and-answer model.
86+
This visualization toolset will be demonstrated throughout this notebook.
9487
9588
Installation
9689
------------
9790
9891
Before you get started, you need to have a Python environment with:
9992
100-
- Python version 3.6 or higher
101-
- For the Captum Insights example, Flask 1.1 or higher and Flask-Compress
102-
(the latest version is recommended)
103-
- PyTorch version 1.2 or higher (the latest version is recommended)
104-
- TorchVision version 0.6 or higher (the latest version is recommended)
93+
- Python version 3.9 or higher
94+
- PyTorch (the latest version is recommended)
95+
- TorchVision (the latest version is recommended)
10596
- Captum (the latest version is recommended)
106-
- Matplotlib version 3.3.4, since Captum currently uses a Matplotlib
107-
function whose arguments have been renamed in later versions
97+
- Matplotlib (the latest version is recommended)
10898
10999
To install Captum in a virtual environment, use:
110100
111101
.. code-block:: sh
112102
113-
pip install torch torchvision captum matplotlib==3.3.4 Flask-Compress
103+
pip install torch torchvision captum matplotlib
114104
115105
Restart this notebook in the environment you set up, and you’re ready to
116106
go!
@@ -257,7 +247,7 @@
257247
attributions_ig = integrated_gradients.attribute(input_img, target=pred_label_idx, n_steps=200)
258248

259249
# Show the original image for comparison
260-
_ = viz.visualize_image_attr(None, np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
250+
_ = viz.visualize_image_attr(None, np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
261251
method="original_image", title="Original Image")
262252

263253
default_cmap = LinearSegmentedColormap.from_list('custom blue',
@@ -385,106 +375,3 @@
385375
# Visualizations such as this can give you novel insights into how your
386376
# hidden layers respond to your input.
387377
#
388-
389-
390-
##########################################################################
391-
# Visualization with Captum Insights
392-
# ----------------------------------
393-
#
394-
# Captum Insights is an interpretability visualization widget built on top
395-
# of Captum to facilitate model understanding. Captum Insights works
396-
# across images, text, and other features to help users understand feature
397-
# attribution. It allows you to visualize attribution for multiple
398-
# input/output pairs, and provides visualization tools for image, text,
399-
# and arbitrary data.
400-
#
401-
# In this section of the notebook, we’ll visualize multiple image
402-
# classification inferences with Captum Insights.
403-
#
404-
# First, let’s gather some image and see what the model thinks of them.
405-
# For variety, we’ll take our cat, a teapot, and a trilobite fossil:
406-
#
407-
408-
imgs = ['img/cat.jpg', 'img/teapot.jpg', 'img/trilobite.jpg']
409-
410-
for img in imgs:
411-
img = Image.open(img)
412-
transformed_img = transform(img)
413-
input_img = transform_normalize(transformed_img)
414-
input_img = input_img.unsqueeze(0) # the model requires a dummy batch dimension
415-
416-
output = model(input_img)
417-
output = F.softmax(output, dim=1)
418-
prediction_score, pred_label_idx = torch.topk(output, 1)
419-
pred_label_idx.squeeze_()
420-
predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
421-
print('Predicted:', predicted_label, '/', pred_label_idx.item(), ' (', prediction_score.squeeze().item(), ')')
422-
423-
424-
##########################################################################
425-
# …and it looks like our model is identifying them all correctly - but of
426-
# course, we want to dig deeper. For that we’ll use the Captum Insights
427-
# widget, which we configure with an ``AttributionVisualizer`` object,
428-
# imported below. The ``AttributionVisualizer`` expects batches of data,
429-
# so we’ll bring in Captum’s ``Batch`` helper class. And we’ll be looking
430-
# at images specifically, so well also import ``ImageFeature``.
431-
#
432-
# We configure the ``AttributionVisualizer`` with the following arguments:
433-
#
434-
# - An array of models to be examined (in our case, just the one)
435-
# - A scoring function, which allows Captum Insights to pull out the
436-
# top-k predictions from a model
437-
# - An ordered, human-readable list of classes our model is trained on
438-
# - A list of features to look for - in our case, an ``ImageFeature``
439-
# - A dataset, which is an iterable object returning batches of inputs
440-
# and labels - just like you’d use for training
441-
#
442-
443-
from captum.insights import AttributionVisualizer, Batch
444-
from captum.insights.attr_vis.features import ImageFeature
445-
446-
# Baseline is all-zeros input - this may differ depending on your data
447-
def baseline_func(input):
448-
return input * 0
449-
450-
# merging our image transforms from above
451-
def full_img_transform(input):
452-
i = Image.open(input)
453-
i = transform(i)
454-
i = transform_normalize(i)
455-
i = i.unsqueeze(0)
456-
return i
457-
458-
459-
input_imgs = torch.cat(list(map(lambda i: full_img_transform(i), imgs)), 0)
460-
461-
visualizer = AttributionVisualizer(
462-
models=[model],
463-
score_func=lambda o: torch.nn.functional.softmax(o, 1),
464-
classes=list(map(lambda k: idx_to_labels[k][1], idx_to_labels.keys())),
465-
features=[
466-
ImageFeature(
467-
"Photo",
468-
baseline_transforms=[baseline_func],
469-
input_transforms=[],
470-
)
471-
],
472-
dataset=[Batch(input_imgs, labels=[282,849,69])]
473-
)
474-
475-
476-
#########################################################################
477-
# Note that running the cell above didn’t take much time at all, unlike
478-
# our attributions above. That’s because Captum Insights lets you
479-
# configure different attribution algorithms in a visual widget, after
480-
# which it will compute and display the attributions. *That* process will
481-
# take a few minutes.
482-
#
483-
# Running the cell below will render the Captum Insights widget. You can
484-
# then choose attributions methods and their arguments, filter model
485-
# responses based on predicted class or prediction correctness, see the
486-
# model’s predictions with associated probabilities, and view heatmaps of
487-
# the attribution compared with the original image.
488-
#
489-
490-
visualizer.render()

0 commit comments

Comments
 (0)