Skip to content

Commit 275f841

Browse files
authored
Merge pull request #185 from OpenMOSS/dev
Dev
2 parents f1b0642 + 6e0a064 commit 275f841

25 files changed

Lines changed: 1444 additions & 1205 deletions

.github/workflows/docs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
name: docs
22
on:
33
push:
4-
branches:
5-
- main
4+
tags:
5+
- "v*" # Triggers on version tags
66

77
permissions:
88
contents: write

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ cython_debug/
168168
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
169169
#.idea/
170170

171+
# VS Code
172+
.vscode/
173+
171174
### Python Patch ###
172175
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
173176
poetry.toml

docs/index.md

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,24 @@ This library provides:
5151

5252
Load any Sparse Autoencoder or other sparse dictionaries in `Language-Model-SAEs` or SAELens format.
5353

54-
```python
55-
# Load Gemma Scope 2 SAE
56-
sae = AbstractSparseAutoEncoder.from_pretrained("gemma-scope-2-1b-pt-res-all:layer_12_width_16k_l0_small")
57-
```
54+
=== "Language-Model-SAEs"
55+
56+
```python
57+
# Load Llama Scope 2 Transcoder
58+
sae = AbstractSparseAutoEncoder.from_pretrained(
59+
"OpenMOSS-Team/Llama-Scope-2-Qwen3-1.7B:transcoder/8x/k128/layer12_transcoder_8x_k128",
60+
fold_activation_scale=False
61+
)
62+
```
63+
64+
=== "SAELens"
65+
66+
```python
67+
# Load Gemma Scope 2 SAE
68+
sae = AbstractSparseAutoEncoder.from_pretrained(
69+
"gemma-scope-2-1b-pt-res-all:layer_12_width_16k_l0_small",
70+
)
71+
```
5872

5973
### Training a Sparse Autoencoder
6074

docs/models/lorsa.md

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,33 @@
11
# Low-Rank Sparse Attention (Lorsa)
22

3-
Low-Rank Sparse Attention (Lorsa) is a specialized sparse dictionary architecture designed to decompose attention layers into interpretable sparse components. Unlike standard SAEs that treat attention as a black box, Lorsa explicitly models the query-key-value structure while maintaining sparsity and interpretability. Lorsa decomposes attention computations into interpretable sparse features that preserve positional information through explicit query-key attention mechanisms. This allows for fine-grained analysis of attention patterns and understanding how models route information based on both content and position.
3+
Low-Rank Sparse Attention (Lorsa) is a specialized sparse dictionary architecture designed to decompose attention layers into interpretable sparse components. Unlike standard SAEs that treat attention as a black box, Lorsa explicitly models the query-key-value structure while maintaining sparsity and interpretability.
4+
5+
Given an input sequence \(X \in \mathbb{R}^{n \times d}\), Lorsa has:
6+
7+
- \(n_{\text{qk\_heads}}\) QK heads, each with projections \(W_q^h, W_k^h \in \mathbb{R}^{d \times d_{\text{qk\_head}}}\)
8+
- \(n_{\text{ov\_heads}}\) rank-1 OV heads, each with projections \(\mathbf{w}_v^i \in \mathbb{R}^{d \times 1}\), \(\mathbf{w}_o^i \in \mathbb{R}^{1 \times d}\)
9+
10+
Every group of \(n_{\text{ov\_heads}} / n_{\text{qk\_heads}}\) consecutive OV heads shares the same QK head. Denote the QK head assigned to OV head \(i\) as \(h(i)\). The forward pass for each OV head \(i\) is:
11+
12+
\[
13+
\begin{aligned}
14+
Q^{h(i)} &= X W_q^{h(i)}, \quad K^{h(i)} = X W_k^{h(i)} \\
15+
A^{h(i)} &= \operatorname{softmax}\!\left(\frac{Q^{h(i)} {(K^{h(i)})}^\top}{\sqrt{d_{\text{qk\_head}}}}\right) \in \mathbb{R}^{n \times n} \\
16+
\tilde{\mathbf{z}}^i &= A^{h(i)}\, (X \mathbf{w}_v^i) \in \mathbb{R}^{n \times 1}
17+
\end{aligned}
18+
\]
19+
20+
The pre-activations across all OV heads are then passed through a sparsity-inducing activation function \(\sigma(\cdot)\):
21+
22+
\[
23+
[\mathbf{z}^0, \ldots, \mathbf{z}^{n_{\text{ov\_heads}}-1}] = \sigma([\tilde{\mathbf{z}}^0, \ldots, \tilde{\mathbf{z}}^{n_{\text{ov\_heads}}-1}])
24+
\]
25+
26+
The final output sums the contributions of all OV heads weighted by their activations:
27+
28+
\[
29+
\hat{Y} = \sum_{i=0}^{n_{\text{ov\_heads}}-1} \mathbf{z}^i\, (\mathbf{w}_o^i)^\top \in \mathbb{R}^{n \times d}
30+
\]
431

