Skip to content

Latest commit

 

History

History
274 lines (241 loc) · 61 KB

File metadata and controls

274 lines (241 loc) · 61 KB

Configuration

Table of Contents

Overview

This document lists every runtime configuration field and all project environment variables used by MistralSAE.

Resolution Order

  1. Resolve config path from MSAE_CONFIG_PATH if set, otherwise ./msae-config.toml.
  2. If the file exists, load it into MSAEConfig.
  3. If missing and MSAE_CREATE_CONFIG_IF_MISSING=true, generate the file from defaults.
  4. If missing and auto-create is disabled, run with in-memory defaults.
  5. Runtime overrides can be applied via configure(...) or with configuration(...):.

Configuration Reference

Text Data ([textdata])

Field Values Default Description Details
shuffle_seed int 42 Seed for shuffling when loading a training dataset.
reload_shuffle_seed int 58 Seed for shuffling when reloading a training dataset.
continue_final_message bool true Whether to tell the model to continue the final message (setting to false add instruct tokens).
P_PDBOOKS float 0.97 Probability of selecting the PDBooks dataset.
P_CLAIRE float 0.01 Probability of selecting the Claire dataset.
P_TULU float 0.02 Probability of selecting the Tulu dataset.
CHUNK_LEN_PDBOOKS int 2048 Chunking length (see common/_misc.py)
CHUNK_LEN_CLAIRE int 2048

Language Model ([llm])

Field Values Default Description Details
model_id str "mistralai/Mistral-Small-3.2-24B-Instruct-2506" model id on Hugging Face.
dtype str "bfloat16" torch dtype for the model
device "cuda" | "cpu" "cuda" Device, only "cuda" or "cpu" are supported.
non_blocking bool true Whether to use non-blocking transfers to device after tokenization.
context_mode "no_grad" | "inference_mode" "inference_mode" Autograd context used for inference/benchmarking ('no_grad' or 'inference_mode').
attention_implementation "sdpa" | "flash_attention_2" "flash_attention_2" Attention backend to use. Currently only support "sdpa" and "flash_attention_2".
batch_size int 16 Batch size (number of sequence) for LLM inference.
seq_token_len int 256 Strict sequence length (in tokens) for LLM inference when compiled.
sae_layer int 30 Transformer layer after which to place the SAE. For training, only the layer up to this one are loaded and used.

SAE ([sae])

