You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
## A python library for mechanistic interpretability of CEBRA models
5
4
6
-
This Python codebase allows for neural representation analysis of CEBRA models. It contains tools to help answer the question: **What representations is my model learning?** We can get a glimpse of what the models learn by looking at the NN units themselves after the model is trained, using “neuroscientist methods” such as CKA, PCA/tSNE (See Sandbrink et al 2023). Precisely these "neuroscientist methods" are implemented in this codebase.
The current version of CEBRA-Lens supports specific analysis on the Allen Institute visual coding dataset ([DeVries et al, Nature Neuro., 2020](https://www.nature.com/articles/s41593-019-0550-9)) and Hippocampus dataset ([Grosmark & Buzáki, Science, 2016](https://www.science.org/doi/full/10.1126/science.aad1935)), and for general analysis on other datasets.
9
7
10
-
## 🔍 Analysis
8
+
**CEBRA-Lens** is a Python library for analyzing and interpreting neural representations learned by models trained with [CEBRA](https://github.com/AdaptiveMotorControlLab/cebra). It provides tools for mechanistic interpretability, allowing users to probe, visualize, and understand the structure of learned embeddings. The library is designed to support in-depth analysis of representational geometry, feature selectivity, and latent space dynamics in neuroscience and beyond. 👋 We welcome contributions and will continue to expand the library in the coming years.
11
9
12
-
Implemented "neuroscientist methods" for neural representation analysis are presented below.
🚨 Make sure that the environment in which you trained the CEBRA models in **has the same torch version** as the environment used for CEBRA-Lens.
15
+
16
+
```{Hint} Familiar with python packages and conda? Quick Install Guide:
17
+
```bash
18
+
conda create -n CEBRAlens python=3.12
19
+
conda activate CEBRAlens
20
+
conda install -c conda-forge pytables==3.8.0
21
+
22
+
# install PyTorch with your desired CUDA version (or for CPU only)- check their website: https://pytorch.org/get-started/locally/
23
+
# example: GPU version of pytorch for CUDA 11.3
24
+
conda install pytorch cudatoolkit=11.3 -c pytorch
25
+
26
+
# install CEBRA and CEBRA-lens
27
+
pip install --pre 'cebra[datasets,demos]
28
+
pip install -- cebralens
29
+
```
30
+
31
+
## 🦓🔍 Analysis Methods
32
+
33
+
Implemented mechanistic interpretability methods for neural representation analysis are presented below.
13
34
14
35
### Model performance analysis
15
36
@@ -47,31 +68,9 @@ These analyses quantify the change in the distance calculated per layer in a mod
47
68
48
69
<imgsrc="figures/analysis.png"alt="analysis">
49
70
50
-
## 📚 Codebase folder structure
51
-
52
-
Below is the folder structure of the repository with the main folder and files. The `cebra_lens` folder contains all the code for the analysis with the metric class definitions in the `quantification` folder, the `demos` folder contains the usage jupyter notebooks and finally there is a `tests` folder which contains some pytest for the repo.
53
-
54
-
CEBRA_lens/
55
-
├── README.md
56
-
├── cebra_lens/
57
-
│ ├── quantification/
58
-
│ │ ├── base.py
59
-
│ │ ├── cka_metric.py
60
-
│ │ ├── decoding.py
61
-
│ │ ├── distance.py
62
-
│ │ ├── misc.py
63
-
│ │ ├── rdm_metric.py
64
-
│ │ └── tsne.py
65
-
│ ├── activations.py
66
-
│ ├── matplotlib.py
67
-
│ ├── utils_allen.py
68
-
│ ├── utils_hpc.py
69
-
│ └── utils.py
70
-
│
71
-
├── demos/
72
-
│ ├── UsageDemoVISUAL.ipynb
73
-
│ └── UsageDemoGENERAL.ipynb
74
-
└── tests/
71
+
# Demo
72
+
73
+
The current version of CEBRA-Lens supports specific analysis on the Allen Institute visual coding dataset ([DeVries et al, Nature Neuro., 2020](https://www.nature.com/articles/s41593-019-0550-9)) and Hippocampus dataset ([Grosmark & Buzáki, Science, 2016](https://www.science.org/doi/full/10.1126/science.aad1935)), and for general analysis on other datasets. See the example notebooks we provide.
75
74
76
75
## 📊Usage
77
76
@@ -97,12 +96,22 @@ fig = lens.plot_metric(
97
96
)
98
97
```
99
98
100
-
The full demonstration of the usage is in the form of 2 jupyter notebooks:
101
-
- UsageDemoVISUAL: analysis on the Allen visual dataset, [here](https://github.com/AdaptiveMotorControlLab/CEBRA-lens/blob/eloise/tests/demos/UsageDemoVISUAL.ipynb)
102
-
- UsageDemoGENERAL: analysis on the Hippocampus dataset, but without specific dataset functions, [here](https://github.com/AdaptiveMotorControlLab/CEBRA-lens/blob/eloise/tests/demos/UsageDemoGENERAL.ipynb)
99
+
#### Jupyter Notebooks
100
+
- UsageDemoVISUAL: analysis on the Allen visual dataset, [here](https://github.com/AdaptiveMotorControlLab/CEBRA-lens/blob/main/demos/UsageDemoVISUAL.ipynb).
101
+
- UsageDemoGENERAL: analysis on the Hippocampus dataset, but without specific dataset functions, [here](https://github.com/AdaptiveMotorControlLab/CEBRA-lens/blob/main/demos/UsageDemoGENERAL.ipynb).
103
102
104
103
These two notebooks showcase the different approach when analyzing a pre-defined dataset and a non-defined dataset.
105
104
105
+
106
+
# Acknowledgements
107
+
108
+
- This repository contains the code for [Eloise's](https://github.com/eloisehabek) semester's project "Engineering software for neural representation analysis"(SPRING 2025),
109
+
building on [Riccardo's](https://github.com/riccardoprog) semester project "Exploring nonlinear encoders for robust vision decoding" (FALL 2024).
110
+
- The work was supervised by [Célia Benquet](https://github.com/CeliaBenquet) and [Mackenzie](https://github.com/MMathisLab) at the Mathis Laboratory of Adaptive Intelligence.
111
+
- We thank the [DeepDraw project](https://elifesciences.org/articles/81499) for some [source code](https://github.com/amathislab/DeepDraw) and analysis methods.
112
+
113
+
# Other helpful tips:
114
+
106
115
## 📥 Download dataset
107
116
108
117
The `utils.py` file contains a overarching `get_data` function which checks for a pre-defined dataset label and accordingly loads the data based on specific functions for the dataset. If you want to load data from a non-defined dataset, you need to first import the loading function inside the `utils.py` file as so:
0 commit comments