-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathload_saelens_model.py
More file actions
35 lines (26 loc) · 1.13 KB
/
load_saelens_model.py
File metadata and controls
35 lines (26 loc) · 1.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
Load a sparse dictionary from HuggingFace (using SAELens format) and use it with llamascopium.
Requires: uv add "llamascopium[sae_lens]"
"""
import torch
from transformer_lens import HookedTransformer
from llamascopium.models.sparse_dictionary import SparseDictionary
# Load Gemma Scope 2 sparse dictionary from HuggingFace
sae = SparseDictionary.from_pretrained("gemma-scope-2-1b-pt-res-all:layer_12_width_16k_l0_small").to("cpu")
print(f"Loaded sparse dictionary: {sae.cfg}")
# Load Gemma 3 with TransformerLens
model = HookedTransformer.from_pretrained("google/gemma-3-1b-pt")
model.to("cpu")
model.eval()
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens, names_filter=[sae.cfg.hook_point_in])
activations = cache[sae.cfg.hook_point_in]
with torch.no_grad():
feature_acts = sae.encode(activations)
reconstructed = sae.decode(feature_acts)
l0 = (feature_acts > 0).sum(dim=-1).float().mean()
mse = (activations.to(sae.cfg.dtype) - reconstructed).pow(2).mean()
print(f"Prompt: {prompt}")
print(f"Average L0: {l0.item():.1f}")
print(f"Reconstruction MSE: {mse.item():.6f}")