Skip to content

Commit 677f510

Browse files
committed
Bug Fix + Added option to get Label from Layer
1 parent c3edce6 commit 677f510

11 files changed

Lines changed: 110 additions & 68 deletions

File tree

PyTorchLayerViz.egg-info/PKG-INFO

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Metadata-Version: 2.1
22
Name: PyTorchLayerViz
3-
Version: 1.2.1
3+
Version: 1.2.4
44
Summary: PyTorchLayerViz is a Python library that allows you to visualize the weights and feature maps of a PyTorch model.
55
Author: Simone Panico
66
Author-email: simone.panico@icloud.com
@@ -73,7 +73,7 @@ layers_to_check = [nn.Conv2d] # Define all Layers you want to pass your picture
7373

7474
input_image_path = 'pictures/hamburger.jpg' # Path to your example picture
7575

76-
numpyArr = get_feature_maps(model = model, layers_to_check = layers_to_check, input_image_path = input_image_path) # Call function from pytorchlayerviz
76+
numpyArr = get_feature_maps(model = model, layers_to_check = layers_to_check, input_image_path = input_image_path, print_image=True) # Call function from pytorchlayerviz
7777
```
7878

7979
### Parameters
@@ -83,6 +83,7 @@ numpyArr = get_feature_maps(model = model, layers_to_check = layers_to_check, in
8383
- **input_image_path (str)** – Path to the input image file. *Required*.
8484
- **transform (transforms.Compose, optional)** – A function/transform that takes in an image and returns a transformed version. Default is None. *Optional*.
8585
- **sequential_order (bool, optional)** – If True, the layers are visualized in the order they are defined in the model. If false it will first go through the first layer defined in the arrDefault is True. *Optional*.
86+
- **print_image (bool, optional)** – If True the Images are getting printed with matplotlib. Default is False. *Optional*.
8687

8788
**Return** The function 'get_feature_maps()` returns the pictures as NumPy Arrays
8889

@@ -117,7 +118,7 @@ pretrained_model = models.vgg16(pretrained=True)
117118
input_image_path = 'hamburger.jpg'
118119
layers_to_check= [nn.MaxPool2d]
119120

