Skip to content

codcordance/mistral-sae

Repository files navigation

MistralSAE

Optimized training of sparse autoencoders on Mistral models

GitHub

NB: Blog post is still being written.

MistralSAE is a tool for optimized online¹ training of sparse autoencoders, or SAEs, on (currently) Mistral 3 models including

  • reusable custom triton kernels, autograd and wrapper for faster training and smaller VRAM footprint;
  • "model slicing" optimizations for faster activation inference (and smaller VRAM footprint too);
  • end-to-end configurable training pipeline, with extensive metrics logging on wandb;
  • test and benchmarking suite;
  • optional basic webapp to chat with a steered model.

¹ Training performed on-the-fly by computing model activations from text batches and updating the SAE directly, without precomputing or storing them.

Installation

The project is compatible with CPython version 3.11 and later (it should work with python 3.10 too²) and CUDA 12.6 and later or ROCm. I highly recommend using uv. To install, clone the repository and run

uv sync --extra {cu126|cu128|cu130|cu131|rocm} [--extra steering]

Use the appropriate CUDA/ROCm extra marker. One can also add the --extra steering marker to install the requirements for the steering chat webapp. Please note that this will download and use³ prebuilt the Flash Attention wheels provided through my Flash Attention index (some not built by me). Then, activate the environment by running

# on Linux:
source .venv/bin/activate

# on Windows:
.venv\Scripts\activate 

² I wouldn't recommend it since it's significantly slower and doesn't support typing/TOML parsing in the Standard Library like the latter versions do (backward compatibility is achieved with tomli and typing_extensions).

³ Not on ROCm. Please refer to the official repository for the installation instructions. There is also one missing wheel for Python 3.14 on windows and CUDA 12.8 (the CUDA 12.6 version should work in replacement).

Usage

Most of the features can be used through the CLI. Run mistralsae [--help] or mistralsae <command> [-help] to display help/syntax. Please consult the cli documentation for more information. Here are the main uses of the cli:

Training

mistralsae train

You can configure the pipeline before by creating the config with mistralsae config create. You will also need to authentify yourself with the wandb login comand before.

The current training setup relies on the following datasets:

The last one being a gated dataset, you will need to request access and login before training (you can run hf auth login to do so). Please refer to the config documentation and components documentation for details on the implementation and on the list of logged metrics. Also see the Extending to other models and datasets section below.

WebUI (chat with steered model)

mistralsae app

If you just want to test the steering without running the whole training, you can download my public checkpoint on mistralai/Mistral-Small-3.2-24B-Instruct-2506 by running mistral_sae checkpoint download (or override the default with the MSAE_HF_SAE_REPO env variable). The default feature (184859) is the Eiffel Tower feature.

Extending to other models and datasets

Currently, the code only supports Mistral 3 models and the dataset listed above. While it should work as is on base models, please note it is formatting data with an instruct template.

The code can easily be adapted to other models and datasets.

  • The encoders, decoders and kernels submodules do not depend on any other component (config, logger, ...), and can thus be imported and used as is. Like the name suggests, they contain the different implementations (triton kernel, autograd and wrapper) for the sae encoder and decoder part. For more details, please consult the code and the components documentation.
  • To add a new dataset, create a subclass of TextDataset in the data module and implement parse_entry(self, entry), which must return a list of conversations (each conversation is a list of Mistral ChatMessageType messages). In __init__, call super().__init__(path, split, prob) to register the Hugging Face dataset path, split, and sampling probability used by TextDatasetsManager. Loading, shuffling/reloading, and conversion to ChatCompletionRequest are already handled by the base class. For concrete patterns, see DS_PDBooks, DS_Claire, and DS_Tulu in text_datasets.py.

Benchmark and test results

Benchmark and test results can be found in benchmarks.md. The default config is based on the optimal settings for training a "dynamic" (classic ReLU) SAE on an RTX PRO 6000 Blackwell GPU.


Planned Features

Currently the TopK implementation is not completed (the auxiliary loss is not correctly computed as described in OpenAI paper). I plan to work on that (and dead latent resampling) when I have time, along with new implementations for JumpReLU or Gated SAEs. I'll maybe work on support for more models and datasets too.

Acknowledgments

This project is inspired by and benefits from the following works:

About

optimized online training of sparse autoencoders on Mistral 3 models

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors