Skip to content

Commit af70639

Browse files
author
anandawolz
committed
fix(batch): prevent overwriting batch outputs
1 parent 2aa0e19 commit af70639

2 files changed

Lines changed: 25 additions & 17 deletions

File tree

cebra_lens/activations.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,9 @@ def _cut_array(array: npt.NDArray,
3535
end = cut_indices[1]
3636
if start == 0 and end == 0:
3737
# If both start and end are 0, take the whole array
38-
sliced_array = array
39-
else:
40-
# Otherwise, slice the array
41-
sliced_array = array[:, start:end if end != 0 else start:]
42-
return sliced_array
43-
38+
return array
39+
# Otherwise, slice the array
40+
return array[:, :, start : (end if end != 0 else None)]
4441

4542
def get_cut_indices(
4643
model_: cebra.integrations.sklearn.cebra.CEBRA,
@@ -156,16 +153,28 @@ def get_activations_model(
156153
# remove all handles to avoid activation's problems
157154
for handle in handles:
158155
handle.remove()
159-
156+
160157
if hasattr(model, "pad_before_transform"):
161-
pad_before_transform = model.pad_before_transform
162-
158+
pad_before_transform = model.pad_before_transform
159+
163160
if pad_before_transform:
164-
# Padding logic: calculate the total reduction which happens based on the
165-
# kernel size per layer, divide the reduction per layer into 2 parts
166161
cut_indices = get_cut_indices(model_, layer_type, conv_layer_info)
167-
for i, (key, value) in enumerate(activations.items()):
168-
activations[key] = _cut_array(value, cut_indices[i])
162+
else:
163+
cut_indices = [(0,0)] * len(handles)
164+
165+
for i, (key, batch_list) in enumerate(list(activations.items())):
166+
if not isinstance(batch_list, list):
167+
continue
168+
sliced_chunks = [
169+
_cut_array(chunk, cut_indices[i])
170+
for chunk in batch_list
171+
]
172+
# now every chunk.shape == (1, channels, common_time)
173+
activations[key] = np.concatenate(sliced_chunks, axis=2)
174+
175+
for key, arr in list(activations.items()):
176+
if arr.ndim == 3 and arr.shape[0] == 1:
177+
activations[key] = arr[0]
169178

170179
return activations
171180

@@ -236,10 +245,10 @@ def _get_activation(activations_keys_prefix: str, activations: Dict):
236245
activations : Dict
237246
The dictionary where the activations will be stored. The key is the name of the layer, and the value is the activations.
238247
"""
239-
248+
activations.setdefault(activations_keys_prefix, [])
240249
def hook(model, input, output):
241-
activations[activations_keys_prefix] = output.detach().squeeze().numpy(
242-
)
250+
arr = output.detach().cpu().numpy()
251+
activations[activations_keys_prefix].append(arr)
243252

244253
return hook, activations
245254

cebra_lens/utils_wrapper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def transform(model, data, label, **transform_kwargs):
1313
"Model must be an instance of cebra.solver.UnifiedSolver",
1414
f"or cebra.integrations.sklearn.cebra.CEBRA, got {type(model)} instead.",
1515
)
16-
1716
return embedding
1817

1918

0 commit comments

Comments
 (0)