-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGradCAM.py
More file actions
205 lines (162 loc) · 8.25 KB
/
GradCAM.py
File metadata and controls
205 lines (162 loc) · 8.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 7 16:06:49 2024
@author: tc922
"""
from torch import nn
import torch
from timm import create_model
import numpy as np
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import math
class GradCAM:
def __init__(self,model,
img_path:str,
layer_idx: int = 2,
input_shape = (224,224),
model_type: str = 'Normal',
relu_threshold = 8
):
"""
params:
layer_idx: the index of the layer where you want to visulize (it count
from classifier to input layer).
input_shape: the image shape to put into the model
model: the model you want to visualize
img_path: the path of the tested image
model_type: some special model need addtional method to process the activations
in order to get Grad-CAM. Currently, there's only function to handel vision
transformer.
relu_threshold: when the heatmap has too many/small highlight point, you can ajdust by reduce/increase this number
"""
self.model = model
self.layer_idx = -layer_idx
self.img_path = img_path
self.input_shape = input_shape
self.model_type = model_type
self.relu_threshold = relu_threshold
print('The model types you can select from are \'Normal\', \'ViT\'')
def __call__(self):
model = self.model
extractor = nn.Sequential(*list(model.children())[:self.layer_idx]) #truncate the model from where your specified idx
classifier = nn.Sequential(*list(model.children())[self.layer_idx:])
#print(extractor)
img = Image.open(self.img_path).convert('RGB').resize(self.input_shape)
transform = transforms.Compose([transforms.ToTensor()])
img = transform(img)
img = torch.unsqueeze(img, 0) #preprocess the image to tensor: (1,C,H,W)
img.requires_grad = True
result_ = model(img)
result = result_.argmax() #get the prediction category
class_Idx = result #the grad-cam will visualize the prediction towards a specific category. (Here is the model's prediction)
self.result_class = class_Idx #make the result accessible from outside
softmax = nn.Softmax(dim = -1) #convert the logits into probabiliries
self.confidence = round(float(softmax(result_)[0][class_Idx])*100,2) #make the model's confidence to the prediction accessible
activations = extractor(img) #get the activations (output features) from target layer
activations.retain_grad() #pytorch will automatically free the parameters' gradients that are not provided by user
#so here it needs to be specified to keep the gradient of activation w.r.t prediction logits
print(f'Activation Shape:{activations.shape}')
prediction_logits = classifier(activations) #the activation is fed into rest layers to get the prediction tensor
print(f'Prediction_logits Shape:{prediction_logits.shape}')
print(f'The Grad-CAM will be plotted based on model prediction result: {class_Idx}')
if self.model_type == 'Normal':
prediction_logits = prediction_logits[:,class_Idx] #only use the specific class to back propagate
elif self.model_type == 'ViT':
prediction_logits = prediction_logits[:,:,class_Idx]
grad_output = torch.ones_like(prediction_logits)
print(grad_output.shape)
prediction_logits.backward(gradient = grad_output) #according to pytorch, backward() should specify
#with a tensor which its length is the same as backward tensor
#when the tensor contains more than one number
d_act = activations.grad #get the gradient of activation from target layer w.r.t. the specified category
if self.model_type == 'Normal':
d_act = d_act.permute(0,2,3,1) #(1,C,H,W) -> (1,H,W,C)
activations = activations.permute(0,2,3,1)
elif self.model_type == 'ViT':
d_act = self.output_decompose_vit_grad_cam(d_act)
activations = self.output_decompose_vit_grad_cam(activations[:,:,:])
print(f'gradient shape (predictioin logti(s) w.r.t. feature logits: ){d_act.shape}')
pooled_grads = torch.mean(d_act,dim = (0,1,2)) #according to the paper, the pooling happens all axis except the channel dim
heatmap = activations.detach().numpy()[0] #for tensors where its requires_grad = True, need detach() function to convert to ndarray
pooled_grads = pooled_grads.numpy()
#back propagate
for i in range(d_act.shape[-1]):
heatmap[:,:,i] *= pooled_grads[i]
print(f'level.1 pooling on heatmap: {heatmap.shape}')
heatmap = np.mean(heatmap, axis = -1) #here the heatmap shapes as (H,W,1)
print(f'level.2 pooling on heatmap: {heatmap.shape}')
print(f'Maximum pixel value of heatmap is {heatmap.max()}')
threshold = heatmap.max()/self.relu_threshold
#threshold = 0
heatmap_filtered = np.maximum(heatmap,threshold)
heatmap_filtered[heatmap_filtered == threshold] = threshold * 0.55
heatmap_filtered[heatmap_filtered < threshold/2] = 0
#print(f'Minimum value of heatmap{heatmap_filtered.min()}')
self.heatmap = np.uint8(255*heatmap_filtered/np.max(heatmap)) #keep the logits that are greater than specific value
#in the paper, that is to say, only keep the positive influence with the specific class.
#then normalize the heatmap and recale its value range from 0 to 255.
def origin_cam_visualization(self,save_path = None):
#display the orignal size heatmap (H,W,1)
plt.rcParams.update({'font.size': 14})
plt.matshow(self.heatmap)
plt.title('Original generated heatmap')
plt.show()
if save_path != None:
plt.savefig(save_path)
def imposing_visualization(self,save_path = None):
alpha = 0.5 #how much CAM will overlap on original image
plt.figure(figsize = (20,20))
plt.rcParams.update({'font.size': 18})
jet = cm.get_cmap('jet') #create the color map object
jet_colors = jet(np.arange(256))[:,:3]
print(f'jet_color shape: {jet_colors.shape}')
jet_colors = (jet_colors*256).astype(np.uint8) #generate a color (RGB) image which has small H and W
# and maps the intensity to color from red to blue
#jet_heatmap = (jet_colors[self.heatmap] * alpha).astype(np.uint8)
#print(jet_heatmap.shape)
#jet_heatmap = Image.fromarray(jet_heatmap).resize(self.input_shape)
self.heatmap = Image.fromarray(self.heatmap).resize(self.input_shape)
jet_heatmap = (jet_colors[np.uint8(self.heatmap)] ).astype(np.uint8)
img = Image.open(self.img_path).convert('RGB').resize(self.input_shape)
jet_heatmap = np.asarray(jet_heatmap)
img_cam = np.uint8(np.asarray(img)*(1-alpha)*1.2) + np.uint8(np.asarray(jet_heatmap* alpha))
plt.subplot(2,2,1)
plt.xticks([])
plt.yticks([])
plt.imshow(img)
plt.title('Original Image')
plt.subplot(2,2,2)
plt.xticks([])
plt.yticks([])
plt.imshow(img_cam) #print the overlapped image (origin + cam)
plt.title('Overlapped Colormap Image')
plt.subplot(2,2,3)
plt.xticks([])
plt.yticks([])
plt.imshow(self.heatmap)
plt.title('Heatmap (2-D Magnitude)')
plt.subplot(2,2,4)
plt.xticks([])
plt.yticks([])
plt.imshow(jet_heatmap/255) #print the cam
plt.title('Projected Colormap (3-D)')
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=-0.05, hspace=0.1)
if save_path != None:
plt.savefig(save_path,bbox_inches = 'tight', pad_inches = 0.3)
name = save_path.split('.')[0]
img.save(name+'-origin'+'.png')
Image.fromarray(img_cam).save(name+'-overlapped'+'.png')
self.heatmap.save(name+'-heatmap'+'.png')
Image.fromarray(jet_heatmap).save(name+'-colormap'+'.png')
def output_decompose_vit_grad_cam(self, vit_input):
#decompose on vit's sequence dimension
_, HW, C = vit_input.shape
print(HW)
HW = int(math.sqrt(HW))
print(HW)
vit_output = torch.reshape(vit_input,(1,HW,HW,C))
#vit_output = vit_output.permute(0,3,1,2)
return vit_output