Skip to content

Commit 8abc266

Browse files
committed
Run formatter
1 parent 89b9ac7 commit 8abc266

20 files changed

Lines changed: 186 additions & 125 deletions

cebra_lens/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# example of structure so that you can directly use the functions get_layer_activations instead of having to do CEBRA_Lens.activations.get_layer_activations
22
from .activations import *
3+
from .matplotlib import *
4+
from .quantification.cka_metric import *
35
from .quantification.decoder import *
46
from .quantification.distance import *
5-
from .quantification.cka_metric import *
67
from .quantification.rdm_metric import *
78
from .quantification.tsne import *
8-
from .matplotlib import *
9+
from .utils import *
910
from .utils_allen import *
1011
from .utils_hpc import *
11-
from .utils import *
1212

1313
# selects what files can be imported when doing from CEBRA_Lens import * --> keep env clean
1414
# __all__ = ['get_layer_activations']

cebra_lens/activations.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Functions to retrieve and handle layer activations"""
22

3+
from typing import Dict, List, Optional, Tuple, Type
4+
35
import cebra
4-
import torch
5-
import torch.nn as nn
6+
import matplotlib.pyplot as plt
67
import numpy as np
78
import numpy.typing as npt
8-
from typing import Tuple, Dict, List, Type, Optional
9+
import torch
10+
import torch.nn as nn
11+
912
from .matplotlib import plot_activations
10-
import matplotlib.pyplot as plt
1113

1214

1315
def _cut_array(array: npt.NDArray,

cebra_lens/matplotlib.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Matplotlib interface to CEBRA-Lens."""
22

3+
import random
34
from abc import *
4-
from typing import Optional, Tuple, List, Dict, Union
5-
import seaborn as sns
5+
from typing import Dict, List, Optional, Tuple, Union
6+
67
import matplotlib.axes
78
import matplotlib.pyplot as plt
89
import numpy as np
9-
import torch
1010
import numpy.typing as npt
11-
import random
11+
import seaborn as sns
12+
import torch
1213

1314

1415
class _BasePlot:
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from .base import *
12
from .cka_metric import *
2-
from .rdm_metric import *
3-
from .misc import *
4-
from .distance import *
53
from .decoder import *
6-
from .base import *
4+
from .distance import *
5+
from .misc import *
6+
from .rdm_metric import *
77
from .tsne import *

cebra_lens/quantification/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from tqdm import tqdm
2-
import numpy as np
31
import pickle
42
import types
5-
from typing import List, Union, Dict
63
from abc import *
74
from pathlib import Path
5+
from typing import Dict, List, Union
6+
7+
import numpy as np
88
import numpy.typing as npt
9+
from tqdm import tqdm
910

1011

1112
class _BaseMetric:

