- Overview
- Resolution Order
- Configuration Reference
- Text Data
- Language Model
- SAE
- Mixed Precision
- Optimizer
- Kernel Tuning
- Compilation
- Weights & Biases
- App Server
- Coverage Workloads
- Environment Variables
- Removed Environment Variables (Migration)
- Operational Notes
This document lists every runtime configuration field and all project environment variables used by MistralSAE.
- Resolve config path from
MSAE_CONFIG_PATHif set, otherwise./msae-config.toml. - If the file exists, load it into
MSAEConfig. - If missing and
MSAE_CREATE_CONFIG_IF_MISSING=true, generate the file from defaults. - If missing and auto-create is disabled, run with in-memory defaults.
- Runtime overrides can be applied via
configure(...)orwith configuration(...):.
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
shuffle_seed |
int |
42 |
Seed for shuffling when loading a training dataset. | |
reload_shuffle_seed |
int |
58 |
Seed for shuffling when reloading a training dataset. | |
continue_final_message |
bool |
true |
Whether to tell the model to continue the final message (setting to false add instruct tokens). | |
P_PDBOOKS |
float |
0.97 |
Probability of selecting the PDBooks dataset. | |
P_CLAIRE |
float |
0.01 |
Probability of selecting the Claire dataset. | |
P_TULU |
float |
0.02 |
Probability of selecting the Tulu dataset. | |
CHUNK_LEN_PDBOOKS |
int |
2048 |
Chunking length (see common/_misc.py) | |
CHUNK_LEN_CLAIRE |
int |
2048 |
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
model_id |
str |
"mistralai/Mistral-Small-3.2-24B-Instruct-2506" |
model id on Hugging Face. | |
dtype |
str |
"bfloat16" |
torch dtype for the model | |
device |
"cuda" | "cpu" |
"cuda" |
Device, only "cuda" or "cpu" are supported. | |
non_blocking |
bool |
true |
Whether to use non-blocking transfers to device after tokenization. | |
context_mode |
"no_grad" | "inference_mode" |
"inference_mode" |
Autograd context used for inference/benchmarking ('no_grad' or 'inference_mode'). | |
attention_implementation |
"sdpa" | "flash_attention_2" |
"flash_attention_2" |
Attention backend to use. Currently only support "sdpa" and "flash_attention_2". | |
batch_size |
int |
16 |
Batch size (number of sequence) for LLM inference. | |
seq_token_len |
int |
256 |
Strict sequence length (in tokens) for LLM inference when compiled. | |
sae_layer |
int |
30 |
Transformer layer after which to place the SAE. For training, only the layer up to this one are loaded and used. |
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
device |
"cuda" | "cpu" |
"cuda" |
Device used for SAE training. | |
batch_size |
int |
2048 |
Target SAE batch size in tokens for training. | |
num_steps |
int |
600000 |
Planned number of SAE optimization steps. | |
base_lr |
float |
5e-05 |
Base learning rate for SAE training. | |
grad_clip |
float |
1.0 |
Gradient clipping max norm for SAE training. | |
log_every |
int |
50 |
Trainer logging period (in optimization steps). | |
more_metrics_every |
int |
50 |
Period (in optimization steps) for heavy diagnostic metrics. 0 disables heavy metrics; if > 0, it must be a multiple of log_every and >= log_every. |
When disabled, heavy diagnostics are neither computed nor logged. |
input_dtype |
str |
"bfloat16" |
Input activation dtype for SAE training batches. | |
latent_multiplier |
float |
56.0 |
Multiplier used to infer SAE latent dim from input dim. | |
latent_dimension |
int | null |
null |
If set, overrides latent_multiplier. | When set, this overrides latent_multiplier. |
encoder_type |
"dynamic" | "topk" |
"dynamic" |
SAE encoder type ('dynamic' or 'topk'). | |
encoder_implementation |
"gather" | "scatter" | "torch" |
"gather" |
Encoder implementation ('torch', 'gather', 'scatter'). | |
decoder_implementation |
"torch" | "bag" | "gather" |
"gather" |
Decoder implementation ('torch', 'bag', 'gather'). | |
topk |
int |
64 |
Top-k value when encoder_type='topk'. | |
topk_aux |
int |
0 |
Auxiliary top-k branch size (currently not fully implemented). | |
topk_apply_relu |
bool |
true |
Whether to apply ReLU on selected TopK activations. | |
tied_init |
bool |
true |
Whether to apply tied initialization when building SAE from config. | |
start_with_torch |
bool |
true |
If true, start training with torch/torch implementations and switch later. | |
switch_to_config_impl_after_step |
int | null |
50000 |
Step at which to switch from torch/torch to configured implementations. | Used only when start_with_torch = true. |
nnz_approx_threshold |
float |
1e-08 |
Threshold used to compute nnz_approx metric. | |
lambda_value |
float |
5.0 |
Weight for the sparsity term in SAE loss. | |
lambda_schedule_enabled |
bool |
true |
Enable linear schedule for lambda. | |
lambda_schedule_start_step |
int |
0 |
Start step for lambda linear schedule. | |
lambda_schedule_end_step |
int |
30000 |
End step for lambda linear schedule. | |
lambda_schedule_start_value |
float |
0.0 |
Start lambda value for linear schedule. | |
lambda_schedule_end_value |
float |
5.0 |
End lambda value for linear schedule. | |
lr_schedule_enabled |
bool |
true |
Enable linear learning-rate scale schedule. | |
lr_schedule_start_step |
int |
480000 |
Start step for LR scale linear schedule. | |
lr_schedule_end_step |
int |
600000 |
End step for LR scale linear schedule. | |
lr_schedule_start_scale |
float |
1.0 |
Start LR scale for linear schedule. | |
lr_schedule_end_scale |
float |
0.2 |
End LR scale for linear schedule. | |
dead_latent_resampling_enabled |
bool |
false |
Enable dead latent resampling and num_dead_latents logging. |
|
dead_latent_resample_every_steps |
int |
0 |
Resample period in steps. | |
dead_latent_resample_min_step |
int |
1 |
Minimum step before enabling resampling. | |
dead_latent_activity_ema_decay |
float |
0.99 |
EMA decay for latent activity. | |
dead_latent_activity_threshold |
float |
1e-06 |
Dead-latent threshold on activity EMA. | |
dead_latent_fallback_decoder_row_norm |
float |
0.1 |
Fallback decoder row norm when all latents are dead during resampling. | |
checkpoint_enabled |
bool |
true |
Enable SAE checkpointing. | |
checkpoint_directory |
str |
"checkpoints/sae" |
Directory for SAE checkpoints. | |
checkpoint_save_every_steps |
int |
10000 |
Checkpoint save period in steps. | |
checkpoint_save_final |
bool |
true |
Save final checkpoint at training end. | |
checkpoint_keep_last_n |
int | null |
2 |
If set, keep only the latest N checkpoints. | |
checkpoint_resume |
bool |
true |
Resume from latest checkpoint by default. | |
checkpoint_resume_path |
str | null |
null |
If set, resume specifically from this checkpoint path. | When set, this has priority over automatic latest-checkpoint resume. |
checkpoint_save_optimizer_state |
bool |
true |
Persist optimizer state into checkpoints. | |
checkpoint_load_optimizer_state |
bool |
true |
Restore optimizer state from checkpoints. | |
checkpoint_strict_shapes |
bool |
true |
Enforce strict shape checks when loading checkpoints. |
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
master_weights_dtype |
str |
"bfloat16" |
Master SAE parameter dtype used for trainable weights. | |
forward_autocast_enabled |
bool |
false |
Enable autocast during SAE forward and loss computation. | |
forward_autocast_dtype |
str |
"bfloat16" |
Autocast dtype for SAE forward and loss computation. | |
backward_autocast_enabled |
bool |
false |
Enable autocast during backward. Defaults to false for FP32 gradient accumulation. | |
gradient_dtype |
str |
"bfloat16" |
Gradient dtype expected after backward. |
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
optimizer_name |
"adamw" | "adam" |
"adamw" |
Optimizer used by SAE trainer. | |
moments_dtype |
str |
"bfloat16" |
AdamW optimizer moments dtype target for SAE training. | |
adamw_beta1 |
float |
0.9 |
AdamW beta1. | |
adamw_beta2 |
float |
0.999 |
AdamW beta2. | |
adamw_epsilon |
float |
1e-15 |
AdamW epsilon. | |
adamw_weight_decay |
float |
0.0 |
AdamW weight decay. |
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
autotune_enabled |
bool |
true |
Enable Triton autotune for sparse kernels. | |
gather_block_tokens |
int |
8 |
BLOCK_TOKENS for encoder gather kernel when autotune is disabled. | |
gather_block_d |
int |
512 |
BLOCK_D for encoder gather kernel when autotune is disabled. | |
gather_num_warps |
int |
1 |
num_warps for encoder gather kernel when autotune is disabled. | |
gather_num_stages |
int |
4 |
num_stages for encoder gather kernel when autotune is disabled. | |
gather_norm_block_latents |
int |
8 |
BLOCK_LATENTS for decoder gather kernel when autotune is disabled. | |
gather_norm_block_d |
int |
128 |
BLOCK_D for decoder gather kernel when autotune is disabled. | |
gather_norm_num_warps |
int |
1 |
num_warps for decoder gather kernel when autotune is disabled. | |
gather_norm_num_stages |
int |
3 |
num_stages for decoder gather kernel when autotune is disabled. | |
scatter_block_d |
int |
128 |
BLOCK_D for encoder scatter kernel when autotune is disabled. | |
scatter_num_warps |
int |
4 |
num_warps for encoder scatter kernel when autotune is disabled. | |
scatter_num_stages |
int |
4 |
num_stages for encoder scatter kernel when autotune is disabled. |
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
float32_matmul_precision |
"highest" | "high" | "medium" |
"high" |
torch float32 matmul precision setting ('highest', 'high', or 'medium'). | |
llm_compile_variant |
"none" | "inductor_default" | "inductor_reduce_overhead" | "inductor_max_autotune" | "inductor_max_autotune_cudagraphs" |
"inductor_max_autotune" |
torch.compile preset for the LLM. Can be "none", "inductor_default", "inductor_reduce_overhead", "inductor_max_autotune" or "inductor_max_autotune_cudagraphs". | |
llm_fullgraph |
bool |
false |
LLM torch.compile fullgraph flag. Must remain false with attention_implementation='flash_attention_2'. | |
sae_compile_variant |
"none" | "inductor_default" | "inductor_reduce_overhead" | "inductor_max_autotune" | "inductor_max_autotune_cudagraphs" |
"inductor_max_autotune" |
torch.compile preset for SAE training/inference. Same accepted values as llm_compile_variant. |
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
enabled |
bool |
true |
Enable Weights and Biases logging. | |
project |
str |
"MistralSAE" |
W&B project (used when enabled). | |
run_name |
str | null |
null |
If None, auto-generate. | |
entity |
str | null |
null |
Optional W&B entity. | |
group |
str | null |
null |
Optional W&B group. | |
tags |
list[str] | null |
null |
Optional W&B tags. | |
mode |
str | null |
null |
Optional W&B mode (e.g. 'offline'). |
Settings for the optional webui to chat with a steered model.
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
host |
str |
"127.0.0.1" |
Host/IP used by the steering web app server. | |
port |
int |
8000 |
Port used by the steering web app server. | |
default_feature_id |
int |
199290 |
Default steering SAE feature index in the web UI. | Runtime requests can override this value from the interface. |
clamp_scale |
float |
3.0 |
Global clamp scale multiplier applied to user steering value. | |
max_new_tokens |
int |
256 |
Maximum number of generated tokens per response. |
| Field | Values | Default | Description | Details |
|---|---|---|---|---|
force_config |
bool |
false |
If true, ignore test-provided workload overrides and run coverage with config-only values. | Disables test workload kwargs overrides. |
num_iters |
int |
30 |
Default number of measured iterations for coverage benchmarks. | |
num_warmup_iters |
int |
3 |
Default number of warmup iterations for coverage benchmarks. | |
max_new_tokens |
int |
3 |
Default max_new_tokens for causal coverage tests. | |
stream_dataset |
bool |
true |
Whether coverage LM batch generation streams datasets. | |
leave_tqdm |
bool |
true |
Whether progress bars are shown during coverage benches. | |
cleanup_memory |
bool |
true |
Whether coverage workloads aggressively cleanup CUDA memory. | |
llm_seq_lens |
list[int] |
[4096, 2048, 1024, 512, 256] |
Default sequence lengths for seq/batch coverage bench. | |
llm_batch_sizes |
list[int] |
[32, 16, 8, 4, 2, 1] |
Default batch sizes for seq/batch coverage bench. | |
compile_contexts |
list["no_grad" | "inference_mode"] |
['no_grad', 'inference_mode'] |
Default contexts for LLM compile coverage bench. | |
compile_attn_implementations |
list["sdpa" | "flash_attention_2"] |
['sdpa', 'flash_attention_2'] |
Default attention implementations for LLM compile coverage bench. | |
compile_variants |
list["none" | "inductor_default" | "inductor_reduce_overhead" | "inductor_max_autotune" | "inductor_max_autotune_cudagraphs"] |
['none', 'inductor_default', 'inductor_reduce_overhead', 'inductor_max_autotune', 'inductor_max_autotune_cudagraphs'] |
Default compile variants for compile coverage benches. | |
compile_fullgraph_values |
list[bool] |
[False, True] |
Default fullgraph sweep values for LLM compile coverage bench. | Applied only to compile variants different from none. |
dynamic_parity_d_in |
int |
64 |
Default d_in for dynamic encoder parity. | |
dynamic_parity_d_out |
int |
512 |
Default d_out for dynamic encoder parity. | |
dynamic_parity_batch_size |
int |
64 |
Default batch size for dynamic encoder parity. | |
dynamic_parity_target_nnz_per_sample |
int |
64 |
Default target nnz/sample for dynamic encoder parity. | |
topk_parity_d_in |
int |
64 |
Default d_in for topk encoder parity. | |
topk_parity_d_out |
int |
512 |
Default d_out for topk encoder parity. | |
topk_parity_k |
int |
16 |
Default k for topk encoder parity. | |
topk_parity_batch_size |
int |
64 |
Default batch size for topk encoder parity. | |
encoder_bench_d_in |
int |
5120 |
Default d_in for encoder benches. | |
encoder_bench_d_out |
int |
286720 |
Default d_out for encoder benches. | |
encoder_bench_k |
int |
64 |
Default k for topk encoder bench. | |
encoder_bench_batch_size |
int |
4096 |
Default batch size for encoder benches. | |
decoder_parity_d_in |
int |
128 |
Default d_in for decoder parity. | |
decoder_parity_d_out |
int |
512 |
Default d_out for decoder parity. | |
decoder_parity_batch_size |
int |
128 |
Default batch size for decoder parity. | |
decoder_parity_points_per_sample |
int |
32 |
Default points_per_sample for decoder parity. | |
decoder_bench_d_in |
int |
286720 |
Default d_in for decoder benches. | |
decoder_bench_d_out |
int |
5120 |
Default d_out for decoder benches. | |
decoder_bench_batch_size |
int |
4096 |
Default batch size for decoder benches. | |
decoder_bench_points_per_sample |
int |
64 |
Default points_per_sample for decoder benches. | |
sae_parity_d_in |
int |
64 |
Default d_in for SAE parity. | |
sae_parity_d_hidden |
int |
512 |
Default d_hidden for SAE parity. | |
sae_parity_batch_size |
int |
64 |
Default batch size for SAE parity. | |
sae_parity_k |
int |
16 |
Default k for SAE topk parity. | |
sae_parity_topk_aux |
int |
0 |
Default topk_aux for SAE parity. | |
sae_parity_topk_apply_relu |
bool |
true |
Default topk_apply_relu for SAE parity. | |
sae_bench_d_in |
int |
5120 |
Default d_in for SAE benches. | |
sae_bench_d_hidden |
int |
286720 |
Default d_hidden for SAE benches. | |
sae_bench_batch_size |
int |
4096 |
Default batch size for SAE benches. | |
sae_bench_k |
int |
64 |
Default k for SAE topk bench. | |
sae_bench_topk_aux |
int |
0 |
Default topk_aux for SAE bench. | |
sae_bench_topk_apply_relu |
bool |
true |
Default topk_apply_relu for SAE bench. | |
sae_compile_d_in |
int |
5120 |
Default d_in for SAE compile bench. | |
sae_compile_d_hidden |
int |
286720 |
Default d_hidden for SAE compile bench. | |
sae_compile_batch_size |
int |
4096 |
Default batch size for SAE compile bench. | |
sae_compile_k |
int |
64 |
Default k for SAE compile bench. | |
sae_compile_topk_aux |
int |
0 |
Default topk_aux for SAE compile bench. | |
sae_compile_topk_apply_relu |
bool |
true |
Default topk_apply_relu for SAE compile bench. | |
sae_compile_encoder_type |
"dynamic" | "topk" |
"dynamic" |
Default SAE encoder type for compile bench. | |
sae_compile_encoder_implementation |
"gather" | "scatter" | "torch" |
"gather" |
Default SAE encoder implementation for compile bench. | |
sae_compile_decoder_implementation |
"torch" | "bag" | "gather" |
"gather" |
Default SAE decoder implementation for compile bench. | |
parity_max_abs_gate |
float |
0.001 |
Default max-abs gate for parity checks. | |
parity_mean_abs_gate |
float |
1e-05 |
Default mean-abs gate for parity checks. | |
parity_seed |
int |
42 |
Default seed for parity/bench coverage workloads. | |
parity_dtype |
str |
"float32" |
Default dtype for coverage parity workloads. | |
parity_master_weights_dtype |
str |
"float32" |
Default master weights dtype for parity workloads. | |
parity_forward_autocast_enabled |
bool |
false |
Default forward autocast for parity workloads. | |
parity_forward_autocast_dtype |
str |
"bfloat16" |
Default forward autocast dtype for parity workloads. | |
bench_dtype |
str |
"bfloat16" |
Default dtype for coverage bench workloads. | |
bench_master_weights_dtype |
str |
"float32" |
Default master weights dtype for bench workloads. | |
bench_forward_autocast_enabled |
bool |
true |
Default forward autocast for bench workloads. | |
bench_forward_autocast_dtype |
str |
"bfloat16" |
Default forward autocast dtype for bench workloads. |
| Variable | Default | Description |
|---|---|---|
MSAE_CONFIG_PATH |
null |
Path to the TOML configuration file. |
MSAE_ENABLE_CONFIG_LOGGING |
true |
Enables/disables config loading logs. |
MSAE_CREATE_CONFIG_IF_MISSING |
false |
Auto-create a config file if it does not exist. |
MSAE_HF_SAE_REPO |
"Codcordance/Mistral-Small-3.2-24B-Instruct-2506-SAE-tied" |
Default Hugging Face repository used by mistral_sae checkpoint download when --repo-id is omitted. |
File logging through environment variables was removed from the runtime and CLI. Legacy variables such as
MSAE_ENABLE_FILE_LOGGING, MSAE_LOG_FILE, and MSAE_LOG_DIR are no longer read and have no effect.
Training and runtime logs are emitted to stdout/stderr, and W&B logging remains controlled by [wandb].
CONFIGis immutable at attribute level; useconfigure(...)orconfiguration(...)for overrides.- In coverage mode,
coverage.force_config=trueenforces config-only workload parameters. - For reproducibility, archive both
msae-config.tomland exportedMSAE_*variables with benchmark artifacts.