@@ -143,20 +143,23 @@ def get_feature_maps(
143143 # Define the image transformations
144144 transform = transforms .Compose (
145145 [
146- transforms .Resize ((224 , 224 )),
147- transforms .ToTensor (),
148- # transforms.Normalize(mean=0., std=1.)
146+ transforms .Resize ((224 , 224 )),
147+ transforms .ToTensor (),
148+ # transforms.Normalize(mean=0., std=1.)
149149 ]
150150 )
151151
152152 # Load and preprocess the image
153- input_image = Image .open (input_image_path )
153+ input_image = Image .open (input_image_path )
154154 input_image = transform (input_image )
155- input_image = input_image .unsqueeze (0 )
155+ input_image = input_image .unsqueeze (0 )
156+
157+ # **Move the input to the same device as the model**
158+ device = next (model .parameters ()).device
159+ input_image = input_image .to (device )
156160
157161 activations = {}
158162 # Register hooks
159-
160163 def get_activation (name ):
161164 def hook (model , input , output ):
162165 activations [name ] = output .detach ()
@@ -182,7 +185,6 @@ def hook(model, input, output):
182185 if feature_map .ndim == 2 : # Grayscale feature map
183186 colored_feature_map = apply_colormap (feature_map )
184187 images .append ((colored_feature_map , layer_name ))
185- # RGB feature map
186- elif feature_map .ndim == 3 and feature_map .shape [0 ] == 3 :
188+ elif feature_map .ndim == 3 and feature_map .shape [0 ] == 3 : # RGB feature map
187189 images .append ((np .transpose (feature_map , (1 , 2 , 0 )), layer_name ))
188- return images
190+ return images
0 commit comments