Skip to content

Commit e407165

Browse files
evo2 SAE recipe: streaming extract + train (7B layer-26) (NVIDIA-BioNeMo#1621)
## What Adds the **Evo2 SAE recipe** under `recipes/evo2/` — the runnable pipeline that turns Evo2 activations into a trained SAE: ``` chunk FASTA → stream-extract (extract.py) → train (train.py) ``` `scripts/7b.sh` runs all three and reproduces our **layer-26 7B (`normalize_input`)** run. ## Contents - **`extract.py`** — streaming activation extractor: reuses `predict_evo2` for the Megatron forward but writes layer-L activations **directly to a parquet `ActivationStore`** (no intermediate `.pt`). Model-agnostic (1B/7B/40B). - **`train.py`** — trains a TopK/ReLU SAE from the activation cache. Never loads the model (reads only the cache; `--model-path`/`--layer` validate metadata only). - **`chunk_fasta.py`**, **`7b.sh`** (orchestrator), **`pyproject.toml`**. _(README/usage docs deferred to a follow-up, once the pipeline is merged and exercised end-to-end.)_ ## Opt-in training fixes (from the merged `sae` PR NVIDIA-BioNeMo#1619) All default to the **previous behavior** — omit them to reproduce a baseline run exactly: | flag | effect | |---|---| | `--aggregate-loss` | batch-level FVU + AuxK loss vs the per-token ratio | | `--dead-count-global` | count dead-latent inactivity in total tokens (× world_size) under DDP | | `--mix-shards N` | shuffle + blend N shards/batch (replaces the old `--shards-per-buffer`) | | `--presample-shards N` | spread the pre-bias-init sample across N shards | The DDP per-epoch batch cap is computed from each rank's assigned shards + `all_reduce(MIN)`, so it stays correct when `mix_shards>1` shuffles the shard list. ## Model format Assumes a local **Evo2 MBridge checkpoint** (`--ckpt-dir`, loaded by `predict_evo2`). Getting it — **NGC pull** or nemo2→MBridge convert — is a documented **prerequisite**, not recipe code (no Savanna conversion baked in). (Contrast CodonFM, whose Encodon model is an HF/TransformerEngine `.safetensors` — different model family + runtime, hence the different extractor.) ## Duplicate code & planned consolidation This recipe deliberately **mirrors the existing CodonFM recipe** rather than deduplicating now: - **`train.py` is essentially the same file as `codonfm/scripts/train.py`** — byte-identical before the Evo2 skin + the four opt-in flags. Both are thin wrappers over the shared `sae.training.Trainer`. - **`extract.py` shares its skeleton with `codonfm/scripts/extract.py`** — both run a model forward and stream into a `sae.ActivationStore`, including a near-identical **per-rank shard merge** (`_merge_temp_stores` ≈ CodonFM's `_merge_rank_stores`). Only the model-loading differs (Evo2 via Megatron `predict_evo2` vs CodonFM's HF/TE loader). Consolidating either means factoring the shared **train-CLI** and the **parquet-shard merge** into the `sae` package **and migrating CodonFM onto them** — a cross-recipe refactor touching another recipe. To keep this PR single-concern (and avoid re-tangling, which this PR stack exists to undo), that dedup is a **planned follow-up** (alongside the duplicated-dashboard-component dedup), not part of this PR. ## Tests Following the existing recipes' convention (CodonFM, ESM2 ship **no recipe-level tests**), the tested logic lives in the **`sae` package** — e.g. the opt-in flags' config round-trip is covered by `sae/tests/test_topk.py` (NVIDIA-BioNeMo#1619). This recipe is a thin driver over that tested package. The streaming `extract.py` and the DDP shard-count path need a multi-GPU run to verify. ## Supersedes Replaces NVIDIA-BioNeMo#1579 (recipe) and NVIDIA-BioNeMo#1583 (tangled extract+recipe), which also smuggled the now-merged `evo2_megatron` (NVIDIA-BioNeMo#1618) and `sae` (NVIDIA-BioNeMo#1619) changes. Closed both in favor of this clean, single-concern recipe carved off `main`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Signed-off-by: Polina Binder <pbinder@nvidia.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 8e5a865 commit e407165

5 files changed

Lines changed: 854 additions & 0 deletions

File tree

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
[build-system]
2+
requires = ["setuptools>=61.0"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "evo2-sae"
7+
version = "0.1.0"
8+
description = "Sparse Autoencoders for the Evo2 DNA language model"
9+
requires-python = ">=3.10"
10+
11+
dependencies = [
12+
"sae",
13+
"torch>=2.0",
14+
"numpy>=1.20",
15+
"pyarrow>=23.0.0",
16+
]
17+
18+
# No package code lives here yet — the recipe is just an entry-point for
19+
# scripts/ that depends on the shared `sae` workspace package. Declare no
20+
# packages so setuptools doesn't try to discover anything.
21+
[tool.setuptools]
22+
packages = []
23+
24+
[tool.uv.sources]
25+
sae = { workspace = true }
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/bin/bash
2+
# Evo2 7B layer-26 SAE recipe: chunk FASTA -> stream-extract activations -> train SAE.
3+
# This reproduces the layer26_7B (normalize_input) run.
4+
#
5+
# Prerequisites (this recipe does NOT download or convert the model):
6+
# - An Evo2 7B *MBridge* checkpoint directory (CKPT_DIR). Obtain it from NGC, e.g.:
7+
# ngc registry model download-version "nvidia/clara/evo2:7b_<ver>" --dest "${WORK_ROOT}/checkpoints"
8+
# (or convert a nemo2 checkpoint to MBridge with the evo2_megatron converter).
9+
# - bionemo-recipes/recipes/evo2_megatron built (.ci_build.sh) with its .venv active,
10+
# providing `predict_evo2`.
11+
# - The `sae` workspace package importable in that same venv.
12+
#
13+
# Override any of these by exporting before invocation.
14+
15+
set -euo pipefail
16+
17+
EVO2_MEGATRON_DIR="${EVO2_MEGATRON_DIR:-/workspace/bionemo-framework/bionemo-recipes/recipes/evo2_megatron}"
18+
RECIPE_DIR="$(cd "$(dirname "$0")/.." && pwd)"
19+
20+
LAYER="${LAYER:-26}"
21+
# Context length the activations were extracted at (the model is context-extended; we
22+
# trained the SAE on 8192-bp chunks).
23+
CHUNK_BP="${CHUNK_BP:-8192}"
24+
25+
# An Evo2 7B MBridge checkpoint directory (see prerequisites above).
26+
CKPT_DIR="${CKPT_DIR:?Set CKPT_DIR to an Evo2 7B MBridge checkpoint directory (see header)}"
27+
FASTA="${FASTA:?Set FASTA to the (prok+euk) input sequences}"
28+
WORK_ROOT="${WORK_ROOT:-/data/interp/evo2}"
29+
30+
NPROC="${NPROC:-8}" # GPUs / DP ranks
31+
MAX_TOKENS="${MAX_TOKENS:-1000000000}"
32+
33+
PARQUET_DIR="${WORK_ROOT}/activations/evo2_7b_layer${LAYER}_parquet"
34+
OUTPUT_DIR="${WORK_ROOT}/sae/evo2_7b_layer${LAYER}"
35+
36+
source "${EVO2_MEGATRON_DIR}/.venv/bin/activate"
37+
38+
echo "============================================================"
39+
echo "STEP 0: Chunk FASTA to <=${CHUNK_BP} bp"
40+
echo "============================================================"
41+
INPUT_STEM="$(basename "$FASTA")"; INPUT_STEM="${INPUT_STEM%.gz}"; INPUT_STEM="${INPUT_STEM%.fasta}"
42+
CHUNKED_FASTA="${WORK_ROOT}/scratch/${INPUT_STEM}_chunked${CHUNK_BP}.fasta"
43+
if [[ -f "$CHUNKED_FASTA" ]]; then
44+
echo "Reusing existing chunked FASTA: $CHUNKED_FASTA"
45+
else
46+
python "${RECIPE_DIR}/scripts/chunk_fasta.py" --input "$FASTA" --output "$CHUNKED_FASTA" --window "$CHUNK_BP"
47+
fi
48+
49+
echo "============================================================"
50+
echo "STEP 1: Stream-extract layer-${LAYER} activations -> parquet ActivationStore (no .pt)"
51+
echo "============================================================"
52+
if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then
53+
echo "Reusing existing parquet shards at $PARQUET_DIR"
54+
else
55+
torchrun --nproc_per_node="$NPROC" "${RECIPE_DIR}/scripts/extract.py" \
56+
--ckpt-dir "$CKPT_DIR" \
57+
--embedding-layer "$LAYER" \
58+
--fasta "$CHUNKED_FASTA" \
59+
--activation-store-dir "$PARQUET_DIR" \
60+
--max-tokens "$MAX_TOKENS" \
61+
--micro-batch-size 4 \
62+
--dtype fp32
63+
fi
64+
65+
echo "============================================================"
66+
echo "STEP 2: Train TopK SAE (layer26_7B normalize_input config)"
67+
echo "============================================================"
68+
# unset a leaked key so ~/.netrc wins; clara-discovery is the wandb entity.
69+
unset WANDB_API_KEY || true
70+
export WANDB_ENTITY="${WANDB_ENTITY:-clara-discovery}"
71+
torchrun --nproc_per_node="$NPROC" "${RECIPE_DIR}/scripts/train.py" \
72+
--cache-dir "$PARQUET_DIR" \
73+
--model-path "$CKPT_DIR" \
74+
--layer "$LAYER" \
75+
--model-type topk \
76+
--expansion-factor 16 --top-k 128 \
77+
--normalize-input \
78+
--auxk 2048 --auxk-coef 0.03125 \
79+
--dead-tokens-threshold 10000000 \
80+
--init-pre-bias \
81+
--n-epochs 1 \
82+
--batch-size 1024 \
83+
--lr 1e-4 --lr-schedule cosine --lr-min 1e-5 --warmup-steps 1000 \
84+
--max-grad-norm 1.0 \
85+
--mix-shards 10 \
86+
--dp-size "$NPROC" \
87+
--log-interval 100 \
88+
--wandb --wandb-project evo2-sae-v2-diverse --wandb-run-name "layer${LAYER}_7B_normalize_input" \
89+
--output-dir "$OUTPUT_DIR" \
90+
--checkpoint-dir "${OUTPUT_DIR}/checkpoints" \
91+
--checkpoint-steps 2000
92+
93+
echo "============================================================"
94+
echo "DONE: SAE checkpoint at ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt"
95+
echo "============================================================"
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Chunk a FASTA into <=N-bp windows so predict_evo2 stays inside the model's trained context.
17+
18+
Evo2 1B was trained with seq_length=8192; longer inputs OOM in the Hyena
19+
fftconv path (intermediates scale super-linearly with L). For 7B/40B raise
20+
--window to whatever those checkpoints were context-extended to.
21+
22+
Non-overlapping windows by default. Each chunk gets a header of the form
23+
">{orig_id}:{start}-{end}" so downstream parquet can be back-mapped.
24+
"""
25+
26+
import argparse
27+
import gzip
28+
from pathlib import Path
29+
30+
31+
def parse_fasta(path: Path):
32+
"""Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz)."""
33+
opener = gzip.open if path.suffix == ".gz" else open
34+
seq_id, parts = None, []
35+
with opener(path, "rt") as f:
36+
for line in f:
37+
line = line.rstrip()
38+
if line.startswith(">"):
39+
if seq_id is not None:
40+
yield seq_id, "".join(parts)
41+
seq_id = line[1:].split()[0]
42+
parts = []
43+
else:
44+
parts.append(line)
45+
if seq_id is not None:
46+
yield seq_id, "".join(parts)
47+
48+
49+
def main():
50+
"""Read input FASTA, write non-overlapping <=window-bp chunks to output FASTA."""
51+
p = argparse.ArgumentParser()
52+
p.add_argument("--input", type=Path, required=True)
53+
p.add_argument("--output", type=Path, required=True)
54+
p.add_argument("--window", type=int, default=8192)
55+
args = p.parse_args()
56+
if args.window <= 0:
57+
p.error("--window must be a positive integer")
58+
if args.input.resolve() == args.output.resolve():
59+
p.error("--input and --output must be different files")
60+
61+
n_in = n_out = bp_out = 0
62+
args.output.parent.mkdir(parents=True, exist_ok=True)
63+
with open(args.output, "w") as out:
64+
for seq_id, seq in parse_fasta(args.input):
65+
n_in += 1
66+
for start in range(0, len(seq), args.window):
67+
end = min(start + args.window, len(seq))
68+
chunk = seq[start:end]
69+
out.write(f">{seq_id}:{start}-{end}\n{chunk}\n")
70+
n_out += 1
71+
bp_out += len(chunk)
72+
73+
print(f"Chunked {n_in} sequences -> {n_out} chunks ({bp_out:,} bp) at window={args.window}")
74+
75+
76+
if __name__ == "__main__":
77+
main()

0 commit comments

Comments
 (0)