|
82 | 82 | - The ``captum.attr.visualization`` module (imported below as ``viz``) |
83 | 83 | provides helpful functions for visualizing attributions related to |
84 | 84 | images. |
85 | | -- **Captum Insights** is an easy-to-use API on top of Captum that |
86 | | - provides a visualization widget with ready-made visualizations for |
87 | | - image, text, and arbitrary model types. |
88 | 85 |
|
89 | | -Both of these visualization toolsets will be demonstrated in this |
90 | | -notebook. The first few examples will focus on computer vision use |
91 | | -cases, but the Captum Insights section at the end will demonstrate |
92 | | -visualization of attributions in a multi-model, visual |
93 | | -question-and-answer model. |
| 86 | +This visualization toolset will be demonstrated throughout this notebook. |
94 | 87 |
|
95 | 88 | Installation |
96 | 89 | ------------ |
97 | 90 |
|
98 | 91 | Before you get started, you need to have a Python environment with: |
99 | 92 |
|
100 | | -- Python version 3.6 or higher |
101 | | -- For the Captum Insights example, Flask 1.1 or higher and Flask-Compress |
102 | | - (the latest version is recommended) |
103 | | -- PyTorch version 1.2 or higher (the latest version is recommended) |
104 | | -- TorchVision version 0.6 or higher (the latest version is recommended) |
| 93 | +- Python version 3.9 or higher |
| 94 | +- PyTorch (the latest version is recommended) |
| 95 | +- TorchVision (the latest version is recommended) |
105 | 96 | - Captum (the latest version is recommended) |
106 | | -- Matplotlib version 3.3.4, since Captum currently uses a Matplotlib |
107 | | - function whose arguments have been renamed in later versions |
| 97 | +- Matplotlib (the latest version is recommended) |
108 | 98 |
|
109 | 99 | To install Captum in a virtual environment, use: |
110 | 100 |
|
111 | 101 | .. code-block:: sh |
112 | 102 |
|
113 | | - pip install torch torchvision captum matplotlib==3.3.4 Flask-Compress |
| 103 | + pip install torch torchvision captum matplotlib |
114 | 104 |
|
115 | 105 | Restart this notebook in the environment you set up, and you’re ready to |
116 | 106 | go! |
|
257 | 247 | attributions_ig = integrated_gradients.attribute(input_img, target=pred_label_idx, n_steps=200) |
258 | 248 |
|
259 | 249 | # Show the original image for comparison |
260 | | -_ = viz.visualize_image_attr(None, np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)), |
| 250 | +_ = viz.visualize_image_attr(None, np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)), |
261 | 251 | method="original_image", title="Original Image") |
262 | 252 |
|
263 | 253 | default_cmap = LinearSegmentedColormap.from_list('custom blue', |
|
385 | 375 | # Visualizations such as this can give you novel insights into how your |
386 | 376 | # hidden layers respond to your input. |
387 | 377 | # |
388 | | - |
389 | | - |
390 | | -########################################################################## |
391 | | -# Visualization with Captum Insights |
392 | | -# ---------------------------------- |
393 | | -# |
394 | | -# Captum Insights is an interpretability visualization widget built on top |
395 | | -# of Captum to facilitate model understanding. Captum Insights works |
396 | | -# across images, text, and other features to help users understand feature |
397 | | -# attribution. It allows you to visualize attribution for multiple |
398 | | -# input/output pairs, and provides visualization tools for image, text, |
399 | | -# and arbitrary data. |
400 | | -# |
401 | | -# In this section of the notebook, we’ll visualize multiple image |
402 | | -# classification inferences with Captum Insights. |
403 | | -# |
404 | | -# First, let’s gather some image and see what the model thinks of them. |
405 | | -# For variety, we’ll take our cat, a teapot, and a trilobite fossil: |
406 | | -# |
407 | | - |
408 | | -imgs = ['img/cat.jpg', 'img/teapot.jpg', 'img/trilobite.jpg'] |
409 | | - |
410 | | -for img in imgs: |
411 | | - img = Image.open(img) |
412 | | - transformed_img = transform(img) |
413 | | - input_img = transform_normalize(transformed_img) |
414 | | - input_img = input_img.unsqueeze(0) # the model requires a dummy batch dimension |
415 | | - |
416 | | - output = model(input_img) |
417 | | - output = F.softmax(output, dim=1) |
418 | | - prediction_score, pred_label_idx = torch.topk(output, 1) |
419 | | - pred_label_idx.squeeze_() |
420 | | - predicted_label = idx_to_labels[str(pred_label_idx.item())][1] |
421 | | - print('Predicted:', predicted_label, '/', pred_label_idx.item(), ' (', prediction_score.squeeze().item(), ')') |
422 | | - |
423 | | - |
424 | | -########################################################################## |
425 | | -# …and it looks like our model is identifying them all correctly - but of |
426 | | -# course, we want to dig deeper. For that we’ll use the Captum Insights |
427 | | -# widget, which we configure with an ``AttributionVisualizer`` object, |
428 | | -# imported below. The ``AttributionVisualizer`` expects batches of data, |
429 | | -# so we’ll bring in Captum’s ``Batch`` helper class. And we’ll be looking |
430 | | -# at images specifically, so well also import ``ImageFeature``. |
431 | | -# |
432 | | -# We configure the ``AttributionVisualizer`` with the following arguments: |
433 | | -# |
434 | | -# - An array of models to be examined (in our case, just the one) |
435 | | -# - A scoring function, which allows Captum Insights to pull out the |
436 | | -# top-k predictions from a model |
437 | | -# - An ordered, human-readable list of classes our model is trained on |
438 | | -# - A list of features to look for - in our case, an ``ImageFeature`` |
439 | | -# - A dataset, which is an iterable object returning batches of inputs |
440 | | -# and labels - just like you’d use for training |
441 | | -# |
442 | | - |
443 | | -from captum.insights import AttributionVisualizer, Batch |
444 | | -from captum.insights.attr_vis.features import ImageFeature |
445 | | - |
446 | | -# Baseline is all-zeros input - this may differ depending on your data |
447 | | -def baseline_func(input): |
448 | | - return input * 0 |
449 | | - |
450 | | -# merging our image transforms from above |
451 | | -def full_img_transform(input): |
452 | | - i = Image.open(input) |
453 | | - i = transform(i) |
454 | | - i = transform_normalize(i) |
455 | | - i = i.unsqueeze(0) |
456 | | - return i |
457 | | - |
458 | | - |
459 | | -input_imgs = torch.cat(list(map(lambda i: full_img_transform(i), imgs)), 0) |
460 | | - |
461 | | -visualizer = AttributionVisualizer( |
462 | | - models=[model], |
463 | | - score_func=lambda o: torch.nn.functional.softmax(o, 1), |
464 | | - classes=list(map(lambda k: idx_to_labels[k][1], idx_to_labels.keys())), |
465 | | - features=[ |
466 | | - ImageFeature( |
467 | | - "Photo", |
468 | | - baseline_transforms=[baseline_func], |
469 | | - input_transforms=[], |
470 | | - ) |
471 | | - ], |
472 | | - dataset=[Batch(input_imgs, labels=[282,849,69])] |
473 | | -) |
474 | | - |
475 | | - |
476 | | -######################################################################### |
477 | | -# Note that running the cell above didn’t take much time at all, unlike |
478 | | -# our attributions above. That’s because Captum Insights lets you |
479 | | -# configure different attribution algorithms in a visual widget, after |
480 | | -# which it will compute and display the attributions. *That* process will |
481 | | -# take a few minutes. |
482 | | -# |
483 | | -# Running the cell below will render the Captum Insights widget. You can |
484 | | -# then choose attributions methods and their arguments, filter model |
485 | | -# responses based on predicted class or prediction correctness, see the |
486 | | -# model’s predictions with associated probabilities, and view heatmaps of |
487 | | -# the attribution compared with the original image. |
488 | | -# |
489 | | - |
490 | | -visualizer.render() |
0 commit comments