120-
numpyArr = get_feature_maps(model = pretrained_model, layers_to_check = layers_to_check, input_image_path = input_image_path, sequential_order = False)
121+
numpyArr = get_feature_maps(model = pretrained_model, layers_to_check = layers_to_check, input_image_path = input_image_path, sequential_order = False, print_image = True)
121122
```
122123

123124
### Output

PyTorchLayerViz.egg-info/SOURCES.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ PyTorchLayerViz.egg-info/SOURCES.txt
88
PyTorchLayerViz.egg-info/commands.txt
99
PyTorchLayerViz.egg-info/dependency_links.txt
1010
PyTorchLayerViz.egg-info/requires.txt
11-
PyTorchLayerViz.egg-info/top_level.txt
11+
PyTorchLayerViz.egg-info/top_level.txt
12+
test/test_get_feature_maps.py
462 Bytes
Binary file not shown.

PyTorchLayerViz/main.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import matplotlib.pyplot as plt
77

88
##########################################################
9+
10+
911
def extract_layers_and_weights(model, layers_to_check, sequential_order=True):
1012
# Initialize dictionaries and lists to store layers and their weights
1113
layer_weights = {}
@@ -126,54 +128,61 @@ def plot_feature_maps(processed_feature_maps, layer_names):
126128
ax.set_title(layer_names[i].split("(")[0], fontsize=30)
127129
plot_index += 1
128130
else:
129-
print(f"Skipping feature map at index {i} with shape: {feature_map.shape}")
131+
print(f"Skipping feature map at index {
132+
i} with shape: {feature_map.shape}")
130133
plt.show()
131134

132135

133136
##########################################################
134137

135138

136139
def get_feature_maps(
137-
model, layers_to_check, input_image_path, transform=None, sequential_order=True, print_image=False
140+
model, layers_to_check, input_image_path, transform=None, print_image=False
138141
):
139142
if transform is None:
140143
# Define the image transformations
141144
transform = transforms.Compose(
142145
[
143-
transforms.Resize((224, 224)),
144-
transforms.ToTensor(),
145-
# transforms.Normalize(mean=0., std=1.)
146+
transforms.Resize((224, 224)),
147+
transforms.ToTensor(),
148+
# transforms.Normalize(mean=0., std=1.)
146149
]
147150
)
148151

149-
# Example usage
150-
input_image = Image.open(input_image_path)
152+
# Load and preprocess the image
153+
input_image = Image.open(input_image_path)
151154
input_image = transform(input_image)
152-
input_image = input_image.unsqueeze(0)
155+
input_image = input_image.unsqueeze(0)
153156

154-
if sequential_order:
155-
ordered_layers, ordered_layer_names = extract_layers_and_weights(
156-
model, layers_to_check, sequential_order
157-
)
158-
feature_maps, layer_names = extract_feature_maps(ordered_layers, input_image)
159-
else:
160-
layer_weights, layers, layer_counters = extract_layers_and_weights(
161-
model, layers_to_check, sequential_order
162-
)
163-
feature_maps, layer_names = extract_feature_maps(
164-
[layer for layer_list in layers.values() for layer in layer_list],
165-
input_image,
166-
)
157+
activations = {}
158+
# Register hooks
159+
160+
def get_activation(name):
161+
def hook(model, input, output):
162+
activations[name] = output.detach()
163+
return hook
164+
165+
# Register hooks on the layers of interest
166+
for name, module in model.named_modules():
167+
if isinstance(module, tuple(layers_to_check)):
168+
module.register_forward_hook(get_activation(name))
169+
170+
# Pass the input through the model
171+
_ = model(input_image)
172+
173+
# Now activations dictionary has the outputs
174+
feature_maps = list(activations.values())
175+
layer_names = list(activations.keys())
167176

168177
processed_feature_maps = process_feature_maps(feature_maps)
169178
if print_image:
170179
plot_feature_maps(processed_feature_maps, layer_names)
171180
images = []
172-
for feature_map in processed_feature_maps:
181+
for feature_map, layer_name in zip(processed_feature_maps, layer_names):
173182
if feature_map.ndim == 2: # Grayscale feature map
174183
colored_feature_map = apply_colormap(feature_map)
175-
images.append(colored_feature_map)
176-
elif feature_map.ndim == 3 and feature_map.shape[0] == 3: # RGB feature map
177-
images.append(np.transpose(feature_map, (1, 2, 0)))
178-
179-
return images
184+
images.append((colored_feature_map, layer_name))
185+
# RGB feature map
186+
elif feature_map.ndim == 3 and feature_map.shape[0] == 3:
187+
images.append((np.transpose(feature_map, (1, 2, 0)), layer_name))
188+
return images

build/lib/PyTorchLayerViz/main.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import matplotlib.pyplot as plt
77

88
##########################################################
9+
10+
911
def extract_layers_and_weights(model, layers_to_check, sequential_order=True):
1012
# Initialize dictionaries and lists to store layers and their weights
1113
layer_weights = {}
@@ -126,53 +128,61 @@ def plot_feature_maps(processed_feature_maps, layer_names):
126128
ax.set_title(layer_names[i].split("(")[0], fontsize=30)
127129
plot_index += 1
128130
else:
129-
print(f"Skipping feature map at index {i} with shape: {feature_map.shape}")
131+
print(f"Skipping feature map at index {
132+
i} with shape: {feature_map.shape}")
130133
plt.show()
131134

132135

133136
##########################################################
134137

135138

136139
def get_feature_maps(
137-
model, layers_to_check, input_image_path, transform=None, sequential_order=True
140+
model, layers_to_check, input_image_path, transform=None, print_image=False
138141
):
139142
if transform is None:
140143
# Define the image transformations
141144
transform = transforms.Compose(
142145
[
143-
transforms.Resize((224, 224)), # Resize the image to 224x224 pixels
144-
transforms.ToTensor(), # Convert the image to a PyTorch tensor
145-
# transforms.Normalize(mean=0., std=1.) # Normalize the image tensor
146+
transforms.Resize((224, 224)),
147+
transforms.ToTensor(),
148+
# transforms.Normalize(mean=0., std=1.)
146149
]
147150
)
148151

149-
# Example usage
150-
input_image = Image.open(input_image_path) # add your image path
152+
# Load and preprocess the image
153+
input_image = Image.open(input_image_path)
151154
input_image = transform(input_image)
152-
input_image = input_image.unsqueeze(0) # Add a batch dimension
155+
input_image = input_image.unsqueeze(0)
153156

154-
if sequential_order:
155-
ordered_layers, ordered_layer_names = extract_layers_and_weights(
156-
model, layers_to_check, sequential_order
157-
)
158-
feature_maps, layer_names = extract_feature_maps(ordered_layers, input_image)
159-
else:
160-
layer_weights, layers, layer_counters = extract_layers_and_weights(
161-
model, layers_to_check, sequential_order
162-
)
163-
feature_maps, layer_names = extract_feature_maps(
164-
[layer for layer_list in layers.values() for layer in layer_list],
165-
input_image,
166-
)
157+
activations = {}
158+
# Register hooks
159+
160+
def get_activation(name):
161+
def hook(model, input, output):
162+
activations[name] = output.detach()
163+
return hook
164+
165+
# Register hooks on the layers of interest
166+
for name, module in model.named_modules():
167+
if isinstance(module, tuple(layers_to_check)):
168+
module.register_forward_hook(get_activation(name))
169+
170+
# Pass the input through the model
171+
_ = model(input_image)
172+
173+
# Now activations dictionary has the outputs
174+
feature_maps = list(activations.values())
175+
layer_names = list(activations.keys())
167176

168177
processed_feature_maps = process_feature_maps(feature_maps)
169-
plot_feature_maps(processed_feature_maps, layer_names)
178+
if print_image:
179+
plot_feature_maps(processed_feature_maps, layer_names)
170180
images = []
171-
for feature_map in processed_feature_maps:
181+
for feature_map, layer_name in zip(processed_feature_maps, layer_names):
172182
if feature_map.ndim == 2: # Grayscale feature map
173183
colored_feature_map = apply_colormap(feature_map)
174-
images.append(colored_feature_map)
175-
elif feature_map.ndim == 3 and feature_map.shape[0] == 3: # RGB feature map
176-
images.append(np.transpose(feature_map, (1, 2, 0)))
177-
178-
return images
184+
images.append((colored_feature_map, layer_name))
185+
# RGB feature map
186+
elif feature_map.ndim == 3 and feature_map.shape[0] == 3:
187+
images.append((np.transpose(feature_map, (1, 2, 0)), layer_name))
188+
return images
5.98 KB
Binary file not shown.
6.09 KB
Binary file not shown.

dist/pytorchlayerviz-1.2.3.tar.gz

5.97 KB
Binary file not shown.

dist/pytorchlayerviz-1.2.4.tar.gz

6.08 KB
Binary file not shown.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
with open("README.md", "r") as f:
44
LONG_DESCRIPTION = f.read()
55

6-
VERSION = '1.2.2'
6+
VERSION = '1.2.4'
77
DESCRIPTION = "PyTorchLayerViz is a Python library that allows you to visualize the weights and feature maps of a PyTorch model."
88

99
setup(

0 commit comments

Comments
 (0)