cebra_lens/quantification/cka_metric.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
66
"""
77

8-
from tqdm import tqdm
8+
from typing import Dict, List, Optional, Tuple
9+
10+
import matplotlib
911
import numpy as np
10-
from .base import _BaseMetric
11-
import cebra_lens.matplotlib as cebra_lens_matplotlib
12-
from typing import Optional, List, Dict, Tuple
1312
import numpy.typing as npt
14-
import matplotlib
13+
from tqdm import tqdm
14+
15+
import cebra_lens.matplotlib as cebra_lens_matplotlib
16+
17+
from .base import _BaseMetric
1518

1619

1720
class CKA(_BaseMetric):
@@ -189,7 +192,7 @@ def _compute_per_layer(
189192
cka_matrix = np.zeros((len(embeddings_1), len(embeddings_1[0])))
190193
for j in tqdm(range(len(embeddings_1))):
191194
if flag:
192-
# the situation when there multiple models inside model labels and the same number of
195+
# the situation when there multiple models inside model labels and the same number of
193196
# models inside each label
194197
cka_matrix[j, :] = self._compute_cka(embeddings_1[j],
195198
embeddings_2[j])
@@ -230,7 +233,7 @@ def compute(self, activations: Dict[str, npt.NDArray],
230233

231234
if len(activations_1) != len(activations_2):
232235
# if the number of models in a label is different from the other model label
233-
# choose embeddings_1 for the one with more models, and then embeddings_2 just compare with
236+
# choose embeddings_1 for the one with more models, and then embeddings_2 just compare with
234237
# the first model
235238
if len(activations_1) > len(activations_2):
236239
embeddings_1 = activations_1

cebra_lens/quantification/decoder.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
from typing import Dict, Optional, Tuple, Type
2+
13
import cebra
2-
import torch
4+
import matplotlib
35
import numpy as np
4-
from ..utils_allen import decoding_frames
5-
from ..utils_hpc import decoding_pos_dir
6-
from ..activations import get_activations_model
7-
from .base import _BaseMetric
8-
import cebra_lens.matplotlib as cebra_lens_matplotlib
96
import numpy.typing as npt
10-
from typing import Dict, Type, Tuple, Optional
11-
import torch.nn as nn
127
import sklearn.metrics
8+
import torch
139
import torch as pt
14-
import matplotlib
10+
import torch.nn as nn
11+
12+
import cebra_lens.matplotlib as cebra_lens_matplotlib
13+
14+
from ..activations import get_activations_model
15+
from ..utils_allen import decoding_frames
16+
from ..utils_hpc import decoding_pos_dir
17+
from .base import _BaseMetric
1518

1619

1720
def decoding(
@@ -370,8 +373,10 @@ def plot(
370373
)
371374

372375
if self.output_only:
373-
return cebra_lens_matplotlib.plot_decoding(results_dict, palette, self.dataset_label,
374-
label, plot_error, ax)
376+
return cebra_lens_matplotlib.plot_decoding(results_dict, palette,
377+
self.dataset_label,
378+
label, plot_error, ax)
375379
else:
376-
return cebra_lens_matplotlib.plot_layer_decoding(results_dict, title, self.dataset_label,
377-
label, plot_error, figsize)
380+
return cebra_lens_matplotlib.plot_layer_decoding(
381+
results_dict, title, self.dataset_label, label, plot_error,
382+
figsize)

cebra_lens/quantification/distance.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"file containing all the functions relative to distance computing"
22

3+
from typing import Dict, List, Optional, Tuple, Union
4+
35
import numpy as np
6+
import numpy.typing as npt
47
from scipy.spatial.distance import cdist, pdist
58
from sklearn.preprocessing import StandardScaler
6-
from typing import List, Optional, Tuple, Union, Dict
7-
from .misc import discrete_binning, repetition_binning, continuous_binning
8-
from .base import _BaseMetric
9+
910
from ..matplotlib import *
10-
import numpy.typing as npt
1111
from ..utils import extract_label
12+
from .base import _BaseMetric
13+
from .misc import continuous_binning, discrete_binning, repetition_binning
1214

1315

1416
class DistanceMetric:

cebra_lens/quantification/misc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""misc functions like normalization and possibly others"""
22

3+
import warnings
34
from random import sample
5+
from typing import List
6+
47
import numpy as np
5-
import torch
68
import numpy.typing as npt
7-
from typing import List
8-
import warnings
9+
import torch
910

1011

1112
def normalize_minmax(rdm: npt.NDArray) -> npt.NDArray:

cebra_lens/quantification/rdm_metric.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""All the functions relative to the Representation Dissimilarity Matrix (RDM) calculation"""
22

3-
from typing import Dict, List, Optional
3+
from typing import Dict, List, Optional, Tuple, Union
4+
5+
import matplotlib
46
import numpy as np
7+
import numpy.typing as npt
8+
import torch
59
from scipy.linalg import block_diag
6-
from typing import List, Optional, Tuple, Union
710
from scipy.spatial.distance import correlation, pdist, squareform
8-
from .misc import discrete_binning, continuous_binning
9-
import torch
10-
import matplotlib
11-
from .base import _BaseMetric
11+
1212
import cebra_lens.matplotlib as cebra_lens_matplotlib
13-
import numpy.typing as npt
13+
14+
from .base import _BaseMetric
15+
from .misc import continuous_binning, discrete_binning
1416

1517

1618
class RDM(_BaseMetric):

0 commit comments

Comments
 (0)