11"""Functions to retrieve and handle layer activations"""
22
3+ from typing import Dict , List , Optional , Tuple , Type
4+
35import cebra
4- import torch
5- import torch .nn as nn
6+ import matplotlib .pyplot as plt
67import numpy as np
78import numpy .typing as npt
8- from typing import Tuple , Dict , List , Type , Optional
9- from .matplotlib import plot_activations
10- import matplotlib .pyplot as plt
9+ import torch
10+ import torch .nn as nn
11+
12+ from .utils_plot import plot_activations
1113
1214
13- def _cut_array (
14- array : npt .NDArray , cut_indices : Tuple [np .int64 , np .int64 ]
15- ) -> npt .NDArray :
15+ def _cut_array (array : npt .NDArray ,
16+ cut_indices : Tuple [np .int64 , np .int64 ]) -> npt .NDArray :
1617 """
1718 Slices the input array based on the provided cut indices.
1819 This is used to remove the padding from activations in `get_activations_model`.
@@ -36,7 +37,7 @@ def _cut_array(
3637 sliced_array = array
3738 else :
3839 # Otherwise, slice the array
39- sliced_array = array [:, start : end if end != 0 else start :]
40+ sliced_array = array [:, start : end if end != 0 else start :]
4041 return sliced_array
4142
4243
@@ -80,10 +81,13 @@ def get_cut_indices(
8081 # add for output layer
8182 cut_indices .append ((0 , 0 ))
8283 elif layer_type == None :
83- raise NotImplementedError ("Padding handling not implemented for 'all'." )
84+ raise NotImplementedError (
85+ "Padding handling not implemented to handle activations for all layer types." ,
86+ "Set layer_type to nn.Conv1d to use the default padding handling." )
8487 else :
8588 # need to analyze the padding from the last output of Conv1 and apply the same cut
86- raise NotImplementedError (f"Padding handling not implemented for { layer_type } ." )
89+ raise NotImplementedError (
90+ f"Padding handling not implemented for { layer_type } ." )
8791 return cut_indices
8892
8993
@@ -93,7 +97,7 @@ def get_activations_model(
9397 session_id : int = - 1 ,
9498 name : str = "single" ,
9599 instance : int = 0 ,
96- layer_type : Type [nn .Module ] = None ,
100+ layer_type : Type [nn .Module ] = nn . Conv1d ,
97101) -> Dict [str , npt .NDArray ]:
98102 """
99103 Extracts activations from a single model layer.
@@ -111,7 +115,8 @@ def get_activations_model(
111115 instance : int
112116 The instance number for the model, used to differentiate between models from the same model category.
113117 layer_type : Type[nn.Module]
114- The type of layer to extract activations from. Defaults to None, meaning extracts activations from all layers.
118+ The type of layer to extract activations from. None means it extracts activations from all layers.
119+ Default is nn.Conv1d, which is the most common layer type used in CEBRA models.
115120
116121 Returns:
117122 --------
@@ -125,26 +130,25 @@ def get_activations_model(
125130 activations = {}
126131 transform_kwargs = {}
127132 if model .solver_name_ in [
128- "multi-session" ,
129- "multi-session-aux" ,
130- "multiobjective-solver" ,
133+ "multi-session" ,
134+ "multi-session-aux" ,
135+ "multiobjective-solver" ,
131136 ]:
132137
133138 model_ = model .model_ [session_id ]
134139 transform_kwargs .update ({"session_id" : session_id })
135140
136141 elif model .solver_name_ in [
137- "single-session" ,
138- "single-session-aux" ,
139- "single-session-hybrid" ,
140- "single-session-full" ,
142+ "single-session" ,
143+ "single-session-aux" ,
144+ "single-session-hybrid" ,
145+ "single-session-full" ,
141146 ]:
142147 model_ = model .model_
143148
144149 else :
145150 raise NotImplementedError (
146- f"Solver { model .solver_name_ } is not yet implemented."
147- )
151+ f"Solver { model .solver_name_ } is not yet implemented." )
148152
149153 activations , handles , conv_layer_info = _attach_hooks (
150154 activations = activations ,
@@ -209,14 +213,14 @@ def process_activations(
209213 name = model_name ,
210214 instance = i ,
211215 layer_type = layer_type ,
212- )
213- )
216+ ))
214217
215218 return activations
216219
217220
218221# Function to create a hook that stores the activations in the dictionary
219222def _get_activation (name : str , activations : Dict ):
223+
220224 def hook (model , input , output ):
221225 activations [name ] = output .detach ().squeeze ().numpy ()
222226
@@ -262,8 +266,7 @@ def _attach_hooks(
262266 # attach hook to the layer_type and to the output layer
263267 if isinstance (model .net [i ], layer_type ) or i == len (model .net ) - 1 :
264268 hook , activations = _get_activation (
265- f"{ name } _{ instance } _layer_{ num_layer } " , activations
266- )
269+ f"{ name } _{ instance } _layer_{ num_layer } " , activations )
267270 if isinstance (model .net [i ], layer_type ):
268271 conv_layer_info .append (model .net [i ].kernel_size [0 ])
269272 handle = model .net [i ].register_forward_hook (hook )
@@ -298,8 +301,7 @@ def _attach_hooks(
298301
299302 else :
300303 hook , activations = _get_activation (
301- f"{ name } _{ instance } _layer_{ num_layer } " , activations
302- )
304+ f"{ name } _{ instance } _layer_{ num_layer } " , activations )
303305
304306 handle = model .net [i ].register_forward_hook (hook )
305307 handles .append (handle )
@@ -309,8 +311,7 @@ def _attach_hooks(
309311
310312
311313def aggregate_activations (
312- activations : Dict [str , npt .NDArray ],
313- ) -> Dict [str , npt .NDArray ]:
314+ activations : Dict [str , npt .NDArray ], ) -> Dict [str , npt .NDArray ]:
314315 """
315316 Aggregates activations by model identifier aka. instance.
316317 This function takes a dictionary of activations where the keys are strings containing model identifiers and layer information,
@@ -387,8 +388,7 @@ def get_activations(
387388 activations = activations or {}
388389
389390 aggregated_activations = aggregate_activations (
390- process_activations (models , data , session_id , activations , layer_type )
391- )
391+ process_activations (models , data , session_id , activations , layer_type ))
392392
393393 activations_dict = {}
394394 for key , value in aggregated_activations .items ():
0 commit comments