Skip to content

Commit 89b9ac7

Browse files
committed
Fix tests and add more
1 parent c3d11e2 commit 89b9ac7

12 files changed

Lines changed: 194 additions & 50 deletions

File tree

cebra_lens/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# example of structure so that you can directly use the functions get_layer_activations instead of having to do CEBRA_Lens.activations.get_layer_activations
22
from .activations import *
3-
from .quantification import *
4-
from .quantification.decoding import *
3+
from .quantification.decoder import *
54
from .quantification.distance import *
65
from .quantification.cka_metric import *
76
from .quantification.rdm_metric import *

cebra_lens/activations.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def get_cut_indices(
8080
cut_indices.append((0, 0))
8181
elif layer_type == None:
8282
raise NotImplementedError(
83-
"Padding handling not implemented for 'all'.")
83+
"Padding handling not implemented to handle activations for all layer types.",
84+
"Set layer_type to nn.Conv1d to use the default padding handling.")
8485
else:
8586
# need to analyze the padding from the last output of Conv1 and apply the same cut
8687
raise NotImplementedError(
@@ -94,7 +95,7 @@ def get_activations_model(
9495
session_id: int = -1,
9596
name: str = "single",
9697
instance: int = 0,
97-
layer_type: Type[nn.Module] = None,
98+
layer_type: Type[nn.Module] = nn.Conv1d,
9899
) -> Dict[str, npt.NDArray]:
99100
"""
100101
Extracts activations from a single model layer.
@@ -112,7 +113,8 @@ def get_activations_model(
112113
instance : int
113114
The instance number for the model, used to differentiate between models from the same model category.
114115
layer_type : Type[nn.Module]
115-
The type of layer to extract activations from. Defaults to None, meaning extracts activations from all layers.
116+
The type of layer to extract activations from. None means it extracts activations from all layers.
117+
Default is nn.Conv1d, which is the most common layer type used in CEBRA models.
116118
117119
Returns:
118120
--------

cebra_lens/quantification/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from .rdm_metric import *
33
from .misc import *
44
from .distance import *
5-
from .decoding import *
5+
from .decoder import *
66
from .base import *
77
from .tsne import *

cebra_lens/quantification/cka_metric.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from tqdm import tqdm
99
import numpy as np
1010
from .base import _BaseMetric
11-
from ..matplotlib import *
11+
import cebra_lens.matplotlib as cebra_lens_matplotlib
1212
from typing import Optional, List, Dict, Tuple
1313
import numpy.typing as npt
14+
import matplotlib
1415

1516

1617
class CKA(_BaseMetric):
@@ -188,7 +189,8 @@ def _compute_per_layer(
188189
cka_matrix = np.zeros((len(embeddings_1), len(embeddings_1[0])))
189190
for j in tqdm(range(len(embeddings_1))):
190191
if flag:
191-
# the situation when there multiple models inside model labels and the same number of models inside each label
192+
# the situation when there multiple models inside model labels and the same number of
193+
# models inside each label
192194
cka_matrix[j, :] = self._compute_cka(embeddings_1[j],
193195
embeddings_2[j])
194196
else:
@@ -207,7 +209,8 @@ def compute(self, activations: Dict[str, npt.NDArray],
207209
Parameters:
208210
-----------
209211
activations : Dict[str, npt.NDArray]
210-
A dictionary where keys are strings which represent the model label and values are 2d lists with the corresponding activations per layer.
212+
A dictionary where keys are strings which represent the model label and values are 2d lists
213+
with the corresponding activations per layer.
211214
212215
comparison : Tuple[str, str]
213216
A tuple containing the model labels to compare.
@@ -227,7 +230,8 @@ def compute(self, activations: Dict[str, npt.NDArray],
227230

228231
if len(activations_1) != len(activations_2):
229232
# if the number of models in a label is different from the other model label
230-
# choose embeddings_1 for the one with more models, and then embeddings_2 just compare with the first model
233+
# choose embeddings_1 for the one with more models, and then embeddings_2 just compare with
234+
# the first model
231235
if len(activations_1) > len(activations_2):
232236
embeddings_1 = activations_1
233237
embeddings_2 = activations_2[0]
@@ -293,7 +297,7 @@ def plot(
293297
matplotlib.axes.Axes
294298
The axes on which the heatmap is plotted.
295299
"""
296-
return plot_cka_heatmaps(
300+
return cebra_lens_matplotlib.plot_cka_heatmaps(
297301
cka_matrices,
298302
annot,
299303
show_cbar,
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from ..utils_hpc import decoding_pos_dir
66
from ..activations import get_activations_model
77
from .base import _BaseMetric
8-
from ..matplotlib import *
8+
import cebra_lens.matplotlib as cebra_lens_matplotlib
99
import numpy.typing as npt
10-
from typing import Dict, Type, Tuple
10+
from typing import Dict, Type, Tuple, Optional
1111
import torch.nn as nn
1212
import sklearn.metrics
1313
import torch as pt
14+
import matplotlib
1415

1516

1617
def decoding(
@@ -118,7 +119,7 @@ def __init__(
118119
test_label: npt.NDArray,
119120
session_id: int = 0,
120121
dataset_label: str = None,
121-
layer_type: Optional[Type[nn.Module]] = None,
122+
layer_type: Optional[Type[nn.Module]] = nn.Conv1d,
122123
output_only: bool = True,
123124
):
124125

@@ -315,7 +316,7 @@ def compute(
315316
def __name__(self):
316317
return "decode_by_layer"
317318

318-
def set_output_only(self, output_only):
319+
def set_output_only(self, output_only: bool) -> None:
319320
"""
320321
Set the output_only parameter to True or False. If True, it will compute the decoding scores for the output embeddings of the model, otherwise it will compute the decoding scores for the activations of the model.
321322
@@ -369,8 +370,8 @@ def plot(
369370
)
370371

371372
if self.output_only:
372-
return plot_decoding(results_dict, palette, self.dataset_label,
373+
return cebra_lens_matplotlib.plot_decoding(results_dict, palette, self.dataset_label,
373374
label, plot_error, ax)
374375
else:
375-
return plot_layer_decoding(results_dict, title, self.dataset_label,
376+
return cebra_lens_matplotlib.plot_layer_decoding(results_dict, title, self.dataset_label,
376377
label, plot_error, figsize)

cebra_lens/quantification/rdm_metric.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""All the functions relative to the Representation Dissimilarity Matrix (RDM) calculation"""
22

3+
from typing import Dict, List, Optional
34
import numpy as np
45
from scipy.linalg import block_diag
56
from typing import List, Optional, Tuple, Union
67
from scipy.spatial.distance import correlation, pdist, squareform
78
from .misc import discrete_binning, continuous_binning
89
import torch
10+
import matplotlib
911
from .base import _BaseMetric
10-
from ..matplotlib import *
12+
import cebra_lens.matplotlib as cebra_lens_matplotlib
1113
import numpy.typing as npt
1214

1315

@@ -21,8 +23,8 @@ class RDM(_BaseMetric):
2123
The data array of shape (num_samples, num_features).
2224
label : torch.Tensor
2325
The array of labels corresponding to the data.
24-
discrete : bool, optional
25-
Whether the labels are discrete or continuous. If None, it will be determined based on the dataset_label.
26+
is_discrete_labels : bool, optional
27+
Whether the labels are discrete or continuous. By default, it is False, meaning the labels are continuous.
2628
dataset_label : str, optional
2729
The dataset type, either 'visual' or 'HPC'. Default is 'visual'.
2830
metric : str, optional
@@ -37,7 +39,7 @@ def __init__(
3739
self,
3840
data: torch.Tensor,
3941
label: torch.Tensor,
40-
is_discrete_labels: bool = None,
42+
is_discrete_labels: bool = False,
4143
dataset_label: str = None,
4244
metric: str = "correlation",
4345
bool_oracle: bool = True,
@@ -254,9 +256,9 @@ def plot(
254256
The figure containing the plotted RDMs.
255257
"""
256258
if self.bool_oracle:
257-
return plot_rdm_correlation(rdms)
259+
return cebra_lens_matplotlib.plot_rdm_correlation(rdms)
258260
else:
259-
return plot_rdm_all(
261+
return cebra_lens_matplotlib.plot_rdm_all(
260262
rdms=rdms,
261263
labels=self.label,
262264
num_bins=self.num_bins,

cebra_lens/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy.typing as npt
77
from tqdm import tqdm
88
from torch import nn
9-
from .quantification.decoding import Decoding
9+
from .quantification.decoder import Decoding
1010
from .quantification.rdm_metric import RDM
1111
from .quantification.cka_metric import CKA
1212
from .quantification.tsne import Tsne

tests/test_activations.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import pytest
21
import torch
2+
import pytest
33
import numpy as np
44
from collections import namedtuple
55
from unittest.mock import MagicMock
@@ -22,7 +22,7 @@ def test_cut_array_with_cut():
2222
np.testing.assert_array_equal(result, np.array([[2, 3, 4]]))
2323

2424

25-
def test_get_cut_indices_conv1d():
25+
def test_get_cut_indices():
2626
Offset = namedtuple("Offset", ["left", "right"])
2727

2828
# Mock the model's get_offset behavior
@@ -32,6 +32,9 @@ def test_get_cut_indices_conv1d():
3232
result = get_cut_indices(model_mock, torch.nn.Conv1d, [3, 3])
3333
assert isinstance(result, list)
3434
assert all(isinstance(x, tuple) and len(x) == 2 for x in result)
35+
36+
with pytest.raises(NotImplementedError, match="Padding handling not implemented*"):
37+
get_cut_indices(model_mock, None, [3, 3])
3538

3639

3740
def make_mock_cebra_model():

tests/test_cka.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import pytest
2+
import numpy as np
3+
import torch
4+
from unittest.mock import patch, MagicMock
5+
import cebra_lens
6+
7+
@pytest.fixture
8+
def dummy_comparisons():
9+
return [("A", "B")]
10+
11+
@pytest.fixture
12+
def dummy_cka(dummy_comparisons):
13+
return cebra_lens.quantification.cka_metric.CKA(comparisons=dummy_comparisons)
14+
15+
@pytest.fixture
16+
def dummy_activations():
17+
# Simulate Conv1D and Linear layer activations as 2D arrays (samples, features)
18+
batch_size = 10
19+
conv_channels = 4
20+
conv_length = 8
21+
linear_features = 5
22+
23+
# Conv1D output: (batch_size, conv_channels, conv_length) -> flatten to (batch_size, conv_channels * conv_length)
24+
conv1d_activations_A = np.random.rand(batch_size, conv_channels, conv_length).reshape(batch_size, -1)
25+
linear_activations_A = np.random.rand(batch_size, linear_features)
26+
conv1d_activations_B = np.random.rand(batch_size, conv_channels, conv_length).reshape(batch_size, -1)
27+
linear_activations_B = np.random.rand(batch_size, linear_features)
28+
29+
# Each group has a list of 2D arrays (one per layer)
30+
return {
31+
"A": np.array([[np.random.rand(5, 10), np.random.rand(5, 10)],
32+
[np.random.rand(5, 10), np.random.rand(5, 10)]]),
33+
"B": np.array([[np.random.rand(5, 10), np.random.rand(5, 10)],
34+
[np.random.rand(5, 10), np.random.rand(5, 10)]]),
35+
}
36+
37+
def test_center_gram_symmetry(dummy_cka):
38+
mat = np.eye(5)
39+
centered = dummy_cka.center_gram(mat)
40+
assert np.allclose(centered, centered.T)
41+
42+
def test_center_gram_unbiased(dummy_cka):
43+
mat = np.eye(5)
44+
centered = dummy_cka.center_gram(mat, unbiased=True)
45+
assert np.allclose(centered, centered.T)
46+
47+
def test_gram_linear(dummy_cka):
48+
x = np.random.rand(10, 5)
49+
gram = dummy_cka.gram_linear(x)
50+
assert gram.shape == (10, 10)
51+
assert np.allclose(gram, gram.T)
52+
53+
def test_cka_value(dummy_cka):
54+
x = np.random.rand(10, 5)
55+
y = np.random.rand(10, 5)
56+
gram_x = dummy_cka.gram_linear(x)
57+
gram_y = dummy_cka.gram_linear(y)
58+
val = dummy_cka.cka(gram_x, gram_y)
59+
assert isinstance(val, float) or isinstance(val, np.floating)
60+
61+
def test_compute_cka_shape(dummy_cka):
62+
emb1 = [np.random.rand(5, 10), np.random.rand(5, 10)]
63+
emb2 = [np.random.rand(5, 10), np.random.rand(5, 10)]
64+
result = dummy_cka._compute_cka(emb1, emb2)
65+
assert result.shape == (1, 2)
66+
67+
def test_compute_per_layer_shape(dummy_cka):
68+
emb1 = [ [np.random.rand(5, 10), np.random.rand(5, 10)] for _ in range(3) ]
69+
emb2 = [ [np.random.rand(5, 10), np.random.rand(5, 10)] for _ in range(3) ]
70+
result = dummy_cka._compute_per_layer(emb1, emb2, flag=True)
71+
assert result.shape == (3, 2)
72+
73+
def test_compute(dummy_cka, dummy_activations):
74+
result = dummy_cka.compute(dummy_activations, ("A", "B"))
75+
assert isinstance(result, np.ndarray)
76+
77+
78+
def test_compute_intra_label(dummy_cka, dummy_activations):
79+
result = dummy_cka.compute(dummy_activations, ("A", "A"))
80+
assert isinstance(result, np.ndarray)
81+
82+
@patch("cebra_lens.matplotlib.plot_cka_heatmaps")
83+
def test_plot_calls_heatmap(mock_plot, dummy_cka):
84+
cka_matrices = {"A": np.random.rand(2, 2)}
85+
dummy_cka.plot(cka_matrices, annot=True)
86+
assert mock_plot.called

0 commit comments

Comments
 (0)