Skip to content

Commit 78e9fa0

Browse files
jwilberroot
andauthored
feat(sae): Triton sparse decoder, tensor-parallel training, streaming (NVIDIA-BioNeMo#1613)
Update SAE to support larger models: - triton kernels - producer/consumer streaming for training - tp Co-authored-by: root <root@nvidia-lepton128.cm.cluster>
1 parent a37afa9 commit 78e9fa0

27 files changed

Lines changed: 2697 additions & 21 deletions

bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
save_cluster_labels,
4949
save_feature_atlas,
5050
)
51-
from .architectures import MoESAE, ReLUSAE, SparseAutoencoder, TopKSAE
51+
from .architectures import MoESAE, ReLUSAE, ShardedTopKSAE, SparseAutoencoder, TopKSAE
5252
from .autointerp import (
5353
DEFAULT_PROMPT_TEMPLATE,
5454
TOKEN_PROMPT_TEMPLATE,
@@ -80,8 +80,10 @@
8080
evaluate_sae,
8181
evaluate_sparsity,
8282
)
83+
from .kernels import HAS_TRITON, TritonDecoderAutograd
8384
from .perf_logger import PerfLogger
8485
from .process_group_manager import ProcessGroupManager
86+
from .streaming import StreamingActivationDataset, StreamingConfig, make_streaming_dataloader
8587
from .training import ParallelConfig, Trainer, TrainingConfig, WandbConfig
8688
from .utils import get_device, set_seed
8789

@@ -90,6 +92,7 @@
9092

