Skip to content

Commit b640f66

Browse files
polinabinder1claude
andcommitted
interpretability/sae: add Evo2 1B SAE recipe
Mirrors the existing esm2 / codonfm SAE recipes. Pipeline: chunk -> convert (Savanna->MBridge) -> predict_evo2 -> pt_to_parquet -> train Differences from esm2/codonfm are forced by Evo2 specifics: - Hyena/Megatron-Core model, no HF AutoModel path => reuses the existing `predict_evo2` CLI for inference instead of writing a custom extract.py - `pt_to_parquet.py` shim bridges predict_evo2's .pt output to the universal `sae.activation_store` parquet contract - `chunk_fasta.py` preprocessor keeps inputs within the model's trained context length (8192 bp for 1B); Hyena fftconv OOMs on long sequences even at micro-batch=1 - `train.py` is the same as codonfm's, copied verbatim per bionemo-recipes' KISS-over-DRY convention Validated end-to-end on 100 organelle sequences (Evo2 1B layer 12): loss 0.67 -> 0.045, FVU 0.90 -> 0.10, var_exp 0.10 -> 0.90, 2m14s wall. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 60deff4 commit b640f66

7 files changed

Lines changed: 645 additions & 0 deletions

File tree

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Evo2 SAE Recipe
2+
3+
Train a sparse autoencoder on Evo2 (DNA language model) residual-stream activations.
4+
5+
Pipeline:
6+
7+
```
8+
HF Savanna ckpt --convert--> MBridge ckpt
9+
|
10+
predict_evo2 --embedding-layer N (FASTA in, .pt out)
11+
|
12+
pt_to_parquet shim (.pt -> ActivationStore parquet shards)
13+
|
14+
train.py (TopK SAE)
15+
```
16+
17+
The eval / dashboard stage from the esm2 recipe is intentionally not ported in v1.
18+
19+
## Quick start (1B model, single GPU)
20+
21+
```bash
22+
bash scripts/1b.sh
23+
```
24+
25+
This will:
26+
27+
1. Convert `arcinstitute/savanna_evo2_1b_base` to MBridge format
28+
2. Run `predict_evo2` on the OpenGenome2 organelle FASTA, extracting layer-12 embeddings
29+
3. Convert the .pt outputs to parquet shards
30+
4. Train a TopK SAE (expansion=8, k=32)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
[build-system]
2+
requires = ["setuptools>=61.0"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "evo2-sae"
7+
version = "0.1.0"
8+
description = "Sparse Autoencoders for the Evo2 DNA language model"
9+
readme = "README.md"
10+
requires-python = ">=3.10"
11+
12+
dependencies = [
13+
"sae",
14+
"torch>=2.0",
15+
"numpy>=1.20",
16+
"tqdm>=4.60",
17+
"pyarrow>=10.0",
18+
]
19+
20+
[tool.setuptools.packages.find]
21+
where = ["src"]
22+
23+
[tool.uv.sources]
24+
sae = { workspace = true }
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/bin/bash
2+
# Evo2 1B SAE pipeline: convert -> predict_evo2 -> pt_to_parquet -> train.
3+
#
4+
# Assumes:
5+
# - bionemo-recipes/recipes/evo2_megatron has been built (.ci_build.sh) and
6+
# its .venv is active, providing predict_evo2 + evo2_convert_savanna_to_mbridge.
7+
# - The sae workspace package is importable in that same venv.
8+
# - HF_TOKEN is set if Savanna checkpoint repo is gated.
9+
#
10+
# Override any of these by exporting before invocation.
11+
12+
set -euo pipefail
13+
14+
EVO2_MEGATRON_DIR="${EVO2_MEGATRON_DIR:-/workspace/bionemo-framework/bionemo-recipes/recipes/evo2_megatron}"
15+
RECIPE_DIR="$(cd "$(dirname "$0")/.." && pwd)"
16+
17+
MODEL="${MODEL:-arcinstitute/savanna_evo2_1b_base}"
18+
MODEL_SIZE="${MODEL_SIZE:-evo2_1b_base}"
19+
LAYER="${LAYER:-12}"
20+
# Trained context length. 1B = 8192. Bump for 7B/40B (context-extended).
21+
CHUNK_BP="${CHUNK_BP:-8192}"
22+
23+
FASTA="${FASTA:-/data/interp/evo2/OpenGenome2/fasta/organelles/organelle_sequences.fasta.gz}"
24+
WORK_ROOT="${WORK_ROOT:-/data/interp/evo2}"
25+
26+
CKPT_DIR="${WORK_ROOT}/checkpoints/${MODEL_SIZE}_mbridge"
27+
PREDICT_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_pt"
28+
PARQUET_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_parquet"
29+
OUTPUT_DIR="${WORK_ROOT}/sae/${MODEL_SIZE}_layer${LAYER}"
30+
31+
source "${EVO2_MEGATRON_DIR}/.venv/bin/activate"
32+
33+
echo "============================================================"
34+
echo "STEP 0: Chunk FASTA to <=${CHUNK_BP} bp (model trained context)"
35+
echo "============================================================"
36+
# chunk_fasta.py reads .gz directly and writes plain .fasta; no separate gunzip needed.
37+
INPUT_STEM="$(basename "$FASTA")"
38+
INPUT_STEM="${INPUT_STEM%.gz}"
39+
INPUT_STEM="${INPUT_STEM%.fasta}"
40+
CHUNKED_FASTA="${WORK_ROOT}/scratch/${INPUT_STEM}_chunked${CHUNK_BP}.fasta"
41+
if [[ -f "$CHUNKED_FASTA" ]]; then
42+
echo "Reusing existing chunked FASTA: $CHUNKED_FASTA"
43+
else
44+
python "${RECIPE_DIR}/scripts/chunk_fasta.py" \
45+
--input "$FASTA" \
46+
--output "$CHUNKED_FASTA" \
47+
--window "$CHUNK_BP"
48+
fi
49+
FASTA="$CHUNKED_FASTA"
50+
51+
echo "============================================================"
52+
echo "STEP 1: Convert Savanna -> MBridge"
53+
echo "============================================================"
54+
if [[ ! -f "${CKPT_DIR}/latest_checkpointed_iteration.txt" ]]; then
55+
evo2_convert_savanna_to_mbridge \
56+
--savanna-ckpt-path "$MODEL" \
57+
--mbridge-ckpt-dir "$CKPT_DIR" \
58+
--model-size "$MODEL_SIZE" \
59+
--tokenizer-path "${EVO2_MEGATRON_DIR}/tokenizers/nucleotide_fast_tokenizer_512"
60+
else
61+
echo "Reusing existing checkpoint at $CKPT_DIR"
62+
fi
63+
64+
echo "============================================================"
65+
echo "STEP 2: Extract layer-${LAYER} embeddings (predict_evo2)"
66+
echo "============================================================"
67+
mkdir -p "$PREDICT_DIR"
68+
if compgen -G "${PREDICT_DIR}/predictions__*.pt" > /dev/null; then
69+
echo "Reusing existing .pt files in $PREDICT_DIR"
70+
else
71+
predict_evo2 \
72+
--fasta "$FASTA" \
73+
--ckpt-dir "$CKPT_DIR" \
74+
--output-dir "$PREDICT_DIR" \
75+
--embedding-layer "$LAYER" \
76+
--micro-batch-size 1 \
77+
--devices 1 \
78+
--write-interval batch
79+
fi
80+
81+
echo "============================================================"
82+
echo "STEP 3: Convert .pt -> parquet ActivationStore"
83+
echo "============================================================"
84+
if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then
85+
echo "Reusing existing parquet shards at $PARQUET_DIR"
86+
else
87+
python "${RECIPE_DIR}/scripts/pt_to_parquet.py" \
88+
--predict-dir "$PREDICT_DIR" \
89+
--output "$PARQUET_DIR" \
90+
--model-name "$MODEL" \
91+
--layer "$LAYER"
92+
fi
93+
94+
echo "============================================================"
95+
echo "STEP 4: Train TopK SAE"
96+
echo "============================================================"
97+
python "${RECIPE_DIR}/scripts/train.py" \
98+
--cache-dir "$PARQUET_DIR" \
99+
--model-path "$MODEL" \
100+
--layer "$LAYER" \
101+
--model-type topk \
102+
--expansion-factor 8 --top-k 32 \
103+
--auxk 64 --auxk-coef 0.03125 \
104+
--init-pre-bias \
105+
--n-epochs 3 \
106+
--batch-size 4096 \
107+
--lr 3e-4 \
108+
--log-interval 50 \
109+
--no-wandb \
110+
--output-dir "$OUTPUT_DIR" \
111+
--checkpoint-dir "${OUTPUT_DIR}/checkpoints" \
112+
--checkpoint-steps 999999
113+
114+
echo "============================================================"
115+
echo "DONE: SAE checkpoint at ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt"
116+
echo "============================================================"
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Chunk a FASTA into <=N-bp windows so predict_evo2 stays inside the model's trained context.
17+
18+
Evo2 1B was trained with seq_length=8192; longer inputs OOM in the Hyena
19+
fftconv path (intermediates scale super-linearly with L). For 7B/40B raise
20+
--window to whatever those checkpoints were context-extended to.
21+
22+
Non-overlapping windows by default. Each chunk gets a header of the form
23+
">{orig_id}:{start}-{end}" so downstream parquet can be back-mapped.
24+
"""
25+
26+
import argparse
27+
import gzip
28+
from pathlib import Path
29+
30+
31+
def parse_fasta(path: Path):
32+
"""Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz)."""
33+
opener = gzip.open if path.suffix == ".gz" else open
34+
seq_id, parts = None, []
35+
with opener(path, "rt") as f:
36+
for line in f:
37+
line = line.rstrip()
38+
if line.startswith(">"):
39+
if seq_id is not None:
40+
yield seq_id, "".join(parts)
41+
seq_id = line[1:].split()[0]
42+
parts = []
43+
else:
44+
parts.append(line)
45+
if seq_id is not None:
46+
yield seq_id, "".join(parts)
47+
48+
49+
def main():
50+
"""Read input FASTA, write non-overlapping <=window-bp chunks to output FASTA."""
51+
p = argparse.ArgumentParser()
52+
p.add_argument("--input", type=Path, required=True)
53+
p.add_argument("--output", type=Path, required=True)
54+
p.add_argument("--window", type=int, default=8192)
55+
args = p.parse_args()
56+
57+
n_in = n_out = bp_out = 0
58+
args.output.parent.mkdir(parents=True, exist_ok=True)
59+
with open(args.output, "w") as out:
60+
for seq_id, seq in parse_fasta(args.input):
61+
n_in += 1
62+
for start in range(0, len(seq), args.window):
63+
end = min(start + args.window, len(seq))
64+
chunk = seq[start:end]
65+
out.write(f">{seq_id}:{start}-{end}\n{chunk}\n")
66+
n_out += 1
67+
bp_out += len(chunk)
68+
69+
print(f"Chunked {n_in} sequences -> {n_out} chunks ({bp_out:,} bp) at window={args.window}")
70+
71+
72+
if __name__ == "__main__":
73+
main()
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Convert predict_evo2 .pt outputs to SAE ActivationStore parquet shards.
17+
18+
predict_evo2 with --embedding-layer writes dicts of:
19+
hidden_embeddings: [B, S, H] (bf16)
20+
pad_mask: [B, S] (1 = valid token, 0 = padding)
21+
seq_idx, tokens: metadata, ignored here
22+
23+
We read each file, mask out padding, flatten to [N_tokens, H], and append
24+
to an ActivationStore so train.py's load_activations() can consume it.
25+
"""
26+
27+
import argparse
28+
import json
29+
from pathlib import Path
30+
31+
import torch
32+
from sae.activation_store import ActivationStore, ActivationStoreConfig
33+
from tqdm import tqdm
34+
35+
36+
def main():
37+
"""Walk predict_evo2 .pt files, mask padding, and write to an ActivationStore."""
38+
p = argparse.ArgumentParser()
39+
p.add_argument("--predict-dir", type=Path, required=True, help="Dir containing predictions__*.pt")
40+
p.add_argument("--output", type=Path, required=True, help="ActivationStore output dir")
41+
p.add_argument("--model-name", type=str, required=True, help="Stamped into metadata.json")
42+
p.add_argument("--layer", type=int, required=True, help="Stamped into metadata.json")
43+
p.add_argument("--shard-size", type=int, default=100_000)
44+
args = p.parse_args()
45+
46+
pt_files = sorted(args.predict_dir.rglob("predictions__*.pt"))
47+
if not pt_files:
48+
raise FileNotFoundError(f"No predictions__*.pt under {args.predict_dir}")
49+
50+
store = ActivationStore(args.output, ActivationStoreConfig(shard_size=args.shard_size))
51+
n_sequences = 0
52+
for pt in tqdm(pt_files, desc="pt->parquet"):
53+
d = torch.load(pt, map_location="cpu", weights_only=False)
54+
hidden = d["hidden_embeddings"]
55+
mask = d["pad_mask"].bool()
56+
flat = hidden[mask].float()
57+
store.append(flat)
58+
n_sequences += hidden.shape[0]
59+
60+
store.finalize(metadata={"model_name": args.model_name, "layer": args.layer, "n_sequences": n_sequences})
61+
print(json.dumps(store.metadata, indent=2))
62+
63+
64+
if __name__ == "__main__":
65+
main()

0 commit comments

Comments
 (0)