Skip to content

Commit 6aa5103

Browse files
sae: config-gated dead-latent/FVU training fixes + inter-shard shuffle & pre-bias init shard sampling (NVIDIA-BioNeMo#1619)
## Why Training TopK SAEs on Evo2 activations hit a severe **dead-latent** problem (a large fraction of features never fired, wasting capacity). `normalize_input` (already merged) fixed most of it; this PR adds the remaining **training-dynamics fixes we found necessary for Evo2 SAE training**. **Every change defaults to the previous behavior and is opt-in** — so you can reproduce or continue prior training runs **exactly as before**, and enable each fix only when you want it. The training recipe opts in; both `topk` options serialize in the checkpoint config so a reloaded SAE keeps its behavior. ## Changes — default = previous behavior, opt in per flag **1. Dead-latent inactivity counted in *total* tokens** — `dead_count_global` (default `False` = previous per-rank count) The auxk revival fires once a latent has been inactive for `dead_tokens_threshold` (10M) tokens, but the counter advanced by *this rank's* micro-batch — so under DDP it ran `world_size`× too slow and revival kicked in `world_size`× too late (≈80M effective tokens on 8 GPUs). Opt in with `dead_count_global=True` to count total tokens (× world_size); the `all_reduce(MIN)` still means "fired on any rank ⇒ reset." **2. Aggregate FVU + auxk loss** — `aggregate_loss` (bool, default `False` = previous per-token) The per-token loss ratio `mean_t(mse_t / var_t)` down-weights rare high-variance tokens, starving the latents that specialize on them (notably Evo2's heavy-tailed **sink tokens**) → they die. Opt in with `aggregate_loss=True` for a batch-level ratio (which also matches the reported `var_exp` metric). This single bool also fixes the **auxk residual** end-to-end: `False` keeps the previous `x - recon + pre_bias`; `True` uses the corrected `x - recon` (the true error, not `pre_bias`-dominated). **3. Shuffle + blend shards** — `mix_shards` (int, default `1` = previous) Shards are written in corpus order (all prokaryota, then all eukaryota). A contiguous per-rank slice trains a rank on one kingdom then switches mid-epoch → a visible **FVU cliff**. `mix_shards=1` (default) = previous behavior (one shard at a time, contiguous slice). Set `mix_shards=N>1` to **globally shuffle the shard list** before the per-rank split (so each rank gets a cross-section) **and** buffer/blend N shards per batch (≈N shards of peak RAM). **4. Spread the pre-bias-init sample** — `sample(num_shards=…)` (default `1` = previous single shard) `pre_bias` is initialized to the geometric median of a sample of activations (so the SAE starts centered). A single-shard sample biases it toward whatever is first in corpus order (one kingdom) → mis-centered init → more dead latents. Set `num_shards>1` to draw the sample across that many random shards spanning the store (≈one shard of peak RAM — each sub-sampled then freed). ## How to opt in (what the Evo2 recipe sets) ```python TopKSAE(..., aggregate_loss=True, dead_count_global=True) store.get_streaming_dataloader(..., mix_shards=8) # shuffle + blend 8 shards pre_bias0 = geometric_median(store.sample(n, num_shards=8)) # sample across 8 shards ``` The training recipe (separate PR) exposes these as CLI flags (`--aggregate-loss`, `--dead-count-global`, `--mix-shards`, `--presample-shards`). ## Opt-out summary | behavior | knob | default | opt in | |---|---|---|---| | global dead-token count | `dead_count_global` (bool) | `False` | `True` | | aggregate FVU + auxk loss | `aggregate_loss` (bool) | `False` | `True` | | shard shuffle + blending | `mix_shards` (int) | `1` | `>1` | | spread pre-init sample | `sample(num_shards=)` | `1` | `>1` | ## Tests — `sae/tests/test_topk.py` (CPU, no GPU) global-vs-local dead-token counting, the aggregate-FVU formula (`mse.mean()/var.mean()`), and that the opted-in flags round-trip through `_get_config()`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Signed-off-by: Polina Binder <pbinder@nvidia.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 78e9fa0 commit 6aa5103

3 files changed

Lines changed: 205 additions & 32 deletions

File tree

bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/activation_store.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -332,18 +332,25 @@ def get_streaming_dataloader(
332332
rank: int = 0,
333333
world_size: int = 1,
334334
max_shards: Optional[int] = None,
335+
mix_shards: int = 1,
335336
) -> DataLoader:
336337
"""Get a streaming DataLoader that reads one shard at a time from disk.
337338
338-
Each rank gets a disjoint slice of shards. Peak RAM per rank is ~1 shard.
339+
Each rank gets a disjoint slice of shards. Peak RAM per rank is ~mix_shards shards.
339340
340341
Args:
341342
batch_size: Batch size for training
342-
shuffle: Whether to shuffle shard order and within-shard data
343+
shuffle: Whether to shuffle within-shard (and within-buffer) data
343344
seed: Random seed for reproducibility
344345
rank: This rank's index (0-indexed)
345346
world_size: Total number of ranks
346347
max_shards: Limit total shards used (for subsampling). None = all.
348+
mix_shards: How many shards to blend together. 1 (default) = previous behavior
349+
(one shard at a time, contiguous per-rank slice, no global shuffle). >1
350+
globally shuffles the shard list before the per-rank split (so each rank gets
351+
a cross-section — found needed for Evo2 training, where shards are kingdom-
352+
ordered, to avoid an fvu cliff) AND buffers/mixes that many shards per batch
353+
(at ~mix_shards shards of peak RAM).
347354
348355
Returns:
349356
DataLoader yielding [batch_size, hidden_dim] tensors
@@ -357,8 +364,16 @@ def get_streaming_dataloader(
357364
if n_total > 1 and pq.read_metadata(last_shard_path).num_rows < shard_size:
358365
n_total -= 1
359366

360-
# Assign equal shards to each rank (drop remainder to keep DDP in sync)
367+
# When mixing (mix_shards > 1), shuffle the shard list BEFORE splitting across ranks
368+
# so each rank gets a random cross-section of the whole parquet, not a contiguous
369+
# slice. Found needed for Evo2 training, where shards are sequence-ordered (e.g. all
370+
# prok then all euk): a contiguous per-rank slice trains a rank on one kingdom then
371+
# switches, causing an fvu cliff. Deterministic across ranks via the shared seed.
372+
# mix_shards == 1 keeps the previous contiguous behavior. Then assign equal shards per
373+
# rank (drop remainder to keep DDP in sync).
361374
all_indices = list(range(n_total))
375+
if mix_shards > 1:
376+
np.random.default_rng(seed if seed is not None else 0).shuffle(all_indices)
362377
per_rank = n_total // world_size
363378
my_indices = all_indices[rank * per_rank : (rank + 1) * per_rank]
364379

@@ -368,11 +383,46 @@ def get_streaming_dataloader(
368383
batch_size=batch_size,
369384
shuffle=shuffle,
370385
seed=seed,
386+
mix_shards=mix_shards,
371387
)
372388

373389
# batch_size=None: dataset already yields pre-formed batches
374390
return DataLoader(dataset, batch_size=None, num_workers=0)
375391

392+
def sample(self, n: int, seed: int = 0, num_shards: int = 1) -> torch.Tensor:
393+
"""Return ~n activation rows for pre-bias (geometric-median) init.
394+
395+
Defaults to a single shard, i.e. the previous behavior. Set ``num_shards`` > 1 to
396+
draw from that many random shards spanning the whole parquet: we found this needed
397+
for Evo2 SAE training, where shards are written in corpus order (e.g. all prokaryota
398+
then all eukaryota), so a single-shard sample biases the geometric-median pre-bias
399+
toward one kingdom and worsens dead latents. Peak RAM ~one shard (each is sub-sampled
400+
then freed before the next).
401+
402+
Args:
403+
n: Number of activation rows to return.
404+
seed: RNG seed for the shard/row sampling and the final permutation; sampling is
405+
deterministic given this seed.
406+
num_shards: Number of shards to sample across, clamped to ``[1, self.n_shards]``.
407+
1 (default) = previous single-shard behavior; >1 spreads the sample.
408+
409+
Returns:
410+
A float ``torch.Tensor`` of shape ``(n, D)`` of sampled pre-bias activation rows
411+
(``torch.from_numpy`` on concatenated per-shard slices loaded via
412+
``self._load_shard``), deterministic for the given ``seed``.
413+
"""
414+
rng = np.random.default_rng(seed)
415+
k = min(self.n_shards, max(1, num_shards))
416+
chosen = rng.choice(self.n_shards, size=k, replace=False)
417+
per = -(-n // k) # ceil(n / k)
418+
parts = []
419+
for i in chosen:
420+
shard = self._load_shard(int(i))
421+
take = min(per, len(shard))
422+
parts.append(shard[rng.choice(len(shard), size=take, replace=False)])
423+
rows = torch.from_numpy(np.concatenate(parts)).float()
424+
return rows[torch.randperm(len(rows), generator=torch.Generator().manual_seed(seed))][:n]
425+
376426
def get_dataloader(
377427
self,
378428
batch_size: int = 4096,
@@ -491,12 +541,16 @@ def __init__(
491541
batch_size: int = 4096,
492542
shuffle: bool = True,
493543
seed: Optional[int] = None,
544+
mix_shards: int = 1,
494545
):
495546
self.store = store
496547
self.shard_indices = shard_indices
497548
self.batch_size = batch_size
498549
self.shuffle = shuffle
499550
self.seed = seed
551+
# Shards to accumulate before flushing batches. >1 mixes rows across that
552+
# many shards (true inter-shard shuffling) instead of one shard at a time.
553+
self.mix_shards = max(1, mix_shards)
500554
self.max_batches = None # Set externally to cap iteration (for DDP sync)
501555

502556
# Approximate length: total tokens in assigned shards / batch_size
@@ -511,20 +565,29 @@ def __iter__(self) -> Iterator[torch.Tensor]:
511565
rng.shuffle(indices)
512566

513567
buffer = None
568+
shards_loaded = 0
514569
n_yielded = 0
515-
for shard_idx in indices:
570+
for shard_pos, shard_idx in enumerate(indices):
516571
shard = torch.from_numpy(self.store._load_shard(shard_idx)).float()
517572
if self.shuffle:
518573
shard = shard[torch.randperm(len(shard))]
519-
520574
buffer = torch.cat([buffer, shard]) if buffer is not None else shard
521-
522-
while len(buffer) >= self.batch_size:
523-
if self.max_batches is not None and n_yielded >= self.max_batches:
524-
return
525-
yield buffer[: self.batch_size]
526-
buffer = buffer[self.batch_size :]
527-
n_yielded += 1
575+
shards_loaded += 1
576+
577+
# Flush only once mix_shards shards are buffered (or this is
578+
# the last shard), shuffling the whole buffer first so each batch
579+
# mixes rows from that many different parts of the parquet.
580+
is_last = shard_pos == len(indices) - 1
581+
if shards_loaded >= self.mix_shards or is_last:
582+
if self.shuffle and self.mix_shards > 1:
583+
buffer = buffer[torch.randperm(len(buffer))]
584+
while len(buffer) >= self.batch_size:
585+
if self.max_batches is not None and n_yielded >= self.max_batches:
586+
return
587+
yield buffer[: self.batch_size]
588+
buffer = buffer[self.batch_size :]
589+
n_yielded += 1
590+
shards_loaded = 0
528591

529592
# Yield remainder as a partial batch (skip if capped)
530593
if self.max_batches is None and buffer is not None and len(buffer) > 0:

bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, Dict, Optional, Tuple
2323

2424
import torch
25+
import torch.distributed as dist
2526
import torch.nn as nn
2627
import torch.nn.functional as F
2728

@@ -59,6 +60,14 @@ class TopKSAE(SparseAutoencoder):
5960
auxk: Number of auxiliary latents for dead latent loss (None = disabled)
6061
auxk_coef: Coefficient for auxiliary loss (default: 1/32)
6162
dead_tokens_threshold: Tokens of inactivity before latent is considered dead (default 10M per Gao et al.)
63+
aggregate_loss: If False (default), reduce the FVU and AuxK losses per-token (the
64+
previous mean-of-per-row ratios). If True, use a single batch-level
65+
``mse.mean() / var.mean()`` ratio, which stops rare high-variance tokens from
66+
being down-weighted (and thus their latents dying).
67+
dead_count_global: If True, accumulate dead-latent inactivity counts across all DDP
68+
ranks (total tokens = micro-batch x world_size); if False (default), count this
69+
rank's micro-batch only. True makes the dead-threshold / AuxK revival fire on time
70+
under data parallelism.
6271
init_encoder_from_decoder: If True, initialize encoder weights as transpose
6372
of decoder weights. From OpenAI paper: this + AuxK → nearly 0% dead latents.
6473
"""
@@ -72,6 +81,8 @@ def __init__(
7281
auxk: Optional[int] = None,
7382
auxk_coef: float = 1 / 32,
7483
dead_tokens_threshold: int = 10_000_000,
84+
aggregate_loss: bool = False,
85+
dead_count_global: bool = False,
7586
init_encoder_from_decoder: bool = True,
7687
init_pre_bias: bool = True,
7788
decoder_impl: str = "dense",
@@ -94,6 +105,12 @@ def __init__(
94105
if decoder_impl not in ("dense", "triton"):
95106
raise ValueError(f"decoder_impl must be 'dense' or 'triton', got {decoder_impl!r}")
96107
self.decoder_impl = decoder_impl
108+
# False (default = previous per-token reduction) | True (batch-level aggregate FVU/auxk
109+
# ratio; opt in to fix dead latents starved by the per-token ratio on rare high-var tokens).
110+
self.aggregate_loss = aggregate_loss
111+
# False (default = previous per-rank count) | True (count inactivity in TOTAL tokens,
112+
# x world_size, so dead-latent revival fires on time under DDP; opt in).
113+
self.dead_count_global = dead_count_global
97114

98115
# Pre-bias (subtracted from normalized input, added to output before denorm)
99116
self.pre_bias = nn.Parameter(torch.zeros(input_dim))
@@ -125,6 +142,8 @@ def _get_config(self) -> Dict[str, Any]:
125142
"auxk": self.auxk,
126143
"auxk_coef": self.auxk_coef,
127144
"dead_tokens_threshold": self.dead_tokens_threshold,
145+
"aggregate_loss": self.aggregate_loss,
146+
"dead_count_global": self.dead_count_global,
128147
}
129148

130149
def _init_encoder_from_decoder(self) -> None:
@@ -288,8 +307,17 @@ def _update_dead_latent_stats(self, codes: torch.Tensor) -> None:
288307
# Check which latents were active (any sample in batch had activation > threshold)
289308
active_mask = (codes.abs() > 1e-3).any(dim=0) # [hidden_dim]
290309

291-
# Reset counter for active latents, increment by token count for inactive
292-
n_tokens = codes.shape[0]
310+
# dead_count_global=True increments by GLOBAL tokens, not this rank's micro-batch:
311+
# each of the world_size ranks processes codes.shape[0] tokens per step, so the
312+
# inactivity counter must advance by codes.shape[0] * world_size to match
313+
# dead_tokens_threshold's intended units (total training tokens). The default
314+
# (per-rank count) makes the threshold (and auxk revival) trigger world_size x too
315+
# late under DDP. The trainer's all_reduce(MIN) preserves "fired on any rank => reset".
316+
if self.dead_count_global and dist.is_available() and dist.is_initialized():
317+
world_size = dist.get_world_size()
318+
else:
319+
world_size = 1
320+
n_tokens = codes.shape[0] * world_size
293321
self.stats_last_nonzero = torch.where(
294322
active_mask, torch.zeros_like(self.stats_last_nonzero), self.stats_last_nonzero + n_tokens
295323
)
@@ -331,25 +359,38 @@ def _compute_auxk_loss(
331359
# Decode auxiliary latents using only dead decoder columns (avoids full-width matmul)
332360
recon_aux = F.linear(codes_aux, self.decoder.weight[:, dead_indices], self.decoder.bias)
333361

334-
# Target is the residual (what primary reconstruction missed)
335-
# Work in normalized space for the aux loss
362+
# Target is the residual (what primary reconstruction missed).
363+
# The corrected residual is x - recon (the actual reconstruction error). The legacy
364+
# non-normalized form `x - recon + pre_bias` simplifies to `x - decoder(codes)`, whose
365+
# norm is dominated by ||pre_bias|| rather than the actual error, weakening the aux
366+
# gradient by ~(||pre_bias|| / ||error||)^2. Gated on aggregate_loss so False
367+
# reproduces the previous auxk loss end-to-end; True uses the fix.
336368
if self.normalize_input and norm_info is not None:
337-
# Normalize x to match the space where encoding happened
369+
# Normalize x to match the space where encoding happened (already correct in both modes)
338370
x_norm = (x - norm_info["mu"]) / norm_info["std"]
339371
# Reuse codes from forward pass instead of re-encoding (or a precomputed
340372
# normalized recon, e.g. from the sparse/triton decode path).
341373
if recon_norm is None:
342374
recon_norm = self.decoder(codes) + self.pre_bias
343375
residual = x_norm - recon_norm.detach()
376+
elif not self.aggregate_loss:
377+
residual = x - recon.detach() + self.pre_bias.detach() # legacy (previous behavior)
344378
else:
345-
residual = x - recon.detach() + self.pre_bias.detach()
346-
347-
# Normalized MSE: MSE / variance of target
348-
mse = (recon_aux - residual).pow(2).mean(dim=-1) # [batch]
349-
target_var = residual.pow(2).mean(dim=-1) # [batch]
350-
351-
# Avoid division by zero, use nan_to_num like OpenAI
352-
normalized_mse = (mse / (target_var + 1e-8)).mean()
379+
residual = x - recon.detach() # corrected: the true reconstruction error
380+
381+
# AuxK normalized MSE: how much of the residual the dead latents recover. Default
382+
# (aggregate_loss=False) is the legacy per-token ratio (mse_t / target_var_t), which
383+
# up-weights already-well-reconstructed (small residual) tokens and down-weights the
384+
# big missed structure dead latents should grab — mis-targeting revival and letting
385+
# dead latents persist. aggregate_loss=True aggregates over the whole batch instead.
386+
if not self.aggregate_loss:
387+
mse = (recon_aux - residual).pow(2).mean(dim=-1)
388+
target_var = residual.pow(2).mean(dim=-1)
389+
normalized_mse = (mse / (target_var + 1e-8)).mean()
390+
else:
391+
mse = (recon_aux - residual).pow(2).mean()
392+
target_var = residual.pow(2).mean()
393+
normalized_mse = mse / (target_var + 1e-8)
353394

354395
return normalized_mse
355396

@@ -433,13 +474,19 @@ def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
433474
# Update dead latent stats
434475
self._update_dead_latent_stats(codes)
435476

436-
# Primary reconstruction loss (FVU: fraction of variance unexplained)
437-
# Center by pre_bias (learned per-dim mean) so denominator reflects
438-
# actual signal variance, consistent with var_exp metric
439-
mse = (recon - x).pow(2).mean(dim=-1) # [batch]
440-
x_centered = x - self.pre_bias
441-
x_var = x_centered.pow(2).mean(dim=-1) # [batch]
442-
recon_loss = (mse / (x_var + 1e-8)).mean()
477+
# Primary reconstruction loss (FVU: fraction of variance unexplained), centered by
478+
# pre_bias to match the reported var_exp metric. Default (aggregate_loss=False) is the
479+
# legacy per-token ratio mean_t(mse_t / x_var_t), which over-weights low-variance tokens
480+
# and down-weights rare high-variance ones, starving the latents specialized on them.
481+
# aggregate_loss=True uses a single batch-level mse.mean() / var.mean() ratio instead.
482+
if not self.aggregate_loss:
483+
mse = (recon - x).pow(2).mean(dim=-1)
484+
x_var = (x - self.pre_bias).pow(2).mean(dim=-1)
485+
recon_loss = (mse / (x_var + 1e-8)).mean()
486+
else:
487+
mse = (recon - x).pow(2).mean()
488+
x_var = (x - self.pre_bias).pow(2).mean()
489+
recon_loss = mse / (x_var + 1e-8)
443490

444491
# Sparsity metric (for logging)
445492
l0 = (codes != 0).float().sum(dim=-1).mean()
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for TopKSAE training-quality options: loss reduction + global dead-latent counting."""
17+
18+
import torch
19+
from sae.architectures import topk as topk_mod
20+
from sae.architectures.topk import TopKSAE
21+
22+
23+
def _make_sae(**kw):
24+
torch.manual_seed(0)
25+
return TopKSAE(input_dim=8, hidden_dim=16, top_k=4, normalize_input=False, **kw)
26+
27+
28+
def test_recon_loss_aggregate_matches_batch_fvu():
29+
"""aggregate_loss=True equals the batch-level FVU mse.mean()/var.mean()."""
30+
x = torch.randn(32, 8)
31+
sae = _make_sae(aggregate_loss=True)
32+
recon = sae.forward_with_aux(x)["recon"]
33+
expected = (recon - x).pow(2).mean() / ((x - sae.pre_bias).pow(2).mean() + 1e-8)
34+
assert torch.allclose(sae.loss(x)["total"], expected)
35+
36+
37+
def test_dead_latent_count_global_vs_local(monkeypatch):
38+
"""dead_count_global advances the inactivity counter by tokens x world_size; else local."""
39+
# Pretend we're in a 4-rank distributed run.
40+
monkeypatch.setattr(topk_mod.dist, "is_available", lambda: True)
41+
monkeypatch.setattr(topk_mod.dist, "is_initialized", lambda: True)
42+
monkeypatch.setattr(topk_mod.dist, "get_world_size", lambda: 4)
43+
44+
codes = torch.zeros(10, 16)
45+
codes[:, 0] = 1.0 # only latent 0 fires
46+
47+
g = _make_sae(dead_count_global=True)
48+
g.stats_last_nonzero.zero_()
49+
g._update_dead_latent_stats(codes)
50+
assert int(g.stats_last_nonzero[0]) == 0 # fired -> reset
51+
assert int(g.stats_last_nonzero[1]) == 10 * 4 # inactive -> tokens x world_size
52+
53+
loc = _make_sae(dead_count_global=False)
54+
loc.stats_last_nonzero.zero_()
55+
loc._update_dead_latent_stats(codes)
56+
assert int(loc.stats_last_nonzero[1]) == 10 # inactive -> local micro-batch only
57+
58+
59+
def test_opted_in_options_round_trip_through_config():
60+
"""Opted-in (non-default) options serialize in the checkpoint config so a reload keeps them."""
61+
cfg = _make_sae(aggregate_loss=True, dead_count_global=True)._get_config()
62+
assert cfg["aggregate_loss"] is True
63+
assert cfg["dead_count_global"] is True

0 commit comments

Comments
 (0)