9193
__all__ = [
9294
"DEFAULT_PROMPT_TEMPLATE",
95+
"HAS_TRITON",
9396
"TOKEN_PROMPT_TEMPLATE",
9497
"ActivationStore",
9598
"ActivationStoreConfig",
@@ -117,14 +120,18 @@
117120
"PerfLogger",
118121
"ProcessGroupManager",
119122
"ReLUSAE",
123+
"ShardedTopKSAE",
120124
"SparseAutoencoder",
121125
"SparsityMetrics",
126+
"StreamingActivationDataset",
127+
"StreamingConfig",
122128
"TokenActivationCollector",
123129
"TokenExample",
124130
"TopExample",
125131
"TopKSAE",
126132
"Trainer",
127133
"TrainingConfig",
134+
"TritonDecoderAutograd",
128135
"WandbConfig",
129136
"build_cluster_label_prompt",
130137
"compute_cluster_centroids",
@@ -140,6 +147,7 @@
140147
"get_device",
141148
"launch_dashboard",
142149
"load_activations",
150+
"make_streaming_dataloader",
143151
"save_activations",
144152
"save_cluster_labels",
145153
"save_feature_atlas",

bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
from .moe import MoESAE
2020
from .relu_l1 import ReLUSAE
2121
from .topk import TopKSAE
22+
from .topk_tp import ShardedTopKSAE
2223

2324

2425
__all__ = [
2526
"MoESAE",
2627
"ReLUSAE",
28+
"ShardedTopKSAE",
2729
"SparseAutoencoder",
2830
"TopKSAE",
2931
]

bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,26 @@ def __init__(
7474
dead_tokens_threshold: int = 10_000_000,
7575
init_encoder_from_decoder: bool = True,
7676
init_pre_bias: bool = True,
77+
decoder_impl: str = "dense",
7778
):
78-
"""Initialize the Top-K SAE with encoder, decoder, and optional auxiliary loss."""
79+
"""Initialize the Top-K SAE with encoder, decoder, and optional auxiliary loss.
80+
81+
``decoder_impl`` selects the decode path: "dense" (default) builds the dense
82+
[batch, hidden_dim] code tensor and runs a full decoder matmul; "triton"
83+
decodes directly from the top-k (indices, values) via a sparse kernel
84+
(O(batch*k*d), no dense code tensor), enabling much larger hidden_dim. Weights
85+
are identical, so checkpoints are interchangeable between the two.
86+
"""
7987
super().__init__(input_dim, hidden_dim)
8088
self.top_k = top_k
8189
self.init_pre_bias = init_pre_bias
8290
self.normalize_input = normalize_input
8391
self.auxk = auxk
8492
self.auxk_coef = auxk_coef
8593
self.dead_tokens_threshold = dead_tokens_threshold
94+
if decoder_impl not in ("dense", "triton"):
95+
raise ValueError(f"decoder_impl must be 'dense' or 'triton', got {decoder_impl!r}")
96+
self.decoder_impl = decoder_impl
8697

8798
# Pre-bias (subtracted from normalized input, added to output before denorm)
8899
self.pre_bias = nn.Parameter(torch.zeros(input_dim))
@@ -208,9 +219,40 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
208219
top_k_vals, top_k_indices = torch.topk(codes_relu, self.top_k, dim=-1)
209220
codes = torch.zeros_like(codes_relu).scatter(-1, top_k_indices, top_k_vals)
210221

211-
recon = self.decode(codes, info)
222+
if self.decoder_impl == "triton":
223+
recon = self._decode_topk_triton(top_k_vals, top_k_indices, info)
224+
else:
225+
recon = self.decode(codes, info)
212226
return recon, codes
213227

228+
def _decode_topk_triton(
229+
self,
230+
top_k_vals: torch.Tensor,
231+
top_k_indices: torch.Tensor,
232+
info: Optional[Dict[str, torch.Tensor]] = None,
233+
denormalize: bool = True,
234+
) -> torch.Tensor:
235+
"""Decode from top-k (values, indices) via the sparse Triton kernel.
236+
237+
Returns reconstruction with pre_bias added; denormalized to input scale when
238+
``denormalize`` (set False to get the normalized-space recon for aux loss).
239+
"""
240+
from ..kernels import TritonDecoderAutograd
241+
242+
recon = TritonDecoderAutograd.apply(top_k_indices.contiguous(), top_k_vals.contiguous(), self.decoder.weight)
243+
recon = recon + self.pre_bias
244+
if denormalize and self.normalize_input and info is not None:
245+
recon = self._denormalize(recon, info)
246+
return recon
247+
248+
def _update_dead_latent_stats_from_indices(self, top_k_indices: torch.Tensor, n_tokens: int) -> None:
249+
"""Update stats_last_nonzero from top-k indices (no dense [batch, hidden] tensor)."""
250+
active_mask = torch.zeros_like(self.stats_last_nonzero, dtype=torch.bool)
251+
active_mask[top_k_indices.reshape(-1)] = True
252+
self.stats_last_nonzero = torch.where(
253+
active_mask, torch.zeros_like(self.stats_last_nonzero), self.stats_last_nonzero + n_tokens
254+
)
255+
214256
def forward_with_aux(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
215257
"""Forward pass with auxiliary info for auxk loss computation.
216258
@@ -257,8 +299,9 @@ def _compute_auxk_loss(
257299
x: torch.Tensor,
258300
recon: torch.Tensor,
259301
pre_act: torch.Tensor,
260-
codes: torch.Tensor,
302+
codes: Optional[torch.Tensor],
261303
norm_info: Optional[Dict[str, torch.Tensor]] = None,
304+
recon_norm: Optional[torch.Tensor] = None,
262305
) -> torch.Tensor:
263306
"""Compute auxiliary loss for dead latents.
264307
@@ -293,8 +336,10 @@ def _compute_auxk_loss(
293336
if self.normalize_input and norm_info is not None:
294337
# Normalize x to match the space where encoding happened
295338
x_norm = (x - norm_info["mu"]) / norm_info["std"]
296-
# Reuse codes from forward pass instead of re-encoding
297-
recon_norm = self.decoder(codes) + self.pre_bias
339+
# Reuse codes from forward pass instead of re-encoding (or a precomputed
340+
# normalized recon, e.g. from the sparse/triton decode path).
341+
if recon_norm is None:
342+
recon_norm = self.decoder(codes) + self.pre_bias
298343
residual = x_norm - recon_norm.detach()
299344
else:
300345
residual = x - recon.detach() + self.pre_bias.detach()
@@ -375,6 +420,9 @@ def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
375420
- aux (if auxk enabled): auxiliary loss value
376421
- dead_pct (if auxk enabled): percentage of dead latents
377422
"""
423+
if self.decoder_impl == "triton":
424+
return self._loss_triton(x)
425+
378426
# Forward pass with auxiliary info
379427
info = self.forward_with_aux(x)
380428
recon = info["recon"]
@@ -422,3 +470,53 @@ def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
422470
result["aux"] = aux_loss
423471

424472
return result
473+
474+
def _loss_triton(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
475+
"""loss() using the sparse Triton decoder.
476+
477+
Numerically equivalent to the dense loss() but never materializes the dense
478+
[batch, hidden_dim] code tensor or runs the full decoder matmul: it decodes
479+
from the top-k (values, indices) and derives dead-latent stats / L0 from the
480+
indices. This is what lets hidden_dim scale to ~1M+.
481+
"""
482+
pre_act, info = self.encode_pre_act(x)
483+
codes_relu = torch.relu(pre_act)
484+
top_k_vals, top_k_indices = torch.topk(codes_relu, self.top_k, dim=-1)
485+
486+
# Sparse decode in normalized space (pre_bias added); denormalize for the main loss.
487+
recon_norm = self._decode_topk_triton(top_k_vals, top_k_indices, info, denormalize=False)
488+
recon = self._denormalize(recon_norm, info) if (self.normalize_input and info) else recon_norm
489+
490+
# Dead-latent stats from indices (no dense codes tensor).
491+
self._update_dead_latent_stats_from_indices(top_k_indices, x.shape[0])
492+
493+
# Primary reconstruction loss (FVU), centered by pre_bias -- matches dense loss().
494+
mse = (recon - x).pow(2).mean(dim=-1)
495+
x_var = (x - self.pre_bias).pow(2).mean(dim=-1)
496+
recon_loss = (mse / (x_var + 1e-8)).mean()
497+
498+
# For TopK, L0 == count of nonzero top-k values.
499+
l0 = (top_k_vals != 0).float().sum(dim=-1).mean()
500+
501+
with torch.no_grad():
502+
raw_mse = (recon - x).pow(2).mean()
503+
total_var = torch.var(x, dim=0).sum()
504+
residual_var = torch.var(recon - x, dim=0).sum()
505+
var_explained = 1.0 - (residual_var / (total_var + 1e-8))
506+
507+
result = {
508+
"total": recon_loss,
509+
"fvu": 1.0 - var_explained,
510+
"sparsity": l0,
511+
"mse": raw_mse,
512+
"variance_explained": var_explained,
513+
}
514+
dead_pct = (self.stats_last_nonzero > self.dead_tokens_threshold).float().mean() * 100
515+
result["dead_pct"] = dead_pct
516+
517+
if self.auxk is not None:
518+
aux_loss = self._compute_auxk_loss(x, recon, pre_act, codes=None, norm_info=info, recon_norm=recon_norm)
519+
result["total"] = recon_loss + self.auxk_coef * aux_loss
520+
result["aux"] = aux_loss
521+
522+
return result

0 commit comments

Comments
 (0)