Commit 6aa5103
sae: config-gated dead-latent/FVU training fixes + inter-shard shuffle & pre-bias init shard sampling (NVIDIA-BioNeMo#1619)
## Why
Training TopK SAEs on Evo2 activations hit a severe **dead-latent**
problem (a large fraction of features never fired, wasting capacity).
`normalize_input` (already merged) fixed most of it; this PR adds the
remaining **training-dynamics fixes we found necessary for Evo2 SAE
training**.
**Every change defaults to the previous behavior and is opt-in** — so
you can reproduce or continue prior training runs **exactly as before**,
and enable each fix only when you want it. The training recipe opts in;
both `topk` options serialize in the checkpoint config so a reloaded SAE
keeps its behavior.
## Changes — default = previous behavior, opt in per flag
**1. Dead-latent inactivity counted in *total* tokens** —
`dead_count_global` (default `False` = previous per-rank count)
The auxk revival fires once a latent has been inactive for
`dead_tokens_threshold` (10M) tokens, but the counter advanced by *this
rank's* micro-batch — so under DDP it ran `world_size`× too slow and
revival kicked in `world_size`× too late (≈80M effective tokens on 8
GPUs). Opt in with `dead_count_global=True` to count total tokens (×
world_size); the `all_reduce(MIN)` still means "fired on any rank ⇒
reset."
**2. Aggregate FVU + auxk loss** — `aggregate_loss` (bool, default
`False` = previous per-token)
The per-token loss ratio `mean_t(mse_t / var_t)` down-weights rare
high-variance tokens, starving the latents that specialize on them
(notably Evo2's heavy-tailed **sink tokens**) → they die. Opt in with
`aggregate_loss=True` for a batch-level ratio (which also matches the
reported `var_exp` metric). This single bool also fixes the **auxk
residual** end-to-end: `False` keeps the previous `x - recon +
pre_bias`; `True` uses the corrected `x - recon` (the true error, not
`pre_bias`-dominated).
**3. Shuffle + blend shards** — `mix_shards` (int, default `1` =
previous)
Shards are written in corpus order (all prokaryota, then all eukaryota).
A contiguous per-rank slice trains a rank on one kingdom then switches
mid-epoch → a visible **FVU cliff**. `mix_shards=1` (default) = previous
behavior (one shard at a time, contiguous slice). Set `mix_shards=N>1`
to **globally shuffle the shard list** before the per-rank split (so
each rank gets a cross-section) **and** buffer/blend N shards per batch
(≈N shards of peak RAM).
**4. Spread the pre-bias-init sample** — `sample(num_shards=…)` (default
`1` = previous single shard)
`pre_bias` is initialized to the geometric median of a sample of
activations (so the SAE starts centered). A single-shard sample biases
it toward whatever is first in corpus order (one kingdom) → mis-centered
init → more dead latents. Set `num_shards>1` to draw the sample across
that many random shards spanning the store (≈one shard of peak RAM —
each sub-sampled then freed).
## How to opt in (what the Evo2 recipe sets)
```python
TopKSAE(..., aggregate_loss=True, dead_count_global=True)
store.get_streaming_dataloader(..., mix_shards=8) # shuffle + blend 8 shards
pre_bias0 = geometric_median(store.sample(n, num_shards=8)) # sample across 8 shards
```
The training recipe (separate PR) exposes these as CLI flags
(`--aggregate-loss`, `--dead-count-global`, `--mix-shards`,
`--presample-shards`).
## Opt-out summary
| behavior | knob | default | opt in |
|---|---|---|---|
| global dead-token count | `dead_count_global` (bool) | `False` |
`True` |
| aggregate FVU + auxk loss | `aggregate_loss` (bool) | `False` | `True`
|
| shard shuffle + blending | `mix_shards` (int) | `1` | `>1` |
| spread pre-init sample | `sample(num_shards=)` | `1` | `>1` |
## Tests — `sae/tests/test_topk.py` (CPU, no GPU)
global-vs-local dead-token counting, the aggregate-FVU formula
(`mse.mean()/var.mean()`), and that the opted-in flags round-trip
through `_get_config()`.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Signed-off-by: Polina Binder <pbinder@nvidia.com>
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>1 parent 78e9fa0 commit 6aa5103
3 files changed
Lines changed: 205 additions & 32 deletions
File tree
- bionemo-recipes/interpretability/sparse_autoencoders/sae
- src/sae
- architectures
- tests
Lines changed: 75 additions & 12 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
332 | 332 | | |
333 | 333 | | |
334 | 334 | | |
| 335 | + | |
335 | 336 | | |
336 | 337 | | |
337 | 338 | | |
338 | | - | |
| 339 | + | |
339 | 340 | | |
340 | 341 | | |
341 | 342 | | |
342 | | - | |
| 343 | + | |
343 | 344 | | |
344 | 345 | | |
345 | 346 | | |
346 | 347 | | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
347 | 354 | | |
348 | 355 | | |
349 | 356 | | |
| |||
357 | 364 | | |
358 | 365 | | |
359 | 366 | | |
360 | | - | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
361 | 374 | | |
| 375 | + | |
| 376 | + | |
362 | 377 | | |
363 | 378 | | |
364 | 379 | | |
| |||
368 | 383 | | |
369 | 384 | | |
370 | 385 | | |
| 386 | + | |
371 | 387 | | |
372 | 388 | | |
373 | 389 | | |
374 | 390 | | |
375 | 391 | | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
376 | 426 | | |
377 | 427 | | |
378 | 428 | | |
| |||
491 | 541 | | |
492 | 542 | | |
493 | 543 | | |
| 544 | + | |
494 | 545 | | |
495 | 546 | | |
496 | 547 | | |
497 | 548 | | |
498 | 549 | | |
499 | 550 | | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
500 | 554 | | |
501 | 555 | | |
502 | 556 | | |
| |||
511 | 565 | | |
512 | 566 | | |
513 | 567 | | |
| 568 | + | |
514 | 569 | | |
515 | | - | |
| 570 | + | |
516 | 571 | | |
517 | 572 | | |
518 | 573 | | |
519 | | - | |
520 | 574 | | |
521 | | - | |
522 | | - | |
523 | | - | |
524 | | - | |
525 | | - | |
526 | | - | |
527 | | - | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
528 | 591 | | |
529 | 592 | | |
530 | 593 | | |
| |||
Lines changed: 67 additions & 20 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
59 | 60 | | |
60 | 61 | | |
61 | 62 | | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
62 | 71 | | |
63 | 72 | | |
64 | 73 | | |
| |||
72 | 81 | | |
73 | 82 | | |
74 | 83 | | |
| 84 | + | |
| 85 | + | |
75 | 86 | | |
76 | 87 | | |
77 | 88 | | |
| |||
94 | 105 | | |
95 | 106 | | |
96 | 107 | | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
97 | 114 | | |
98 | 115 | | |
99 | 116 | | |
| |||
125 | 142 | | |
126 | 143 | | |
127 | 144 | | |
| 145 | + | |
| 146 | + | |
128 | 147 | | |
129 | 148 | | |
130 | 149 | | |
| |||
288 | 307 | | |
289 | 308 | | |
290 | 309 | | |
291 | | - | |
292 | | - | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
293 | 321 | | |
294 | 322 | | |
295 | 323 | | |
| |||
331 | 359 | | |
332 | 360 | | |
333 | 361 | | |
334 | | - | |
335 | | - | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
336 | 368 | | |
337 | | - | |
| 369 | + | |
338 | 370 | | |
339 | 371 | | |
340 | 372 | | |
341 | 373 | | |
342 | 374 | | |
343 | 375 | | |
| 376 | + | |
| 377 | + | |
344 | 378 | | |
345 | | - | |
346 | | - | |
347 | | - | |
348 | | - | |
349 | | - | |
350 | | - | |
351 | | - | |
352 | | - | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
353 | 394 | | |
354 | 395 | | |
355 | 396 | | |
| |||
433 | 474 | | |
434 | 475 | | |
435 | 476 | | |
436 | | - | |
437 | | - | |
438 | | - | |
439 | | - | |
440 | | - | |
441 | | - | |
442 | | - | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
443 | 490 | | |
444 | 491 | | |
445 | 492 | | |
| |||
Lines changed: 63 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
0 commit comments