532
The architecture was introduced in [*Towards Understanding the Nature of Attention with Low-Rank Sparse Decomposition*](https://openreview.net/forum?id=9A2etpDFIB) (ICLR 2026), which proposes using sparse dictionary learning to address *attention superposition*—the challenge of disentangling attention-mediated interactions between features at different token positions. For detailed architectural specifications and mathematical formulations, please refer to this paper.
633

@@ -63,26 +90,30 @@ lorsa_config = LorsaConfig(
6390

6491
#### Attention Dimensions
6592

93+
We recommend setting `d_qk_head` to match the target model's head dimension. `n_qk_heads` can be freely chosen: a natural starting point is `n_qk_heads = n_heads * expansion_factor` (n_heads is the num of attention heads of target attention layer), though a smaller value is also reasonable if you want to reduce Lorsa's parameter count(not less than `n_heads`).
94+
6695
| Parameter | Type | Description | Default |
6796
|-----------|------|-------------|---------|
68-
| `n_qk_heads` | `int` | Number of query-key attention heads | Required |
69-
| `d_qk_head` | `int` | Dimension per query-key head | Required |
70-
| `n_ctx` | `int` | Maximum context length / sequence length | Required |
97+
| `n_qk_heads` | `int` | Number of QK heads. | Required |
98+
| `d_qk_head` | `int` | Dimension per QK head. | Required |
99+
| `n_ctx` | `int` | Maximum context length. | Required |
71100

72-
!!! note "Number of Value Heads"
73-
The number of value heads (output features) is automatically computed as: `n_ov_heads = expansion_factor * d_model` (same as `d_sae`). The `ov_group_size` is `n_ov_heads // n_qk_heads`.
101+
!!! note "Number of OV Heads"
102+
The number of OV heads is automatically computed as: `n_ov_heads = expansion_factor * d_model` (same as `d_sae`).
74103

75104
#### Positional Embeddings
76105

106+
It is strongly recommended to copy the positional embedding parameters directly from the target model's implementation. Incorrect settings will make it harder for Lorsa to learn the target attention patterns.
107+
77108
| Parameter | Type | Description | Default |
78109
|-----------|------|-------------|---------|
79110
| `positional_embedding_type` | `str` | Type of positional embedding: `"rotary"` or `"none"` | `"rotary"` |
80111
| `rotary_dim` | `int` | Dimension of rotary embeddings (typically `d_qk_head`) | Required |
81112
| `rotary_base` | `int` | Base for rotary embeddings frequency | `10000` |
82-
| `rotary_adjacent_pairs` | `bool` | Whether to apply RoPE on adjacent pairs vs. all dimensions | `True` |
83-
| `rotary_scale` | `int` | Scaling factor for rotary embeddings | `1` |
113+
| `rotary_adjacent_pairs` | `bool` | Whether to apply RoPE on adjacent pairs | `True` |
114+
| `rotary_scale` | `int` | Scaling factor of the head dimension for rotary embeddings | `1` |
84115

85-
#### NTK-Aware RoPE (for Llama 3.1 and 3.2 herd models)
116+
#### NTK-Aware RoPE (only for Llama 3.1 and 3.2 herd models)
86117

87118
| Parameter | Type | Description | Default |
88119
|-----------|------|-------------|---------|
@@ -92,15 +123,30 @@ lorsa_config = LorsaConfig(
92123
| `NTK_by_parts_high_freq_factor` | `float` | High-frequency component scaling factor | `1.0` |
93124
| `old_context_len` | `int` | Original context length before scaling | `2048` |
94125

95-
#### Attention Settings
126+
#### Attention Computation Details
96127

97128
| Parameter | Type | Description | Default |
98129
|-----------|------|-------------|---------|
99-
| `attn_scale` | `float \| None` | Attention scaling factor. If `None`, uses $\frac{1}{\sqrt{d_{\text{qk\_head}}}}$ | `None` |
130+
| `attn_scale` | `float | None` | Attention scaling factor. If `None`, uses $\frac{1}{\sqrt{d_{\text{qk\_head}}}}$ | `None` |
100131
| `use_post_qk_ln` | `bool` | Apply LayerNorm/RMSNorm after computing Q and K projections | `False` |
101-
| `normalization_type` | `str \| None` | Normalization type: `"LN"` (LayerNorm) or `"RMS"` (RMSNorm). Only used when `use_post_qk_ln=True` | `None` |
132+
| `normalization_type` | `str | None` | Normalization type: `"LN"` (LayerNorm) or `"RMS"` (RMSNorm). Only used when `use_post_qk_ln=True` | `None` |
102133
| `eps` | `float` | Epsilon for numerical stability in normalization | `1e-6` |
103134

135+
### Initialization Strategy
136+
137+
For Lorsa, initialization from the original model's attention weights is highly recommended:
138+
139+
```python
140+
InitializerConfig(
141+
grid_search_init_norm=True,
142+
initialize_lorsa_with_mhsa=True, # Initialize Q, K from attention weights
143+
initialize_W_D_with_active_subspace=True, # Initialize V, O from attention weights
144+
model_layer=13, # Specify layer to extract attention weights from
145+
)
146+
```
147+
148+
This initialization helps Lorsa start from a good approximation of the attention computation.
149+
104150
## Training
105151

106152
### Basic Training Setup
@@ -125,31 +171,7 @@ settings = TrainLorsaSettings(
125171
sae=LorsaConfig(
126172
hook_point_in="blocks.13.ln1.hook_normalized",
127173
hook_point_out="blocks.13.hook_attn_out",
128-
d_model=2048,
129-
expansion_factor=32,
130-
131-
# Attention configuration
132-
n_qk_heads=16,
133-
d_qk_head=128,
134-
n_ctx=2048,
135-
136-
# RoPE configuration
137-
positional_embedding_type="rotary",
138-
rotary_dim=128,
139-
rotary_base=1000000,
140-
rotary_adjacent_pairs=False,
141-
142-
# Sparsity
143-
act_fn="topk",
144-
top_k=256,
145-
146-
# Normalization
147-
use_post_qk_ln=True,
148-
normalization_type="RMS",
149-
eps=1e-6,
150-
151-
dtype=torch.float32,
152-
device="cuda",
174+
# ... other settings ...
153175
),
154176
initializer=InitializerConfig(
155177
grid_search_init_norm=True,
@@ -196,21 +218,6 @@ settings = TrainLorsaSettings(
196218
train_lorsa(settings)
197219
```
198220

199-
### Initialization Strategy
200-
201-
For Lorsa, initialization from the original model's attention weights is highly recommended:
202-
203-
```python
204-
InitializerConfig(
205-
grid_search_init_norm=True,
206-
initialize_lorsa_with_mhsa=True, # Initialize Q, K from attention weights
207-
initialize_W_D_with_active_subspace=True, # Initialize V, O from attention weights
208-
model_layer=13, # Specify layer to extract attention weights from
209-
)
210-
```
211-
212-
This initialization helps Lorsa start from a good approximation of the attention computation.
213-
214221
### Important Training Considerations
215222

216223
1. **Sequence batching**: Since Lorsa operates on sequences, `batch_size` in `ActivationFactoryConfig` represents the number of sequences (not tokens). The effective token batch size is `batch_size * n_ctx`.

docs/models/sae.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
# Sparse Autoencoder (SAE)
22

3-
Sparse Autoencoders (SAEs) are the foundational architecture for learning interpretable features from language model activations. They decompose neural network activations into sparse, interpretable features that help address the superposition problem. An SAE consists of an encoder that maps model activations to a higher-dimensional latent space and a decoder that reconstructs the original activations. The key innovation is enforcing sparsity through activation functions or regularization, which encourages the model to learn monosemantic features—where each feature represents a single concept.
3+
Sparse Autoencoders (SAEs) are the foundational architecture for learning interpretable features from language model activations. They decompose neural network activations into sparse, interpretable features that help address the superposition problem.
4+
5+
Given a model activation vector \(\mathbf{x} \in \mathbb{R}^{d_{\text{model}}}\), an SAE first **encodes** it into a high-dimensional sparse latent representation, then **decodes** it back to reconstruct the original activation:
6+
7+
\[
8+
\begin{aligned}
9+
\mathbf{z} &= \sigma(W_E \mathbf{x} + \mathbf{b}_E) \in \mathbb{R}^{d_{\text{SAE}}} \\
10+
\hat{\mathbf{x}} &= W_D \mathbf{z} + \mathbf{b}_D \in \mathbb{R}^{d_{\text{model}}}
11+
\end{aligned}
12+
\]
13+
14+
where \(W_E \in \mathbb{R}^{d_{\text{SAE}} \times d_{\text{model}}}\) and \(W_D \in \mathbb{R}^{d_{\text{model}} \times d_{\text{SAE}}}\) are the encoder and decoder weight matrices, \(\mathbf{b}_E, \mathbf{b}_D\) are bias terms, and \(\sigma(\cdot)\) is a sparsity-inducing activation function (e.g., ReLU, TopK). The model is trained to minimize the reconstruction loss \(\|\mathbf{x} - \hat{\mathbf{x}}\|^2\) while keeping \(\mathbf{z}\) sparse, encouraging each dimension of \(\mathbf{z}\) to correspond to a monosemantic feature.
415

516
The architecture was introduced in foundational works including [*Sparse Autoencoders Find Highly Interpretable Features in Language Models*](https://arxiv.org/abs/2309.08600) and [*Towards Monosemanticity: Decomposing Language Models With Dictionary Learning*](https://transformer-circuits.pub/2023/monosemantic-features). For detailed architectural specifications and mathematical formulations, please refer to these papers.
617

docs/models/transcoder.md

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,18 @@ Transcoders were introduced in the following papers: [*Automatically Identifying
66

77
## Configuration
88

9-
Transcoders use the same `SAEConfig` class as standard SAEs. All sparse dictionary models inherit common parameters from `BaseSAEConfig`. See the [Common Configuration Parameters](overview.md#common-configuration-parameters) section for the full list of inherited parameters.
9+
Transcoders use the same `SAEConfig` and `InitializerConfig` as standard SAEs. See the [SAE configuration guide](sae.md#configuration) for the full parameter reference.
1010

11-
### Transcoder-Specific Parameters
11+
The only essential difference is that `hook_point_in` and `hook_point_out` must point to **different** locations—typically the input and output of the MLP sublayer you want to decompose:
1212

1313
```python
14-
from lm_saes import SAEConfig
15-
import torch
16-
1714
transcoder_config = SAEConfig(
18-
# Transcoder-specific: different hook points
19-
hook_point_in="blocks.6.ln2.hook_normalized", # Input to MLP
20-
hook_point_out="blocks.6.hook_mlp_out", # Output from MLP
21-
use_glu_encoder=False,
22-
23-
# Common parameters (documented in Sparse Dictionaries overview)
24-
d_model=768,
25-
expansion_factor=32,
26-
act_fn="topk",
27-
top_k=64,
28-
dtype=torch.float32,
29-
device="cuda",
15+
hook_point_in="blocks.6.ln2.hook_normalized", # before MLP
16+
hook_point_out="blocks.6.hook_mlp_out", # after MLP
17+
...
3018
)
3119
```
3220

33-
| Parameter | Type | Description | Default |
34-
|-----------|------|-------------|---------|
35-
| `hook_point_in` | `str` | Hook point before the computational unit (e.g., `blocks.L.ln2.hook_normalized` for MLP input). Must differ from `hook_point_out` for transcoders | Required |
36-
| `hook_point_out` | `str` | Hook point after the computational unit (e.g., `blocks.L.hook_mlp_out` for MLP output). Must differ from `hook_point_in` for transcoders | Required |
37-
| `use_glu_encoder` | `bool` | Whether to use a Gated Linear Unit (GLU) in the encoder. GLU can improve expressiveness but increases parameter count | `False` |
38-
39-
!!! important "Transcoder vs SAE"
40-
When `hook_point_in != hook_point_out`, the configuration defines a transcoder rather than a standard SAE. This allows the model to learn the transformation between two different points in the network.
41-
4221
### Initialization Strategy
4322

4423
Proper initialization is crucial for training high-quality transcoders. We recommend the following configuration:

examples/load_hf_model.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
Load a Transcoder from HuggingFace.
3+
"""
4+
5+
import torch
6+
from transformer_lens import HookedTransformer
7+
8+
from lm_saes.abstract_sae import AbstractSparseAutoEncoder
9+
10+
# Load Gemma Scope 2 SAE from HuggingFace
11+
sae = AbstractSparseAutoEncoder.from_pretrained(
12+
"OpenMOSS-Team/Llama-Scope-2-Qwen3-1.7B:transcoder/8x/k128/layer12_transcoder_8x_k128",
13+
fold_activation_scale=False,
14+
).to("cpu")
15+
16+
print(f"Loaded SAE: {sae.cfg}")
17+
18+
# Load Gemma 3 with TransformerLens
19+
model = HookedTransformer.from_pretrained("Qwen/Qwen3-1.7B")
20+
model.to("cpu")
21+
model.eval()
22+
23+
prompt = "The capital of France is"
24+
tokens = model.to_tokens(prompt)
25+
_, cache = model.run_with_cache(tokens, names_filter=[sae.cfg.hook_point_in, sae.cfg.hook_point_out])
26+
x = cache[sae.cfg.hook_point_in]
27+
label = cache[sae.cfg.hook_point_out]
28+
29+
with torch.no_grad():
30+
feature_acts = sae.encode(x)
31+
reconstructed = sae.decode(feature_acts)
32+
33+
l0 = (feature_acts > 0).sum(dim=-1).float().mean()
34+
mse = (x.to(sae.cfg.dtype) - reconstructed).pow(2).mean()
35+
print(f"Prompt: {prompt}")
36+
print(f"Average L0: {l0.item():.1f}")
37+
print(f"Reconstruction MSE: {mse.item():.6f}")

mkdocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ nav:
5050
- Overview: models/overview.md
5151
- Sparse Autoencoder: models/sae.md
5252
- Transcoder: models/transcoder.md
53-
- Cross Layer Transcoder: models/clt.md
53+
# - Cross Layer Transcoder: models/clt.md
5454
- Low-Rank Sparse Attention: models/lorsa.md
5555
- Analyze SAEs: analyze-saes.md
5656
- Distributed Guidelines: distributed-guidelines.md

0 commit comments

Comments
 (0)