Skip to content

Commit 95281ac

Browse files
committed
feat: support resuming SAE training from pretrained checkpoints
1 parent b5fbd92 commit 95281ac

4 files changed

Lines changed: 39 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ dependencies = [
2525
"more-itertools>=10.7.0",
2626
"json-repair>=0.44.1",
2727
"cattrs>=26.1.0",
28+
"jupyter>=1.1.1",
29+
"ipykernel>=7.2.0",
2830
]
2931
requires-python = ">=3.11,<3.13"
3032
readme = "README.md"

src/llamascopium/initializer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ def initialize_sae_from_config(
183183
device_mesh=device_mesh,
184184
)
185185

186+
if cfg.sae_pretrained_name_or_path is not None:
187+
return sae
188+
186189
sae = self.initialize_parameters(sae)
187190
if sae.cfg.norm_activation == "dataset-wise":
188191
if activation_norm is None:

src/llamascopium/models/sparse_dictionary.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@ class SparseDictionaryConfig(BaseModelConfig, ABC):
132132
top_k: int = 50
133133
"""The k value to use for the topk family of activation functions. For vanilla TopK, the L0 norm of the feature activations will be exactly equal to `top_k`."""
134134

135+
sae_pretrained_name_or_path: str | None = None
136+
"""Optional pretrained SAE path or identifier used to restore model weights."""
137+
138+
strict_loading: bool = True
139+
"""Whether to strictly enforce an exact state_dict key match when loading pretrained weights."""
140+
135141
use_triton_kernel: bool = False
136142
"""Whether to use the Triton SpMM kernel for the sparse matrix multiplication. Currently only supported for vanilla SAE."""
137143

@@ -148,7 +154,7 @@ def d_sae(self) -> int:
148154
return d_sae
149155

150156
@classmethod
151-
def from_pretrained(cls, pretrained_name_or_path: str, **kwargs):
157+
def from_pretrained(cls, pretrained_name_or_path: str, strict_loading: bool = True, **kwargs):
152158
"""Load the config of the sparse dictionary from a pretrained name or path. Config is read from <pretrained_name_or_path>/config.json (for local storage) or <repo_id>/<name>/config.json (for HuggingFace Hub).
153159
154160
Args:
@@ -171,6 +177,9 @@ def from_pretrained(cls, pretrained_name_or_path: str, **kwargs):
171177
with open(path, "r") as f:
172178
sae_config = json.load(f)
173179

180+
sae_config["sae_pretrained_name_or_path"] = pretrained_name_or_path
181+
sae_config["strict_loading"] = strict_loading
182+
174183
if cls is SparseDictionaryConfig:
175184
cls = SAE_TYPE_TO_CONFIG_CLASS[sae_config["sae_type"]]
176185

@@ -179,6 +188,9 @@ def from_pretrained(cls, pretrained_name_or_path: str, **kwargs):
179188
def save_hyperparameters(self, sae_path: str | Path, remove_loading_info: bool = True):
180189
assert os.path.exists(sae_path), f"{sae_path} does not exist. Unable to save hyperparameters."
181190
d = self.model_dump()
191+
if remove_loading_info:
192+
d.pop("sae_pretrained_name_or_path", None)
193+
d.pop("strict_loading", None)
182194

183195
with open(os.path.join(sae_path, "config.json"), "w") as f:
184196
json.dump(d, f, indent=4)
@@ -517,6 +529,16 @@ def from_config(cls, cfg: SparseDictionaryConfig, device_mesh: DeviceMesh | None
517529
if cls is SparseDictionary:
518530
cls = SAE_TYPE_TO_MODEL_CLASS[cfg.sae_type]
519531

532+
if cfg.sae_pretrained_name_or_path is not None:
533+
return cls.from_pretrained(
534+
cfg.sae_pretrained_name_or_path,
535+
device_mesh=device_mesh,
536+
fold_activation_scale=False,
537+
strict_loading=cfg.strict_loading,
538+
device=cfg.device,
539+
dtype=cfg.dtype,
540+
)
541+
520542
model = cls(cfg, device_mesh)
521543
total_params = sum(param.numel() for param in model.parameters()) / 1e9
522544
logger.info(f"Initializing {cfg.sae_type} with {total_params:.2f} B parameters")
@@ -534,7 +556,8 @@ def from_local(
534556
):
535557
"""Load a pretrained sparse dictionary from a local directory."""
536558

537-
cfg = SparseDictionaryConfig.from_pretrained(path, **kwargs)
559+
cfg = SparseDictionaryConfig.from_pretrained(path, strict_loading=strict_loading, **kwargs)
560+
cfg.sae_pretrained_name_or_path = None
538561
model = cls.from_config(cfg, device_mesh=device_mesh)
539562

540563
if path.endswith(".pt") or path.endswith(".safetensors") or path.endswith(".dcp"):

uv.lock

Lines changed: 9 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)