Field Values Default Description Details
device "cuda" | "cpu" "cuda" Device used for SAE training.
batch_size int 2048 Target SAE batch size in tokens for training.
num_steps int 600000 Planned number of SAE optimization steps.
base_lr float 5e-05 Base learning rate for SAE training.
grad_clip float 1.0 Gradient clipping max norm for SAE training.
log_every int 50 Trainer logging period (in optimization steps).
more_metrics_every int 50 Period (in optimization steps) for heavy diagnostic metrics. 0 disables heavy metrics; if > 0, it must be a multiple of log_every and >= log_every. When disabled, heavy diagnostics are neither computed nor logged.
input_dtype str "bfloat16" Input activation dtype for SAE training batches.
latent_multiplier float 56.0 Multiplier used to infer SAE latent dim from input dim.
latent_dimension int | null null If set, overrides latent_multiplier. When set, this overrides latent_multiplier.
encoder_type "dynamic" | "topk" "dynamic" SAE encoder type ('dynamic' or 'topk').
encoder_implementation "gather" | "scatter" | "torch" "gather" Encoder implementation ('torch', 'gather', 'scatter').
decoder_implementation "torch" | "bag" | "gather" "gather" Decoder implementation ('torch', 'bag', 'gather').
topk int 64 Top-k value when encoder_type='topk'.
topk_aux int 0 Auxiliary top-k branch size (currently not fully implemented).
topk_apply_relu bool true Whether to apply ReLU on selected TopK activations.
tied_init bool true Whether to apply tied initialization when building SAE from config.
start_with_torch bool true If true, start training with torch/torch implementations and switch later.
switch_to_config_impl_after_step int | null 50000 Step at which to switch from torch/torch to configured implementations. Used only when start_with_torch = true.
nnz_approx_threshold float 1e-08 Threshold used to compute nnz_approx metric.
lambda_value float 5.0 Weight for the sparsity term in SAE loss.
lambda_schedule_enabled bool true Enable linear schedule for lambda.
lambda_schedule_start_step int 0 Start step for lambda linear schedule.
lambda_schedule_end_step int 30000 End step for lambda linear schedule.
lambda_schedule_start_value float 0.0 Start lambda value for linear schedule.
lambda_schedule_end_value float 5.0 End lambda value for linear schedule.
lr_schedule_enabled bool true Enable linear learning-rate scale schedule.
lr_schedule_start_step int 480000 Start step for LR scale linear schedule.
lr_schedule_end_step int 600000 End step for LR scale linear schedule.
lr_schedule_start_scale float 1.0 Start LR scale for linear schedule.
lr_schedule_end_scale float 0.2 End LR scale for linear schedule.
dead_latent_resampling_enabled bool false Enable dead latent resampling and num_dead_latents logging.
dead_latent_resample_every_steps int 0 Resample period in steps.
dead_latent_resample_min_step int 1 Minimum step before enabling resampling.
dead_latent_activity_ema_decay float 0.99 EMA decay for latent activity.
dead_latent_activity_threshold float 1e-06 Dead-latent threshold on activity EMA.
dead_latent_fallback_decoder_row_norm float 0.1 Fallback decoder row norm when all latents are dead during resampling.
checkpoint_enabled bool true Enable SAE checkpointing.
checkpoint_directory str "checkpoints/sae" Directory for SAE checkpoints.
checkpoint_save_every_steps int 10000 Checkpoint save period in steps.
checkpoint_save_final bool true Save final checkpoint at training end.
checkpoint_keep_last_n int | null 2 If set, keep only the latest N checkpoints.
checkpoint_resume bool true Resume from latest checkpoint by default.
checkpoint_resume_path str | null null If set, resume specifically from this checkpoint path. When set, this has priority over automatic latest-checkpoint resume.
checkpoint_save_optimizer_state bool true Persist optimizer state into checkpoints.
checkpoint_load_optimizer_state bool true Restore optimizer state from checkpoints.
checkpoint_strict_shapes bool true Enforce strict shape checks when loading checkpoints.

Mixed Precision ([mixed_precision])

Field Values Default Description Details
master_weights_dtype str "bfloat16" Master SAE parameter dtype used for trainable weights.
forward_autocast_enabled bool false Enable autocast during SAE forward and loss computation.
forward_autocast_dtype str "bfloat16" Autocast dtype for SAE forward and loss computation.
backward_autocast_enabled bool false Enable autocast during backward. Defaults to false for FP32 gradient accumulation.
gradient_dtype str "bfloat16" Gradient dtype expected after backward.

Optimizer ([optimizer])

Field Values Default Description Details
optimizer_name "adamw" | "adam" "adamw" Optimizer used by SAE trainer.
moments_dtype str "bfloat16" AdamW optimizer moments dtype target for SAE training.
adamw_beta1 float 0.9 AdamW beta1.
adamw_beta2 float 0.999 AdamW beta2.
adamw_epsilon float 1e-15 AdamW epsilon.
adamw_weight_decay float 0.0 AdamW weight decay.

Kernel Tuning ([kernels])

Field Values Default Description Details
autotune_enabled bool true Enable Triton autotune for sparse kernels.
gather_block_tokens int 8 BLOCK_TOKENS for encoder gather kernel when autotune is disabled.
gather_block_d int 512 BLOCK_D for encoder gather kernel when autotune is disabled.
gather_num_warps int 1 num_warps for encoder gather kernel when autotune is disabled.
gather_num_stages int 4 num_stages for encoder gather kernel when autotune is disabled.
gather_norm_block_latents int 8 BLOCK_LATENTS for decoder gather kernel when autotune is disabled.
gather_norm_block_d int 128 BLOCK_D for decoder gather kernel when autotune is disabled.
gather_norm_num_warps int 1 num_warps for decoder gather kernel when autotune is disabled.
gather_norm_num_stages int 3 num_stages for decoder gather kernel when autotune is disabled.
scatter_block_d int 128 BLOCK_D for encoder scatter kernel when autotune is disabled.
scatter_num_warps int 4 num_warps for encoder scatter kernel when autotune is disabled.
scatter_num_stages int 4 num_stages for encoder scatter kernel when autotune is disabled.

Compilation ([compilation])

