Skip to content

Commit bc93def

Browse files
committed
CUDA Device is now Supported
1 parent 5f9c957 commit bc93def

2 files changed

Lines changed: 12 additions & 10 deletions

File tree

PyTorchLayerViz/main.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,20 +143,23 @@ def get_feature_maps(
143143
# Define the image transformations
144144
transform = transforms.Compose(
145145
[
146-
transforms.Resize((224, 224)),
147-
transforms.ToTensor(),
148-
# transforms.Normalize(mean=0., std=1.)
146+
transforms.Resize((224, 224)),
147+
transforms.ToTensor(),
148+
# transforms.Normalize(mean=0., std=1.)
149149
]
150150
)
151151

152152
# Load and preprocess the image
153-
input_image = Image.open(input_image_path)
153+
input_image = Image.open(input_image_path)
154154
input_image = transform(input_image)
155-
input_image = input_image.unsqueeze(0)
155+
input_image = input_image.unsqueeze(0)
156+
157+
# **Move the input to the same device as the model**
158+
device = next(model.parameters()).device
159+
input_image = input_image.to(device)
156160

157161
activations = {}
158162
# Register hooks
159-
160163
def get_activation(name):
161164
def hook(model, input, output):
162165
activations[name] = output.detach()
@@ -182,7 +185,6 @@ def hook(model, input, output):
182185
if feature_map.ndim == 2: # Grayscale feature map
183186
colored_feature_map = apply_colormap(feature_map)
184187
images.append((colored_feature_map, layer_name))
185-
# RGB feature map
186-
elif feature_map.ndim == 3 and feature_map.shape[0] == 3:
188+
elif feature_map.ndim == 3 and feature_map.shape[0] == 3: # RGB feature map
187189
images.append((np.transpose(feature_map, (1, 2, 0)), layer_name))
188-
return images
190+
return images

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
with open("README.md", "r") as f:
44
LONG_DESCRIPTION = f.read()
55

6-
VERSION = '1.2.4'
6+
VERSION = '1.2.5'
77
DESCRIPTION = "PyTorchLayerViz is a Python library that allows you to visualize the weights and feature maps of a PyTorch model."
88

99
setup(

0 commit comments

Comments
 (0)