@@ -35,12 +35,9 @@ def _cut_array(array: npt.NDArray,
3535 end = cut_indices [1 ]
3636 if start == 0 and end == 0 :
3737 # If both start and end are 0, take the whole array
38- sliced_array = array
39- else :
40- # Otherwise, slice the array
41- sliced_array = array [:, start :end if end != 0 else start :]
42- return sliced_array
43-
38+ return array
39+ # Otherwise, slice the array
40+ return array [:, :, start : (end if end != 0 else None )]
4441
4542def get_cut_indices (
4643 model_ : cebra .integrations .sklearn .cebra .CEBRA ,
@@ -156,16 +153,28 @@ def get_activations_model(
156153 # remove all handles to avoid activation's problems
157154 for handle in handles :
158155 handle .remove ()
159-
156+
160157 if hasattr (model , "pad_before_transform" ):
161- pad_before_transform = model .pad_before_transform
162-
158+ pad_before_transform = model .pad_before_transform
159+
163160 if pad_before_transform :
164- # Padding logic: calculate the total reduction which happens based on the
165- # kernel size per layer, divide the reduction per layer into 2 parts
166161 cut_indices = get_cut_indices (model_ , layer_type , conv_layer_info )
167- for i , (key , value ) in enumerate (activations .items ()):
168- activations [key ] = _cut_array (value , cut_indices [i ])
162+ else :
163+ cut_indices = [(0 ,0 )] * len (handles )
164+
165+ for i , (key , batch_list ) in enumerate (list (activations .items ())):
166+ if not isinstance (batch_list , list ):
167+ continue
168+ sliced_chunks = [
169+ _cut_array (chunk , cut_indices [i ])
170+ for chunk in batch_list
171+ ]
172+ # now every chunk.shape == (1, channels, common_time)
173+ activations [key ] = np .concatenate (sliced_chunks , axis = 2 )
174+
175+ for key , arr in list (activations .items ()):
176+ if arr .ndim == 3 and arr .shape [0 ] == 1 :
177+ activations [key ] = arr [0 ]
169178
170179 return activations
171180
@@ -236,10 +245,10 @@ def _get_activation(activations_keys_prefix: str, activations: Dict):
236245 activations : Dict
237246 The dictionary where the activations will be stored. The key is the name of the layer, and the value is the activations.
238247 """
239-
248+ activations . setdefault ( activations_keys_prefix , [])
240249 def hook (model , input , output ):
241- activations [ activations_keys_prefix ] = output .detach ().squeeze ().numpy (
242- )
250+ arr = output .detach ().cpu ().numpy ()
251+ activations [ activations_keys_prefix ]. append ( arr )
243252
244253 return hook , activations
245254
0 commit comments