Skip to content

Commit 498398b

Browse files
GitHub Actions (#40)
* Create publish-book.yml - build docs * Create python-package.yml * Add Makefile, run formatters and codespell, add more git actions (#46) * Add makefile * Update requirements.txt * Fix tests, run formater and codespell checks * Add gitactions for formatting and codespell * Update title of visual notebook * Add the shell in yml git action file * Remove yapf on notebooks * Fix tests and add more * Run formatter * Fix imports with matplotlib * Remove outdated scripts (#48) * Add API to the jupyter-book (#49) * Add API to the jupyter-book * Update intro.md * Update pyproject.toml --------- Co-authored-by: Mackenzie Mathis <mackenzie.mathis@epfl.ch> --------- Co-authored-by: Mackenzie Mathis <mackenzie.mathis@epfl.ch> --------- Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com>
1 parent 116f42c commit 498398b

40 files changed

Lines changed: 1111 additions & 1632 deletions

.github/workflows/publish-book.yml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: publish-book
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
8+
jobs:
9+
deploy-book:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v4
13+
14+
- name: Set up Python 3.10
15+
uses: actions/setup-python@v4
16+
with:
17+
python-version: "3.10"
18+
19+
- name: Install dependencies
20+
run: |
21+
python -m pip install --upgrade pip
22+
python -m pip install .[docs]
23+
pip install jupyter-book sphinxcontrib-mermaid
24+
25+
- name: Build the book
26+
run: |
27+
jupyter-book build .
28+
29+
- name: GitHub Pages action
30+
uses: peaceiris/actions-gh-pages@v3.9.3
31+
with:
32+
github_token: ${{ secrets.GITHUB_TOKEN }}
33+
publish_dir: ./_build/html
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
name: python-package
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
build:
11+
runs-on: ${{ matrix.os }}
12+
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
os: [ubuntu-latest, macos-14, windows-latest]
17+
python-version: ["3.10"]
18+
include:
19+
- os: ubuntu-latest
20+
path: ~/.cache/pip
21+
- os: macos-14
22+
path: ~/Library/Caches/pip
23+
- os: windows-latest
24+
path: ~\AppData\Local\pip\Cache
25+
26+
steps:
27+
- name: Checkout code
28+
uses: actions/checkout@v3
29+
30+
- name: Set up Python
31+
uses: conda-incubator/setup-miniconda@v3
32+
with:
33+
channels: conda-forge,defaults
34+
channel-priority: strict
35+
python-version: ${{ matrix.python-version }}
36+
37+
- name: Install PyTables through Conda
38+
shell: bash -el {0}
39+
run: |
40+
conda install pytables==3.8.0 "numpy<2"
41+
42+
- name: Install dependencies
43+
shell: bash -el {0}
44+
run: |
45+
python -m pip install --upgrade pip setuptools wheel
46+
pip install -r requirements.txt
47+
48+
- name: Run the formatter
49+
shell: bash -el {0}
50+
run: |
51+
make format
52+
53+
- name: Run the spelling detector
54+
shell: bash -el {0}
55+
run: |
56+
make codespell
57+
58+
- name: Check the documentation coverage
59+
shell: bash -el {0}
60+
run: |
61+
make interrogate
62+
63+
- name: Run all pytest tests
64+
shell: bash -el {0}
65+
run: |
66+
pip install pytest
67+
pytest tests/

Makefile

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
CEBRA_LENS_VERSION := 0.1.0.dev0
2+
3+
dist:
4+
python3 -m pip install virtualenv
5+
python3 -m pip install --upgrade build twine
6+
python3 -m build --wheel --sdist
7+
8+
build: dist
9+
10+
test:
11+
python -m pytest --ff tests
12+
13+
interrogate:
14+
interrogate \
15+
--ignore-property-decorators \
16+
--ignore-init-method \
17+
--verbose \
18+
--ignore-semiprivate \
19+
--ignore-private \
20+
--ignore-magic \
21+
--omit-covered-files \
22+
-f 80 \
23+
cebra_lens
24+
25+
docs:
26+
export PYTHONPATH=$(pwd)
27+
jupyter-book build docs
28+
29+
docs-touch:
30+
find docs/docs -iname '*.md' -exec touch {} \;
31+
jupyter-book build docs/docs
32+
33+
docs-strict:
34+
jupyter-book build docs --keep-going --strict
35+
36+
# Serve the docs
37+
serve_docs:
38+
python -m http.server 8080 --bind 127.0.0.1 -d docs/_build/html
39+
40+
# Serve the entire page
41+
serve_page:
42+
python -m http.server 8080 --bind 127.0.0.1 -d docs/_build/html
43+
44+
# Format code in the main package and docs
45+
format:
46+
yapf -i -p -r cebra_lens
47+
yapf -i -p -r tests
48+
isort cebra_lens/
49+
isort tests/
50+
51+
codespell:
52+
codespell cebra_lens/ tests/ docs/docs/*.md -L "nce, nd"
53+
54+
55+
.PHONY: docs docs-touch docs-strict serve_docs serve_page

cebra_lens/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# example of structure so that you can directly use the functions get_layer_activations instead of having to do CEBRA_Lens.activations.get_layer_activations
22
from .activations import *
3-
from .quantification import *
4-
from .quantification.decoding import *
5-
from .quantification.distance import *
63
from .quantification.cka_metric import *
4+
from .quantification.decoder import *
5+
from .quantification.distance import *
76
from .quantification.rdm_metric import *
87
from .quantification.tsne import *
9-
from .matplotlib import *
8+
from .utils import *
109
from .utils_allen import *
1110
from .utils_hpc import *
12-
from .utils import *
11+
from .utils_plot import *
1312

1413
# selects what files can be imported when doing from CEBRA_Lens import * --> keep env clean
1514
# __all__ = ['get_layer_activations']

cebra_lens/activations.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
"""Functions to retrieve and handle layer activations"""
22

3+
from typing import Dict, List, Optional, Tuple, Type
4+
35
import cebra
4-
import torch
5-
import torch.nn as nn
6+
import matplotlib.pyplot as plt
67
import numpy as np
78
import numpy.typing as npt
8-
from typing import Tuple, Dict, List, Type, Optional
9-
from .matplotlib import plot_activations
10-
import matplotlib.pyplot as plt
9+
import torch
10+
import torch.nn as nn
11+
12+
from .utils_plot import plot_activations
1113

1214

13-
def _cut_array(
14-
array: npt.NDArray, cut_indices: Tuple[np.int64, np.int64]
15-
) -> npt.NDArray:
15+
def _cut_array(array: npt.NDArray,
16+
cut_indices: Tuple[np.int64, np.int64]) -> npt.NDArray:
1617
"""
1718
Slices the input array based on the provided cut indices.
1819
This is used to remove the padding from activations in `get_activations_model`.
@@ -36,7 +37,7 @@ def _cut_array(
3637
sliced_array = array
3738
else:
3839
# Otherwise, slice the array
39-
sliced_array = array[:, start : end if end != 0 else start :]
40+
sliced_array = array[:, start:end if end != 0 else start:]
4041
return sliced_array
4142

4243

@@ -80,10 +81,13 @@ def get_cut_indices(
8081
# add for output layer
8182
cut_indices.append((0, 0))
8283
elif layer_type == None:
83-
raise NotImplementedError("Padding handling not implemented for 'all'.")
84+
raise NotImplementedError(
85+
"Padding handling not implemented to handle activations for all layer types.",
86+
"Set layer_type to nn.Conv1d to use the default padding handling.")
8487
else:
8588
# need to analyze the padding from the last output of Conv1 and apply the same cut
86-
raise NotImplementedError(f"Padding handling not implemented for {layer_type}.")
89+
raise NotImplementedError(
90+
f"Padding handling not implemented for {layer_type}.")
8791
return cut_indices
8892

8993

@@ -93,7 +97,7 @@ def get_activations_model(
9397
session_id: int = -1,
9498
name: str = "single",
9599
instance: int = 0,
96-
layer_type: Type[nn.Module] = None,
100+
layer_type: Type[nn.Module] = nn.Conv1d,
97101
) -> Dict[str, npt.NDArray]:
98102
"""
99103
Extracts activations from a single model layer.
@@ -111,7 +115,8 @@ def get_activations_model(
111115
instance : int
112116
The instance number for the model, used to differentiate between models from the same model category.
113117
layer_type : Type[nn.Module]
114-
The type of layer to extract activations from. Defaults to None, meaning extracts activations from all layers.
118+
The type of layer to extract activations from. None means it extracts activations from all layers.
119+
Default is nn.Conv1d, which is the most common layer type used in CEBRA models.
115120
116121
Returns:
117122
--------
@@ -125,26 +130,25 @@ def get_activations_model(
125130
activations = {}
126131
transform_kwargs = {}
127132
if model.solver_name_ in [
128-
"multi-session",
129-
"multi-session-aux",
130-
"multiobjective-solver",
133+
"multi-session",
134+
"multi-session-aux",
135+
"multiobjective-solver",
131136
]:
132137

133138
model_ = model.model_[session_id]
134139
transform_kwargs.update({"session_id": session_id})
135140

136141
elif model.solver_name_ in [
137-
"single-session",
138-
"single-session-aux",
139-
"single-session-hybrid",
140-
"single-session-full",
142+
"single-session",
143+
"single-session-aux",
144+
"single-session-hybrid",
145+
"single-session-full",
141146
]:
142147
model_ = model.model_
143148

144149
else:
145150
raise NotImplementedError(
146-
f"Solver {model.solver_name_} is not yet implemented."
147-
)
151+
f"Solver {model.solver_name_} is not yet implemented.")
148152

149153
activations, handles, conv_layer_info = _attach_hooks(
150154
activations=activations,
@@ -209,14 +213,14 @@ def process_activations(
209213
name=model_name,
210214
instance=i,
211215
layer_type=layer_type,
212-
)
213-
)
216+
))
214217

215218
return activations
216219

217220

218221
# Function to create a hook that stores the activations in the dictionary
219222
def _get_activation(name: str, activations: Dict):
223+
220224
def hook(model, input, output):
221225
activations[name] = output.detach().squeeze().numpy()
222226

@@ -262,8 +266,7 @@ def _attach_hooks(
262266
# attach hook to the layer_type and to the output layer
263267
if isinstance(model.net[i], layer_type) or i == len(model.net) - 1:
264268
hook, activations = _get_activation(
265-
f"{name}_{instance}_layer_{num_layer}", activations
266-
)
269+
f"{name}_{instance}_layer_{num_layer}", activations)
267270
if isinstance(model.net[i], layer_type):
268271
conv_layer_info.append(model.net[i].kernel_size[0])
269272
handle = model.net[i].register_forward_hook(hook)
@@ -298,8 +301,7 @@ def _attach_hooks(
298301

299302
else:
300303
hook, activations = _get_activation(
301-
f"{name}_{instance}_layer_{num_layer}", activations
302-
)
304+
f"{name}_{instance}_layer_{num_layer}", activations)
303305

304306
handle = model.net[i].register_forward_hook(hook)
305307
handles.append(handle)
@@ -309,8 +311,7 @@ def _attach_hooks(
309311

310312

311313
def aggregate_activations(
312-
activations: Dict[str, npt.NDArray],
313-
) -> Dict[str, npt.NDArray]:
314+
activations: Dict[str, npt.NDArray], ) -> Dict[str, npt.NDArray]:
314315
"""
315316
Aggregates activations by model identifier aka. instance.
316317
This function takes a dictionary of activations where the keys are strings containing model identifiers and layer information,
@@ -387,8 +388,7 @@ def get_activations(
387388
activations = activations or {}
388389

389390
aggregated_activations = aggregate_activations(
390-
process_activations(models, data, session_id, activations, layer_type)
391-
)
391+
process_activations(models, data, session_id, activations, layer_type))
392392

393393
activations_dict = {}
394394
for key, value in aggregated_activations.items():
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from .base import *
12
from .cka_metric import *
2-
from .rdm_metric import *
3-
from .misc import *
3+
from .decoder import *
44
from .distance import *
5-
from .decoding import *
6-
from .base import *
5+
from .misc import *
6+
from .rdm_metric import *
77
from .tsne import *

cebra_lens/quantification/base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from tqdm import tqdm
2-
import numpy as np
31
import pickle
42
import types
5-
from typing import List, Union, Dict
63
from abc import *
74
from pathlib import Path
5+
from typing import Dict, List, Union
6+
7+
import numpy as np
88
import numpy.typing as npt
9+
from tqdm import tqdm
910

1011

1112
class _BaseMetric:
@@ -14,7 +15,8 @@ class _BaseMetric:
1415
"""
1516

1617
@abstractmethod
17-
def compute(self, activations: Dict[str, npt.NDArray]) -> Dict[str, npt.NDArray]:
18+
def compute(self,
19+
activations: Dict[str, npt.NDArray]) -> Dict[str, npt.NDArray]:
1820
"""
1921
Every metric which inherits ``_BaseMetric`` needs to implement a compute function.
2022
The compute function is specific to a metric, e.g. intra-bin distance, RDM, CKA,...
@@ -66,9 +68,8 @@ def save(self, filepath: str, data: Dict[str, npt.NDArray]) -> None:
6668
and the value is a npt.NDArray containing for all the models under that label the calculated data.
6769
"""
6870
filepath = Path(filepath)
69-
custom_filepath = filepath.with_stem(
70-
filepath.stem + f"_{self.__class__.__name__}"
71-
)
71+
custom_filepath = filepath.with_stem(filepath.stem +
72+
f"_{self.__class__.__name__}")
7273
with open(custom_filepath, "wb") as f:
7374
pickle.dump(data, f)
7475

0 commit comments

Comments
 (0)