Skip to content

Commit d0b1812

Browse files
authored
Merge pull request #182 from OpenMOSS/dev
Dev
2 parents 8bde367 + 54f2d19 commit d0b1812

5 files changed

Lines changed: 107 additions & 12 deletions

File tree

docs/index.md

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,37 @@ This library provides:
2424
To add our library as a project dependency, run:
2525

2626
```bash
27-
uv add lm-saes
27+
uv add lm-saes==2.0.0b16
2828
```
2929

3030
We also support [Ascend NPU](https://github.com/Ascend/pytorch) as an accelerator backend. To add our library as a project dependency with NPU dependency constraints, run:
3131

3232
```bash
33-
uv add lm-saes[npu]
33+
uv add lm-saes[npu]==2.0.0b16
3434
```
3535

3636
=== "Pip"
3737

3838
Of course, you can also directly use [pip](https://pypi.org/project/pip/) to install our library. To install our library with pip, run:
3939

4040
```bash
41-
pip install lm-saes
41+
pip install lm-saes==2.0.0b16
4242
```
4343

4444
We also support [Ascend NPU](https://github.com/Ascend/pytorch) as an accelerator backend. To install our library with NPU dependency constraints, run:
4545

4646
```bash
47-
pip install lm-saes[npu]
47+
pip install lm-saes[npu]==2.0.0b16
4848
```
4949

5050
### Load a trained Sparse Autoencoder from HuggingFace
5151

52-
WIP
52+
Load any Sparse Autoencoder or other sparse dictionaries in `Language-Model-SAEs` or SAELens format.
53+
54+
```python
55+
# Load Gemma Scope 2 SAE
56+
sae = AbstractSparseAutoEncoder.from_pretrained("gemma-scope-2-1b-pt-res-all:layer_12_width_16k_l0_small")
57+
```
5358

5459
### Training a Sparse Autoencoder
5560

@@ -121,11 +126,42 @@ train_sae(settings)
121126

122127
### Analyze a trained Sparse Autoencoder
123128

124-
WIP
129+
Requires setting up [MongoDB](https://www.mongodb.com/). See [analyze-saes](analyze-saes.md) for details.
130+
131+
```python
132+
settings = AnalyzeSAESettings(
133+
sae=PretrainedSAE(pretrained_name_or_path="path/to/sae", device="cuda"),
134+
sae_name="pythia-160m-sae",
135+
activation_factory=ActivationFactoryConfig(
136+
sources=[ActivationFactoryDatasetSource(name="SlimPajama-3B")],
137+
target=ActivationFactoryTarget.ACTIVATIONS_2D,
138+
hook_points=["blocks.6.hook_resid_post"],
139+
batch_size=16,
140+
context_size=2048,
141+
),
142+
model=LanguageModelConfig(model_name="EleutherAI/pythia-160m", device="cuda"),
143+
model_name="pythia-160m",
144+
datasets={"SlimPajama-3B": DatasetConfig(dataset_name_or_path="Hzfinfdu/SlimPajama-3B")},
145+
analyzer=FeatureAnalyzerConfig(total_analyzing_tokens=100_000_000),
146+
mongo=MongoDBConfig(),
147+
device_type="cuda",
148+
)
149+
150+
analyze_sae(settings)
151+
```
125152

126153
### Convert trained Sparse Autoencoder to SAELens format
127154

128-
WIP
155+
Requires `sae_lens` package available. Supports ReLU, JumpReLU, and TopK SAEs.
156+
157+
```python
158+
from lm_saes import SparseAutoEncoder
159+
160+
sae = SparseAutoEncoder.from_pretrained("path/to/sae")
161+
sae_saelens = sae.to_saelens(model_name="pythia-160m")
162+
```
163+
164+
You can use the `sae_saelens` with any tools compatible to SAELens.
129165

130166
## Citation
131167

examples/load_saelens_model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Load a Sparse Autoencoder from HuggingFace (using SAELens format) and use it with lm-saes.
3+
4+
Requires: uv add "lm-saes[sae_lens]"
5+
"""
6+
7+
import torch
8+
from transformer_lens import HookedTransformer
9+
10+
from lm_saes.abstract_sae import AbstractSparseAutoEncoder
11+
12+
# Load Gemma Scope 2 SAE from HuggingFace
13+
sae = AbstractSparseAutoEncoder.from_pretrained("gemma-scope-2-1b-pt-res-all:layer_12_width_16k_l0_small").to("cpu")
14+
15+
print(f"Loaded SAE: {sae.cfg}")
16+
17+
# Load Gemma 3 with TransformerLens
18+
model = HookedTransformer.from_pretrained("google/gemma-3-1b-pt")
19+
model.to("cpu")
20+
model.eval()
21+
22+
prompt = "The capital of France is"
23+
tokens = model.to_tokens(prompt)
24+
_, cache = model.run_with_cache(tokens, names_filter=[sae.cfg.hook_point_in])
25+
activations = cache[sae.cfg.hook_point_in]
26+
27+
with torch.no_grad():
28+
feature_acts = sae.encode(activations)
29+
reconstructed = sae.decode(feature_acts)
30+
31+
l0 = (feature_acts > 0).sum(dim=-1).float().mean()
32+
mse = (activations.to(sae.cfg.dtype) - reconstructed).pow(2).mean()
33+
print(f"Prompt: {prompt}")
34+
print(f"Average L0: {l0.item():.1f}")
35+
print(f"Reconstruction MSE: {mse.item():.6f}")

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ npu = ["torch-npu==2.9.0rc1"]
114114

115115
triton = ["triton"]
116116

117-
sae_lens = ["sae-lens>=6.22.3"]
117+
sae_lens = ["sae-lens>=6.37.5"]
118118

119119
[[tool.uv.index]]
120120
name = "torch-cpu"
@@ -206,4 +206,8 @@ update_changelog_on_bump = true
206206
version_files = [
207207
"README.md:pip install lm-saes==",
208208
"README.md:uv add lm-saes==",
209+
"docs/index.md:uv add lm-saes==",
210+
"docs/index.md:uv add lm-saes[npu]==",
211+
"docs/index.md:pip install lm-saes==",
212+
"docs/index.md:pip install lm-saes[npu]==",
209213
]

src/lm_saes/utils/auto.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,26 @@ def auto_infer_pretrained_sae_type(pretrained_name_or_path: str) -> PretrainedSA
4949
elif repo_exists(repo_id=repo_id):
5050
return PretrainedSAEType.HUGGINGFACE
5151

52+
likely_saelens = "/" not in repo_id
53+
if likely_saelens:
54+
if importlib.util.find_spec("sae_lens") is None:
55+
raise ValueError(
56+
f"Pretrained name or path {pretrained_name_or_path} is likely in SAELens format, but SAELens is not installed."
57+
)
58+
else:
59+
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
60+
61+
lookups = get_pretrained_saes_directory()
62+
if lookups.get(repo_id) is None:
63+
raise ValueError(
64+
f"Pretrained name or path {pretrained_name_or_path} is likely in SAELens format, but {repo_id} is not a valid SAELens release. If you are sure this is a valid SAELens release, try upgrading SAELens to the latest version."
65+
)
66+
if lookups[repo_id].saes_map.get(name) is None:
67+
raise ValueError(
68+
f"Pretrained name or path {pretrained_name_or_path} is likely in SAELens format, but {name} is not a valid ID in release {repo_id}. If you are sure this is a valid ID, try upgrading SAELens to the latest version."
69+
)
70+
return PretrainedSAEType.SAELENS
71+
5272
raise ValueError(
5373
f"Pretrained name or path {pretrained_name_or_path} is not found on disk, nor on HuggingFace, nor in SAELens."
5474
)

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)