99import torch
1010import torch .nn as nn
1111
12+ from cebra_lens import utils_wrapper
13+
1214from .utils_plot import plot_activations
1315
1416
@@ -83,7 +85,8 @@ def get_cut_indices(
8385 elif layer_type == None :
8486 raise NotImplementedError (
8587 "Padding handling not implemented to handle activations for all layer types." ,
86- "Set layer_type to nn.Conv1d to use the default padding handling." )
88+ "Set layer_type to nn.Conv1d to use the default padding handling." ,
89+ )
8790 else :
8891 # need to analyze the padding from the last output of Conv1 and apply the same cut
8992 raise NotImplementedError (
@@ -94,8 +97,10 @@ def get_cut_indices(
9497def get_activations_model (
9598 model : cebra .integrations .sklearn .cebra .CEBRA ,
9699 data : torch .Tensor ,
97- session_id : int = - 1 ,
98- name : str = "single" ,
100+ labels : Optional [torch .Tensor ] = None ,
101+ session_id : int = None ,
102+ pad_before_transform : bool = True ,
103+ activations_keys_prefix : str = "model" ,
99104 instance : int = 0 ,
100105 layer_type : Type [nn .Module ] = nn .Conv1d ,
101106) -> Dict [str , npt .NDArray ]:
@@ -129,42 +134,41 @@ def get_activations_model(
129134
130135 activations = {}
131136 transform_kwargs = {}
132- if model .solver_name_ in [
133- "multi-session" ,
134- "multi-session-aux" ,
135- "multiobjective-solver" ,
136- ]:
137-
138- model_ = model .model_ [session_id ]
139- transform_kwargs .update ({"session_id" : session_id })
140-
141- elif model .solver_name_ in [
142- "single-session" ,
143- "single-session-aux" ,
144- "single-session-hybrid" ,
145- "single-session-full" ,
146- ]:
147- model_ = model .model_
148137
138+ if isinstance (model , cebra .integrations .sklearn .cebra .CEBRA ):
139+ model_ = model .solver_ ._get_model (session_id = session_id )
140+ elif isinstance (model , cebra .solver .Solver ):
141+ model_ = model ._get_model (session_id = session_id )
149142 else :
150- raise NotImplementedError (
151- f"Solver { model .solver_name_ } is not yet implemented." )
143+ raise ValueError (
144+ "Model must be an instance of cebra.integrations.sklearn.cebra.CEBRA "
145+ f"or cebra.solver.Solver, got { type (model )} instead." , )
146+
147+ transform_kwargs .update ({"session_id" : session_id })
152148
153149 activations , handles , conv_layer_info = _attach_hooks (
154150 activations = activations ,
155151 model = model_ ,
156- name = name ,
152+ activations_keys_prefix = activations_keys_prefix ,
157153 instance = instance ,
158154 layer_type = layer_type ,
159155 )
160- _ = model .transform (data , ** transform_kwargs )
156+
157+ _ = utils_wrapper .transform (model = model ,
158+ data = data ,
159+ label = labels ,
160+ ** transform_kwargs )
161161
162162 # remove all handles to avoid activation's problems
163163 for handle in handles :
164164 handle .remove ()
165165
166- if model .pad_before_transform :
167- # Padding logic: calculate the total reduction which happens based on the kernel size per layer, divide the reduction per layer into 2 parts
166+ if hasattr (model , "pad_before_transform" ):
167+ pad_before_transform = model .pad_before_transform
168+
169+ if pad_before_transform :
170+ # Padding logic: calculate the total reduction which happens based on the
171+ # kernel size per layer, divide the reduction per layer into 2 parts
168172 cut_indices = get_cut_indices (model_ , layer_type , conv_layer_info )
169173 for i , (key , value ) in enumerate (activations .items ()):
170174 activations [key ] = _cut_array (value , cut_indices [i ])
@@ -175,7 +179,9 @@ def get_activations_model(
175179def process_activations (
176180 models : Dict [str , List [cebra .integrations .sklearn .cebra .CEBRA ]],
177181 data : torch .Tensor ,
178- session_id : int ,
182+ labels : Optional [torch .Tensor ] = None ,
183+ session_id : int = None ,
184+ pad_before_transform : bool = True ,
179185 activations : Dict [str , npt .NDArray ] = {},
180186 layer_type : Type [nn .Module ] = None ,
181187) -> Dict [str , npt .NDArray ]:
@@ -209,8 +215,10 @@ def process_activations(
209215 get_activations_model (
210216 model = model ,
211217 data = data ,
218+ labels = labels ,
212219 session_id = session_id ,
213- name = model_name ,
220+ pad_before_transform = pad_before_transform ,
221+ activations_keys_prefix = model_name ,
214222 instance = i ,
215223 layer_type = layer_type ,
216224 ))
@@ -219,18 +227,37 @@ def process_activations(
219227
220228
221229# Function to create a hook that stores the activations in the dictionary
222- def _get_activation (name : str , activations : Dict ):
230+ def _get_activation (activations_keys_prefix : str , activations : Dict ):
231+ """Creates a forward hook to capture activations from a model layer.
232+
233+ This function returns a hook that captures the output of a model layer during the forward pass and stores it in a dictionary.
234+
235+ Parameters:
236+ -----------
237+ activations_keys_prefix : str
238+ The prefix to use for the activation key, corresopnding to the type of model (eg. "single", "multi").
239+ activations : Dict
240+ A dictionary to store the activations. The key will be the name of the layer, and the value will be the activations.
241+
242+ Returns:
243+ --------
244+ hook : function
245+ A forward hook function that captures the activations.
246+ activations : Dict
247+ The dictionary where the activations will be stored. The key is the name of the layer, and the value is the activations.
248+ """
223249
224250 def hook (model , input , output ):
225- activations [name ] = output .detach ().squeeze ().numpy ()
251+ activations [activations_keys_prefix ] = output .detach ().squeeze ().numpy (
252+ )
226253
227254 return hook , activations
228255
229256
230257def _attach_hooks (
231258 activations : Dict [str , npt .NDArray ],
232259 model : cebra .integrations .sklearn .cebra .CEBRA ,
233- name : str ,
260+ activations_keys_prefix : str ,
234261 instance : int ,
235262 layer_type : Type [nn .Module ] = None ,
236263) -> Dict [str , npt .NDArray ]: # only attaches hooks on convolutional layers
@@ -244,8 +271,10 @@ def _attach_hooks(
244271 A dictionary to store the activations. Please refer to ``activations`` returned by ``get_activations_model``.
245272 model : cebra.integrations.sklearn.cebra.CEBRA
246273 The model to which hooks will be attached.
247- name : str
248- A base name for the activation keys (e.g., "single", "multi").
274+ activations_keys_prefix : str
275+ A base name for the activation keys (e.g., "single", "multi") so that the keys are
276+ unique for each model instance. The keys will be in the format
277+ '{activations_keys_prefix}_{instance}_layer_{num_layer}'.
249278 instance : int
250279 The instance number for the model, used to differentiate between models from the same model category.
251280 layer_type : Type[nn.Module]
@@ -266,7 +295,8 @@ def _attach_hooks(
266295 # attach hook to the layer_type and to the output layer
267296 if isinstance (model .net [i ], layer_type ) or i == len (model .net ) - 1 :
268297 hook , activations = _get_activation (
269- f"{ name } _{ instance } _layer_{ num_layer } " , activations )
298+ f"{ activations_keys_prefix } _{ instance } _layer_{ num_layer } " ,
299+ activations )
270300 if isinstance (model .net [i ], layer_type ):
271301 conv_layer_info .append (model .net [i ].kernel_size [0 ])
272302 handle = model .net [i ].register_forward_hook (hook )
@@ -277,7 +307,7 @@ def _attach_hooks(
277307 for submodule in model .net [i ].modules ():
278308 if isinstance (submodule , layer_type ):
279309 hook , activations = _get_activation (
280- f"{ name } _{ instance } _layer_{ num_layer } " ,
310+ f"{ activations_keys_prefix } _{ instance } _layer_{ num_layer } " ,
281311 activations ,
282312 )
283313 conv_layer_info .append (submodule .kernel_size [0 ])
@@ -292,7 +322,7 @@ def _attach_hooks(
292322 if bool (model .net [i ]._modules ):
293323 for submodule in model .net [i ].modules ():
294324 hook , activations = _get_activation (
295- f"{ name } _{ instance } _layer_{ num_layer } " ,
325+ f"{ activations_keys_prefix } _{ instance } _layer_{ num_layer } " ,
296326 activations ,
297327 )
298328 handle = submodule .register_forward_hook (hook )
@@ -301,7 +331,8 @@ def _attach_hooks(
301331
302332 else :
303333 hook , activations = _get_activation (
304- f"{ name } _{ instance } _layer_{ num_layer } " , activations )
334+ f"{ activations_keys_prefix } _{ instance } _layer_{ num_layer } " ,
335+ activations )
305336
306337 handle = model .net [i ].register_forward_hook (hook )
307338 handles .append (handle )
@@ -356,9 +387,11 @@ def aggregate_activations(
356387def get_activations (
357388 models : Dict [str , List [cebra .integrations .sklearn .cebra .CEBRA ]],
358389 data : torch .Tensor ,
359- session_id : int ,
390+ labels : Optional [torch .Tensor ] = None ,
391+ session_id : int = None ,
392+ pad_before_transform : bool = True ,
360393 activations : Optional [Dict [str , npt .NDArray ]] = None ,
361- layer_type : Optional [Type [nn .Module ]] = None ,
394+ layer_type : Optional [Type [nn .Module ]] = nn . Conv1d ,
362395) -> Dict [str , npt .NDArray ]:
363396 """
364397 Extracts and organizes activations from models.
@@ -388,7 +421,15 @@ def get_activations(
388421 activations = activations or {}
389422
390423 aggregated_activations = aggregate_activations (
391- process_activations (models , data , session_id , activations , layer_type ))
424+ process_activations (
425+ models = models ,
426+ data = data ,
427+ labels = labels ,
428+ session_id = session_id ,
429+ pad_before_transform = pad_before_transform ,
430+ activations = activations ,
431+ layer_type = layer_type ,
432+ ))
392433
393434 activations_dict = {}
394435 for key , value in aggregated_activations .items ():
0 commit comments