Skip to content

Commit 382ac10

Browse files
committed
Adapt so that runs for unified CEBRA and simplification
1 parent d6a02e2 commit 382ac10

19 files changed

Lines changed: 1277 additions & 566 deletions

cebra_lens/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .utils_allen import *
1010
from .utils_hpc import *
1111
from .utils_plot import *
12+
from .utils_wrapper import *
1213

1314
# selects what files can be imported when doing from CEBRA_Lens import * --> keep env clean
1415
# __all__ = ['get_layer_activations']

cebra_lens/activations.py

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import torch
1010
import torch.nn as nn
1111

12+
from cebra_lens import utils_wrapper
13+
1214
from .utils_plot import plot_activations
1315

1416

@@ -83,7 +85,8 @@ def get_cut_indices(
8385
elif layer_type == None:
8486
raise NotImplementedError(
8587
"Padding handling not implemented to handle activations for all layer types.",
86-
"Set layer_type to nn.Conv1d to use the default padding handling.")
88+
"Set layer_type to nn.Conv1d to use the default padding handling.",
89+
)
8790
else:
8891
# need to analyze the padding from the last output of Conv1 and apply the same cut
8992
raise NotImplementedError(
@@ -94,8 +97,10 @@ def get_cut_indices(
9497
def get_activations_model(
9598
model: cebra.integrations.sklearn.cebra.CEBRA,
9699
data: torch.Tensor,
97-
session_id: int = -1,
98-
name: str = "single",
100+
labels: Optional[torch.Tensor] = None,
101+
session_id: int = None,
102+
pad_before_transform: bool = True,
103+
activations_keys_prefix: str = "model",
99104
instance: int = 0,
100105
layer_type: Type[nn.Module] = nn.Conv1d,
101106
) -> Dict[str, npt.NDArray]:
@@ -129,42 +134,41 @@ def get_activations_model(
129134

130135
activations = {}
131136
transform_kwargs = {}
132-
if model.solver_name_ in [
133-
"multi-session",
134-
"multi-session-aux",
135-
"multiobjective-solver",
136-
]:
137-
138-
model_ = model.model_[session_id]
139-
transform_kwargs.update({"session_id": session_id})
140-
141-
elif model.solver_name_ in [
142-
"single-session",
143-
"single-session-aux",
144-
"single-session-hybrid",
145-
"single-session-full",
146-
]:
147-
model_ = model.model_
148137

138+
if isinstance(model, cebra.integrations.sklearn.cebra.CEBRA):
139+
model_ = model.solver_._get_model(session_id=session_id)
140+
elif isinstance(model, cebra.solver.Solver):
141+
model_ = model._get_model(session_id=session_id)
149142
else:
150-
raise NotImplementedError(
151-
f"Solver {model.solver_name_} is not yet implemented.")
143+
raise ValueError(
144+
"Model must be an instance of cebra.integrations.sklearn.cebra.CEBRA "
145+
f"or cebra.solver.Solver, got {type(model)} instead.", )
146+
147+
transform_kwargs.update({"session_id": session_id})
152148

153149
activations, handles, conv_layer_info = _attach_hooks(
154150
activations=activations,
155151
model=model_,
156-
name=name,
152+
activations_keys_prefix=activations_keys_prefix,
157153
instance=instance,
158154
layer_type=layer_type,
159155
)
160-
_ = model.transform(data, **transform_kwargs)
156+
157+
_ = utils_wrapper.transform(model=model,
158+
data=data,
159+
label=labels,
160+
**transform_kwargs)
161161

162162
# remove all handles to avoid activation's problems
163163
for handle in handles:
164164
handle.remove()
165165

166-
if model.pad_before_transform:
167-
# Padding logic: calculate the total reduction which happens based on the kernel size per layer, divide the reduction per layer into 2 parts
166+
if hasattr(model, "pad_before_transform"):
167+
pad_before_transform = model.pad_before_transform
168+
169+
if pad_before_transform:
170+
# Padding logic: calculate the total reduction which happens based on the
171+
# kernel size per layer, divide the reduction per layer into 2 parts
168172
cut_indices = get_cut_indices(model_, layer_type, conv_layer_info)
169173
for i, (key, value) in enumerate(activations.items()):
170174
activations[key] = _cut_array(value, cut_indices[i])
@@ -175,7 +179,9 @@ def get_activations_model(
175179
def process_activations(
176180
models: Dict[str, List[cebra.integrations.sklearn.cebra.CEBRA]],
177181
data: torch.Tensor,
178-
session_id: int,
182+
labels: Optional[torch.Tensor] = None,
183+
session_id: int = None,
184+
pad_before_transform: bool = True,
179185
activations: Dict[str, npt.NDArray] = {},
180186
layer_type: Type[nn.Module] = None,
181187
) -> Dict[str, npt.NDArray]:
@@ -209,8 +215,10 @@ def process_activations(
209215
get_activations_model(
210216
model=model,
211217
data=data,
218+
labels=labels,
212219
session_id=session_id,
213-
name=model_name,
220+
pad_before_transform=pad_before_transform,
221+
activations_keys_prefix=model_name,
214222
instance=i,
215223
layer_type=layer_type,
216224
))
@@ -219,18 +227,37 @@ def process_activations(
219227

220228

221229
# Function to create a hook that stores the activations in the dictionary
222-
def _get_activation(name: str, activations: Dict):
230+
def _get_activation(activations_keys_prefix: str, activations: Dict):
231+
"""Creates a forward hook to capture activations from a model layer.
232+
233+
This function returns a hook that captures the output of a model layer during the forward pass and stores it in a dictionary.
234+
235+
Parameters:
236+
-----------
237+
activations_keys_prefix : str
238+
The prefix to use for the activation key, corresopnding to the type of model (eg. "single", "multi").
239+
activations : Dict
240+
A dictionary to store the activations. The key will be the name of the layer, and the value will be the activations.
241+
242+
Returns:
243+
--------
244+
hook : function
245+
A forward hook function that captures the activations.
246+
activations : Dict
247+
The dictionary where the activations will be stored. The key is the name of the layer, and the value is the activations.
248+
"""
223249

224250
def hook(model, input, output):
225-
activations[name] = output.detach().squeeze().numpy()
251+
activations[activations_keys_prefix] = output.detach().squeeze().numpy(
252+
)
226253

227254
return hook, activations
228255

229256

230257
def _attach_hooks(
231258
activations: Dict[str, npt.NDArray],
232259
model: cebra.integrations.sklearn.cebra.CEBRA,
233-
name: str,
260+
activations_keys_prefix: str,
234261
instance: int,
235262
layer_type: Type[nn.Module] = None,
236263
) -> Dict[str, npt.NDArray]: # only attaches hooks on convolutional layers
@@ -244,8 +271,10 @@ def _attach_hooks(
244271
A dictionary to store the activations. Please refer to ``activations`` returned by ``get_activations_model``.
245272
model : cebra.integrations.sklearn.cebra.CEBRA
246273
The model to which hooks will be attached.
247-
name : str
248-
A base name for the activation keys (e.g., "single", "multi").
274+
activations_keys_prefix : str
275+
A base name for the activation keys (e.g., "single", "multi") so that the keys are
276+
unique for each model instance. The keys will be in the format
277+
'{activations_keys_prefix}_{instance}_layer_{num_layer}'.
249278
instance : int
250279
The instance number for the model, used to differentiate between models from the same model category.
251280
layer_type : Type[nn.Module]
@@ -266,7 +295,8 @@ def _attach_hooks(
266295
# attach hook to the layer_type and to the output layer
267296
if isinstance(model.net[i], layer_type) or i == len(model.net) - 1:
268297
hook, activations = _get_activation(
269-
f"{name}_{instance}_layer_{num_layer}", activations)
298+
f"{activations_keys_prefix}_{instance}_layer_{num_layer}",
299+
activations)
270300
if isinstance(model.net[i], layer_type):
271301
conv_layer_info.append(model.net[i].kernel_size[0])
272302
handle = model.net[i].register_forward_hook(hook)
@@ -277,7 +307,7 @@ def _attach_hooks(
277307
for submodule in model.net[i].modules():
278308
if isinstance(submodule, layer_type):
279309
hook, activations = _get_activation(
280-
f"{name}_{instance}_layer_{num_layer}",
310+
f"{activations_keys_prefix}_{instance}_layer_{num_layer}",
281311
activations,
282312
)
283313
conv_layer_info.append(submodule.kernel_size[0])
@@ -292,7 +322,7 @@ def _attach_hooks(
292322
if bool(model.net[i]._modules):
293323
for submodule in model.net[i].modules():
294324
hook, activations = _get_activation(
295-
f"{name}_{instance}_layer_{num_layer}",
325+
f"{activations_keys_prefix}_{instance}_layer_{num_layer}",
296326
activations,
297327
)
298328
handle = submodule.register_forward_hook(hook)
@@ -301,7 +331,8 @@ def _attach_hooks(
301331

302332
else:
303333
hook, activations = _get_activation(
304-
f"{name}_{instance}_layer_{num_layer}", activations)
334+
f"{activations_keys_prefix}_{instance}_layer_{num_layer}",
335+
activations)
305336

306337
handle = model.net[i].register_forward_hook(hook)
307338
handles.append(handle)
@@ -356,9 +387,11 @@ def aggregate_activations(
356387
def get_activations(
357388
models: Dict[str, List[cebra.integrations.sklearn.cebra.CEBRA]],
358389
data: torch.Tensor,
359-
session_id: int,
390+
labels: Optional[torch.Tensor] = None,
391+
session_id: int = None,
392+
pad_before_transform: bool = True,
360393
activations: Optional[Dict[str, npt.NDArray]] = None,
361-
layer_type: Optional[Type[nn.Module]] = None,
394+
layer_type: Optional[Type[nn.Module]] = nn.Conv1d,
362395
) -> Dict[str, npt.NDArray]:
363396
"""
364397
Extracts and organizes activations from models.
@@ -388,7 +421,15 @@ def get_activations(
388421
activations = activations or {}
389422

390423
aggregated_activations = aggregate_activations(
391-
process_activations(models, data, session_id, activations, layer_type))
424+
process_activations(
425+
models=models,
426+
data=data,
427+
labels=labels,
428+
session_id=session_id,
429+
pad_before_transform=pad_before_transform,
430+
activations=activations,
431+
layer_type=layer_type,
432+
))
392433

393434
activations_dict = {}
394435
for key, value in aggregated_activations.items():

cebra_lens/quantification/cka_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def compute(self, activations: Dict[str, npt.NDArray],
212212
Parameters:
213213
-----------
214214
activations : Dict[str, npt.NDArray]
215-
A dictionary where keys are strings which represent the model label and values are 2d lists
215+
A dictionary where keys are strings which represent the model label and values are 2d lists
216216
with the corresponding activations per layer.
217217
218218
comparison : Tuple[str, str]

0 commit comments

Comments
 (0)