Skip to content

Commit ced5e79

Browse files
committed
Added Unittest + the option to disable the print of the Images
1 parent 6d793cf commit ced5e79

10 files changed

Lines changed: 108 additions & 113 deletions

File tree

22 Bytes
Binary file not shown.

PyTorchLayerViz/main.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,22 @@ def plot_feature_maps(processed_feature_maps, layer_names):
134134

135135

136136
def get_feature_maps(
137-
model, layers_to_check, input_image_path, transform=None, sequential_order=True
137+
model, layers_to_check, input_image_path, transform=None, sequential_order=True, print_image=False
138138
):
139139
if transform is None:
140140
# Define the image transformations
141141
transform = transforms.Compose(
142142
[
143-
transforms.Resize((224, 224)), # Resize the image to 224x224 pixels
144-
transforms.ToTensor(), # Convert the image to a PyTorch tensor
145-
# transforms.Normalize(mean=0., std=1.) # Normalize the image tensor
143+
transforms.Resize((224, 224)),
144+
transforms.ToTensor(),
145+
# transforms.Normalize(mean=0., std=1.)
146146
]
147147
)
148148

149149
# Example usage
150-
input_image = Image.open(input_image_path) # add your image path
150+
input_image = Image.open(input_image_path)
151151
input_image = transform(input_image)
152-
input_image = input_image.unsqueeze(0) # Add a batch dimension
152+
input_image = input_image.unsqueeze(0)
153153

154154
if sequential_order:
155155
ordered_layers, ordered_layer_names = extract_layers_and_weights(
@@ -166,7 +166,8 @@ def get_feature_maps(
166166
)
167167

168168
processed_feature_maps = process_feature_maps(feature_maps)
169-
plot_feature_maps(processed_feature_maps, layer_names)
169+
if print_image:
170+
plot_feature_maps(processed_feature_maps, layer_names)
170171
images = []
171172
for feature_map in processed_feature_maps:
172173
if feature_map.ndim == 2: # Grayscale feature map

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ layers_to_check = [nn.Conv2d] # Define all Layers you want to pass your picture
5959

6060
input_image_path = 'pictures/hamburger.jpg' # Path to your example picture
6161

62-
numpyArr = get_feature_maps(model = model, layers_to_check = layers_to_check, input_image_path = input_image_path) # Call function from pytorchlayerviz
62+
numpyArr = get_feature_maps(model = model, layers_to_check = layers_to_check, input_image_path = input_image_path, print_image=True) # Call function from pytorchlayerviz
6363
```
6464

6565
### Parameters
@@ -69,6 +69,7 @@ numpyArr = get_feature_maps(model = model, layers_to_check = layers_to_check, in
6969
- **input_image_path (str)** – Path to the input image file. *Required*.
7070
- **transform (transforms.Compose, optional)** – A function/transform that takes in an image and returns a transformed version. Default is None. *Optional*.
7171
- **sequential_order (bool, optional)** – If True, the layers are visualized in the order they are defined in the model. If false it will first go through the first layer defined in the arrDefault is True. *Optional*.
72+
- **print_image (bool, optional)** – If True the Images are getting printed with matplotlib. Default is False. *Optional*.
7273

7374
**Return** The function 'get_feature_maps()` returns the pictures as NumPy Arrays
7475

@@ -103,7 +104,7 @@ pretrained_model = models.vgg16(pretrained=True)
103104
input_image_path = 'hamburger.jpg'
104105
layers_to_check= [nn.MaxPool2d]
105106

106-
numpyArr = get_feature_maps(model = pretrained_model, layers_to_check = layers_to_check, input_image_path = input_image_path, sequential_order = False)
107+
numpyArr = get_feature_maps(model = pretrained_model, layers_to_check = layers_to_check, input_image_path = input_image_path, sequential_order = False, print_image = True)
107108
```
108109

109110
### Output

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
with open("README.md", "r") as f:
44
LONG_DESCRIPTION = f.read()
55

6-
VERSION = '1.2.1'
6+
VERSION = '1.2.2'
77
DESCRIPTION = "PyTorchLayerViz is a Python library that allows you to visualize the weights and feature maps of a PyTorch model."
88

99
setup(
3.38 KB
Binary file not shown.
147 KB
Binary file not shown.
1.53 MB
Binary file not shown.
1.53 MB
Binary file not shown.

test/test.ipynb

Lines changed: 42 additions & 103 deletions
Large diffs are not rendered by default.

test/test_get_feature_maps.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
import matplotlib.pyplot as plt
3+
from PIL import Image
4+
import torch
5+
from torch import nn
6+
from torchvision import datasets, transforms, models
7+
from torchvision.transforms import ToTensor
8+
from PIL import Image
9+
import numpy as np
10+
import sys
11+
import os
12+
sys.path.append(os.path.abspath(os.path.join('..', 'PyTorchLayerViz')))
13+
14+
from main import get_feature_maps
15+
16+
pretrained_model = models.vgg16(pretrained=True)
17+
input_image_path = 'brain.tif'
18+
19+
torch.manual_seed(42)
20+
21+
transform = transforms.Compose([
22+
transforms.Resize((256, 256)),
23+
transforms.CenterCrop(224),
24+
transforms.ToTensor(),
25+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26+
])
27+
28+
29+
30+
class TestGetFeatureMaps(unittest.TestCase):
31+
def test_maxPool(self):
32+
layers_to_check= [nn.MaxPool2d]
33+
numpyArr_from_function = get_feature_maps(model = pretrained_model, layers_to_check = layers_to_check, input_image_path = input_image_path)
34+
35+
numpyArr_from_file = np.load('output_images/test_maxPool.npy')
36+
37+
np.testing.assert_array_equal(numpyArr_from_function, numpyArr_from_file)
38+
def test_moreLayers(self):
39+
layers_to_check= [nn.MaxPool2d, nn.Conv2d]
40+
numpyArr_from_function = get_feature_maps(model = pretrained_model, layers_to_check = layers_to_check, input_image_path = input_image_path)
41+
42+
numpyArr_from_file = np.load('output_images/test_moreLayers.npy')
43+
44+
np.testing.assert_array_equal(numpyArr_from_function, numpyArr_from_file)
45+
def test_transform(self):
46+
layers_to_check= [nn.MaxPool2d, nn.Conv2d]
47+
numpyArr_from_function = get_feature_maps(model = pretrained_model, layers_to_check = layers_to_check, input_image_path = input_image_path, transform = transform)
48+
49+
numpyArr_from_file = np.load('output_images/test_transform.npy')
50+
51+
np.testing.assert_array_equal(numpyArr_from_function, numpyArr_from_file)
52+
53+
if __name__ == '__main__':
54+
unittest.main()

0 commit comments

Comments
 (0)