Skip to content

Commit 48debb2

Browse files
committed
backward compat for btopK sae, support for btopk sae finetuning, fix licence config
1 parent da018d1 commit 48debb2

3 files changed

Lines changed: 42 additions & 12 deletions

File tree

dictionary_learning/dictionary.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,29 @@ def from_pretrained(
667667

668668
state_dict = th.load(path, weights_only=True)
669669
dict_size, activation_dim = state_dict["encoder.weight"].shape
670+
normalization_keys = [
671+
"target_rms",
672+
"activation_mean",
673+
"activation_std",
674+
"activation_global_scale",
675+
]
676+
is_in_dict = th.tensor([k in state_dict for k in normalization_keys])
677+
if not is_in_dict.all():
678+
if is_in_dict.any():
679+
raise ValueError(
680+
f"Some normalization keys are present in the state dict but not all. Missing keys: {[n for n in normalization_keys if n not in state_dict]}"
681+
)
682+
else:
683+
warn(
684+
"No normalization keys found in the state dict. Assuming no normalization is needed. This is normal for old dictionaries."
685+
)
686+
for key in normalization_keys:
687+
state_dict[key] = (
688+
th.full((activation_dim,), th.nan)
689+
if key in ["activation_mean", "activation_std"]
690+
else th.tensor(th.nan)
691+
)
692+
670693
if k is None:
671694
k = state_dict["k"].item()
672695
elif "k" in state_dict and k != state_dict["k"].item():

dictionary_learning/trainers/batch_top_k.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
layer: int,
2222
lm_name: str,
2323
dict_class: type = BatchTopKSAE,
24+
pretrained_ae: Optional[BatchTopKSAE] = None,
2425
lr: Optional[float] = None,
2526
auxk_alpha: float = 1 / 32,
2627
warmup_steps: int = 1000,
@@ -33,7 +34,7 @@ def __init__(
3334
activation_mean: Optional[t.Tensor] = None,
3435
activation_std: Optional[t.Tensor] = None,
3536
target_rms: float = 1.0,
36-
encoder_init_norm: str = 1.0,
37+
encoder_init_norm: float = 1.0,
3738
):
3839
super().__init__(seed)
3940
assert layer is not None and lm_name is not None
@@ -51,15 +52,18 @@ def __init__(
5152
t.manual_seed(seed)
5253
t.cuda.manual_seed_all(seed)
5354

54-
self.ae = dict_class(
55-
activation_dim,
56-
dict_size,
57-
k,
58-
activation_mean=activation_mean,
59-
activation_std=activation_std,
60-
target_rms=target_rms,
61-
encoder_init_norm=encoder_init_norm,
62-
)
55+
if pretrained_ae is None:
56+
self.ae = dict_class(
57+
activation_dim,
58+
dict_size,
59+
k,
60+
activation_mean=activation_mean,
61+
activation_std=activation_std,
62+
target_rms=target_rms,
63+
encoder_init_norm=encoder_init_norm,
64+
)
65+
else:
66+
self.ae = pretrained_ae
6367

6468
if device is None:
6569
self.device = "cuda" if t.cuda.is_available() else "cpu"

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@ requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
33
build-backend = "setuptools.build_meta"
44
[tool.setuptools_scm]
55

6+
[tool.setuptools.packages.find]
7+
include = ["dictionary_learning*"]
8+
exclude = ["junk*"]
9+
610
[project]
711
dynamic = ["version"]
812
name = "dictionary_learning"
913
description = "A package for dictionary learning via sparse autoencoders on neural network activations"
1014
readme = "README.md"
1115
keywords = ["dictionary learning", "sparse autoencoder", "neural networks"]
12-
13-
license = { text = "MIT" }
16+
license-files = ["LICENSE"]
1417
classifiers = [
1518
"Programming Language :: Python :: 3",
1619
"License :: OSI Approved :: MIT License",

0 commit comments

Comments
 (0)