NB: Blog post is still being written.
MistralSAE is a tool for optimized online¹ training of sparse autoencoders, or SAEs, on (currently) Mistral 3 models including
- reusable custom triton kernels, autograd and wrapper for faster training and smaller VRAM footprint;
- "model slicing" optimizations for faster activation inference (and smaller VRAM footprint too);
- end-to-end configurable training pipeline, with extensive metrics logging on wandb;
- test and benchmarking suite;
- optional basic webapp to chat with a steered model.
¹ Training performed on-the-fly by computing model activations from text batches and updating the SAE directly, without precomputing or storing them.
The project is compatible with CPython version 3.11 and later (it should work with python 3.10 too²) and
CUDA 12.6 and later or ROCm. I highly recommend using uv. To install, clone the repository and run
uv sync --extra {cu126|cu128|cu130|cu131|rocm} [--extra steering]Use the appropriate CUDA/ROCm extra marker. One can also add the --extra steering marker to install the requirements
for the steering chat webapp. Please note that this will download and use³ prebuilt the Flash Attention wheels
provided through my Flash Attention index (some not built by me).
Then, activate the environment by running
# on Linux:
source .venv/bin/activate
# on Windows:
.venv\Scripts\activate ² I wouldn't recommend it since it's significantly slower and doesn't support typing/TOML parsing in the Standard Library like the latter versions do (backward compatibility is achieved with tomli and typing_extensions).
³ Not on ROCm. Please refer to the official repository for the installation instructions. There is also one missing wheel for Python 3.14 on windows and CUDA 12.8 (the CUDA 12.6 version should work in replacement).
Most of the features can be used through the CLI. Run mistralsae [--help] or mistralsae <command> [-help] to
display help/syntax. Please consult the cli documentation for more information. Here are the main
uses of the cli:
mistralsae trainYou can configure the pipeline before by creating the config with mistralsae config create. You will also need to
authentify yourself with the wandb login comand before.
The current training setup relies on the following datasets:
The last one being a gated dataset, you will need to request access and login before training (you can run
hf auth login to do so). Please refer to the config documentation and
components documentation for details on the implementation and on the list of logged metrics.
Also see the Extending to other models and datasets section below.
mistralsae app
If you just want to test the steering without running the whole training, you can download my public checkpoint on
mistralai/Mistral-Small-3.2-24B-Instruct-2506 by running mistral_sae checkpoint download (or override the default
with the MSAE_HF_SAE_REPO env variable). The default feature (184859) is the Eiffel Tower feature.
Currently, the code only supports Mistral 3 models and the dataset listed above. While it should work as is on base models, please note it is formatting data with an instruct template.
The code can easily be adapted to other models and datasets.
- The
encoders,decodersandkernelssubmodules do not depend on any other component (config, logger, ...), and can thus be imported and used as is. Like the name suggests, they contain the different implementations (triton kernel, autograd and wrapper) for the sae encoder and decoder part. For more details, please consult the code and the components documentation. - To add a new dataset, create a subclass of
TextDatasetin the data module and implementparse_entry(self, entry), which must return a list of conversations (each conversation is a list of MistralChatMessageTypemessages). In__init__, callsuper().__init__(path, split, prob)to register the Hugging Face dataset path, split, and sampling probability used byTextDatasetsManager. Loading, shuffling/reloading, and conversion toChatCompletionRequestare already handled by the base class. For concrete patterns, seeDS_PDBooks,DS_Claire, andDS_Tuluintext_datasets.py.
Benchmark and test results can be found in benchmarks.md. The default config is based on the optimal settings for training a "dynamic" (classic ReLU) SAE on an RTX PRO 6000 Blackwell GPU.
Currently the TopK implementation is not completed (the auxiliary loss is not correctly computed as described in OpenAI paper). I plan to work on that (and dead latent resampling) when I have time, along with new implementations for JumpReLU or Gated SAEs. I'll maybe work on support for more models and datasets too.
This project is inspired by and benefits from the following works:
- Templeton, et al., "Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet", Transformer Circuits Thread, 2024 ;
- Bricken, et al., "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning", Transformer Circuits Thread, 2023 ;
- Gao, et al., "Scaling and evaluating sparse autoencoders", OpenAI, 2024 ;
- Anthropic, "The Golden Gate Claude", 2024 ;
- Anthropic, "Transformer Circuits Update (June 2024)", 2024 ;
- Anthropic, "Transformer Circuits Update (April 2024) — Training SAEs", 2024 ;
- Anthropic, "Transformer Circuits Update (January 2025)", 2025.