Field Values Default Description Details
float32_matmul_precision "highest" | "high" | "medium" "high" torch float32 matmul precision setting ('highest', 'high', or 'medium').
llm_compile_variant "none" | "inductor_default" | "inductor_reduce_overhead" | "inductor_max_autotune" | "inductor_max_autotune_cudagraphs" "inductor_max_autotune" torch.compile preset for the LLM. Can be "none", "inductor_default", "inductor_reduce_overhead", "inductor_max_autotune" or "inductor_max_autotune_cudagraphs".
llm_fullgraph bool false LLM torch.compile fullgraph flag. Must remain false with attention_implementation='flash_attention_2'.
sae_compile_variant "none" | "inductor_default" | "inductor_reduce_overhead" | "inductor_max_autotune" | "inductor_max_autotune_cudagraphs" "inductor_max_autotune" torch.compile preset for SAE training/inference. Same accepted values as llm_compile_variant.

Weights & Biases ([wandb])

Field Values Default Description Details
enabled bool true Enable Weights and Biases logging.
project str "MistralSAE" W&B project (used when enabled).
run_name str | null null If None, auto-generate.
entity str | null null Optional W&B entity.
group str | null null Optional W&B group.
tags list[str] | null null Optional W&B tags.
mode str | null null Optional W&B mode (e.g. 'offline').

App Server ([app])

Settings for the optional webui to chat with a steered model.

Field Values Default Description Details
host str "127.0.0.1" Host/IP used by the steering web app server.
port int 8000 Port used by the steering web app server.
default_feature_id int 199290 Default steering SAE feature index in the web UI. Runtime requests can override this value from the interface.
clamp_scale float 3.0 Global clamp scale multiplier applied to user steering value.
max_new_tokens int 256 Maximum number of generated tokens per response.

Coverage Workloads ([coverage])

