Skip to content

Commit d6a02e2

Browse files
committed
Fix imports with matplotlib
1 parent 8abc266 commit d6a02e2

12 files changed

Lines changed: 26 additions & 25 deletions

File tree

cebra_lens/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# example of structure so that you can directly use the functions get_layer_activations instead of having to do CEBRA_Lens.activations.get_layer_activations
22
from .activations import *
3-
from .matplotlib import *
43
from .quantification.cka_metric import *
54
from .quantification.decoder import *
65
from .quantification.distance import *
@@ -9,6 +8,7 @@
98
from .utils import *
109
from .utils_allen import *
1110
from .utils_hpc import *
11+
from .utils_plot import *
1212

1313
# selects what files can be imported when doing from CEBRA_Lens import * --> keep env clean
1414
# __all__ = ['get_layer_activations']

cebra_lens/activations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111

12-
from .matplotlib import plot_activations
12+
from .utils_plot import plot_activations
1313

1414

1515
def _cut_array(array: npt.NDArray,

cebra_lens/quantification/cka_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy.typing as npt
1313
from tqdm import tqdm
1414

15-
import cebra_lens.matplotlib as cebra_lens_matplotlib
15+
from cebra_lens import utils_plot
1616

1717
from .base import _BaseMetric
1818

@@ -300,7 +300,7 @@ def plot(
300300
matplotlib.axes.Axes
301301
The axes on which the heatmap is plotted.
302302
"""
303-
return cebra_lens_matplotlib.plot_cka_heatmaps(
303+
return utils_plot.plot_cka_heatmaps(
304304
cka_matrices,
305305
annot,
306306
show_cbar,

cebra_lens/quantification/decoder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch as pt
1010
import torch.nn as nn
1111

12-
import cebra_lens.matplotlib as cebra_lens_matplotlib
12+
from cebra_lens import utils_plot
1313

1414
from ..activations import get_activations_model
1515
from ..utils_allen import decoding_frames
@@ -373,10 +373,10 @@ def plot(
373373
)
374374

375375
if self.output_only:
376-
return cebra_lens_matplotlib.plot_decoding(results_dict, palette,
377-
self.dataset_label,
378-
label, plot_error, ax)
376+
return utils_plot.plot_decoding(results_dict, palette,
377+
self.dataset_label, label,
378+
plot_error, ax)
379379
else:
380-
return cebra_lens_matplotlib.plot_layer_decoding(
381-
results_dict, title, self.dataset_label, label, plot_error,
382-
figsize)
380+
return utils_plot.plot_layer_decoding(results_dict, title,
381+
self.dataset_label, label,
382+
plot_error, figsize)

cebra_lens/quantification/distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from scipy.spatial.distance import cdist, pdist
88
from sklearn.preprocessing import StandardScaler
99

10-
from ..matplotlib import *
1110
from ..utils import extract_label
11+
from ..utils_plot import *
1212
from .base import _BaseMetric
1313
from .misc import continuous_binning, discrete_binning, repetition_binning
1414

cebra_lens/quantification/rdm_metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from scipy.linalg import block_diag
1010
from scipy.spatial.distance import correlation, pdist, squareform
1111

12-
import cebra_lens.matplotlib as cebra_lens_matplotlib
12+
from cebra_lens import utils_plot
1313

1414
from .base import _BaseMetric
1515
from .misc import continuous_binning, discrete_binning
@@ -258,9 +258,9 @@ def plot(
258258
The figure containing the plotted RDMs.
259259
"""
260260
if self.bool_oracle:
261-
return cebra_lens_matplotlib.plot_rdm_correlation(rdms)
261+
return utils_plot.plot_rdm_correlation(rdms)
262262
else:
263-
return cebra_lens_matplotlib.plot_rdm_all(
263+
return utils_plot.plot_rdm_all(
264264
rdms=rdms,
265265
labels=self.label,
266266
num_bins=self.num_bins,

cebra_lens/quantification/tsne.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from typing import List, Optional, Union
44

5+
import matplotlib
56
import numpy as np
67
import numpy.typing as npt
78
from sklearn.manifold import TSNE
89

9-
from ..matplotlib import *
10+
from ..utils_plot import *
1011
from .base import _BaseMetric
1112

1213

demos/metric_template.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import numpy as np
2-
from ..cebra_lens.quantification.base import _BaseMetric
3-
from ..cebra_lens.matplotlib import *
2+
from cebra_lens.quantification.base import _BaseMetric
3+
from cebra_lens.utils_plot import *
44
from typing import List, Optional, Union
55
import numpy.typing as npt
6-
6+
import matplotlib
77

88
class NewMetric(_BaseMetric):
99
"""
@@ -86,7 +86,7 @@ def plot(
8686
The figure containing the NewMetric plot.
8787
"""
8888

89-
#Need to define the plot_newMetric function in the matplotlib.py
89+
#Need to define the plot_newMetric function in the utils_plot.py
9090
return plot_newMetric(
9191
embeddings,
9292
labels,

tests/test_cka.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_compute_intra_label(dummy_cka, dummy_activations):
101101
assert isinstance(result, np.ndarray)
102102

103103

104-
@patch("cebra_lens.matplotlib.plot_cka_heatmaps")
104+
@patch("cebra_lens.utils_plot.plot_cka_heatmaps")
105105
def test_plot_calls_heatmap(mock_plot, dummy_cka):
106106
cka_matrices = {"A": np.random.rand(2, 2)}
107107
dummy_cka.plot(cka_matrices, annot=True)

0 commit comments

Comments
 (0)