Field Values Default Description Details
force_config bool false If true, ignore test-provided workload overrides and run coverage with config-only values. Disables test workload kwargs overrides.
num_iters int 30 Default number of measured iterations for coverage benchmarks.
num_warmup_iters int 3 Default number of warmup iterations for coverage benchmarks.
max_new_tokens int 3 Default max_new_tokens for causal coverage tests.
stream_dataset bool true Whether coverage LM batch generation streams datasets.
leave_tqdm bool true Whether progress bars are shown during coverage benches.
cleanup_memory bool true Whether coverage workloads aggressively cleanup CUDA memory.
llm_seq_lens list[int] [4096, 2048, 1024, 512, 256] Default sequence lengths for seq/batch coverage bench.
llm_batch_sizes list[int] [32, 16, 8, 4, 2, 1] Default batch sizes for seq/batch coverage bench.
compile_contexts list["no_grad" | "inference_mode"] ['no_grad', 'inference_mode'] Default contexts for LLM compile coverage bench.
compile_attn_implementations list["sdpa" | "flash_attention_2"] ['sdpa', 'flash_attention_2'] Default attention implementations for LLM compile coverage bench.
compile_variants list["none" | "inductor_default" | "inductor_reduce_overhead" | "inductor_max_autotune" | "inductor_max_autotune_cudagraphs"] ['none', 'inductor_default', 'inductor_reduce_overhead', 'inductor_max_autotune', 'inductor_max_autotune_cudagraphs'] Default compile variants for compile coverage benches.
compile_fullgraph_values list[bool] [False, True] Default fullgraph sweep values for LLM compile coverage bench. Applied only to compile variants different from none.
dynamic_parity_d_in int 64 Default d_in for dynamic encoder parity.
dynamic_parity_d_out int 512 Default d_out for dynamic encoder parity.
dynamic_parity_batch_size int 64 Default batch size for dynamic encoder parity.
dynamic_parity_target_nnz_per_sample int 64 Default target nnz/sample for dynamic encoder parity.
topk_parity_d_in int 64 Default d_in for topk encoder parity.
topk_parity_d_out int 512 Default d_out for topk encoder parity.
topk_parity_k int 16 Default k for topk encoder parity.
topk_parity_batch_size int 64 Default batch size for topk encoder parity.
encoder_bench_d_in int 5120 Default d_in for encoder benches.
encoder_bench_d_out int 286720 Default d_out for encoder benches.
encoder_bench_k int 64 Default k for topk encoder bench.
encoder_bench_batch_size int 4096 Default batch size for encoder benches.
decoder_parity_d_in int 128 Default d_in for decoder parity.
decoder_parity_d_out int 512 Default d_out for decoder parity.
decoder_parity_batch_size int 128 Default batch size for decoder parity.
decoder_parity_points_per_sample int 32 Default points_per_sample for decoder parity.
decoder_bench_d_in int 286720 Default d_in for decoder benches.
decoder_bench_d_out int 5120 Default d_out for decoder benches.
decoder_bench_batch_size int 4096 Default batch size for decoder benches.
decoder_bench_points_per_sample int 64 Default points_per_sample for decoder benches.
sae_parity_d_in int 64 Default d_in for SAE parity.
sae_parity_d_hidden int 512 Default d_hidden for SAE parity.
sae_parity_batch_size int 64 Default batch size for SAE parity.
sae_parity_k int 16 Default k for SAE topk parity.
sae_parity_topk_aux int 0 Default topk_aux for SAE parity.
sae_parity_topk_apply_relu bool true Default topk_apply_relu for SAE parity.
sae_bench_d_in int 5120 Default d_in for SAE benches.
sae_bench_d_hidden int 286720 Default d_hidden for SAE benches.
sae_bench_batch_size int 4096 Default batch size for SAE benches.
sae_bench_k int 64 Default k for SAE topk bench.
sae_bench_topk_aux int 0 Default topk_aux for SAE bench.
sae_bench_topk_apply_relu bool true Default topk_apply_relu for SAE bench.
sae_compile_d_in int 5120 Default d_in for SAE compile bench.
sae_compile_d_hidden int 286720 Default d_hidden for SAE compile bench.
sae_compile_batch_size int 4096 Default batch size for SAE compile bench.
sae_compile_k int 64 Default k for SAE compile bench.
sae_compile_topk_aux int 0 Default topk_aux for SAE compile bench.
sae_compile_topk_apply_relu bool true Default topk_apply_relu for SAE compile bench.
sae_compile_encoder_type "dynamic" | "topk" "dynamic" Default SAE encoder type for compile bench.
sae_compile_encoder_implementation "gather" | "scatter" | "torch" "gather" Default SAE encoder implementation for compile bench.
sae_compile_decoder_implementation "torch" | "bag" | "gather" "gather" Default SAE decoder implementation for compile bench.
parity_max_abs_gate float 0.001 Default max-abs gate for parity checks.
parity_mean_abs_gate float 1e-05 Default mean-abs gate for parity checks.
parity_seed int 42 Default seed for parity/bench coverage workloads.
parity_dtype str "float32" Default dtype for coverage parity workloads.
parity_master_weights_dtype str "float32" Default master weights dtype for parity workloads.
parity_forward_autocast_enabled bool false Default forward autocast for parity workloads.
parity_forward_autocast_dtype str "bfloat16" Default forward autocast dtype for parity workloads.
bench_dtype str "bfloat16" Default dtype for coverage bench workloads.
bench_master_weights_dtype str "float32" Default master weights dtype for bench workloads.
bench_forward_autocast_enabled bool true Default forward autocast for bench workloads.
bench_forward_autocast_dtype str "bfloat16" Default forward autocast dtype for bench workloads.

Environment Variables

Variable Default Description
MSAE_CONFIG_PATH null Path to the TOML configuration file.
MSAE_ENABLE_CONFIG_LOGGING true Enables/disables config loading logs.
MSAE_CREATE_CONFIG_IF_MISSING false Auto-create a config file if it does not exist.
MSAE_HF_SAE_REPO "Codcordance/Mistral-Small-3.2-24B-Instruct-2506-SAE-tied" Default Hugging Face repository used by mistral_sae checkpoint download when --repo-id is omitted.

Removed Environment Variables (Migration)

File logging through environment variables was removed from the runtime and CLI. Legacy variables such as MSAE_ENABLE_FILE_LOGGING, MSAE_LOG_FILE, and MSAE_LOG_DIR are no longer read and have no effect. Training and runtime logs are emitted to stdout/stderr, and W&B logging remains controlled by [wandb].

Operational Notes

  • CONFIG is immutable at attribute level; use configure(...) or configuration(...) for overrides.
  • In coverage mode, coverage.force_config=true enforces config-only workload parameters.
  • For reproducibility, archive both msae-config.toml and exported MSAE_* variables with benchmark artifacts.