@@ -74,15 +74,26 @@ def __init__(
7474 dead_tokens_threshold : int = 10_000_000 ,
7575 init_encoder_from_decoder : bool = True ,
7676 init_pre_bias : bool = True ,
77+ decoder_impl : str = "dense" ,
7778 ):
78- """Initialize the Top-K SAE with encoder, decoder, and optional auxiliary loss."""
79+ """Initialize the Top-K SAE with encoder, decoder, and optional auxiliary loss.
80+
81+ ``decoder_impl`` selects the decode path: "dense" (default) builds the dense
82+ [batch, hidden_dim] code tensor and runs a full decoder matmul; "triton"
83+ decodes directly from the top-k (indices, values) via a sparse kernel
84+ (O(batch*k*d), no dense code tensor), enabling much larger hidden_dim. Weights
85+ are identical, so checkpoints are interchangeable between the two.
86+ """
7987 super ().__init__ (input_dim , hidden_dim )
8088 self .top_k = top_k
8189 self .init_pre_bias = init_pre_bias
8290 self .normalize_input = normalize_input
8391 self .auxk = auxk
8492 self .auxk_coef = auxk_coef
8593 self .dead_tokens_threshold = dead_tokens_threshold
94+ if decoder_impl not in ("dense" , "triton" ):
95+ raise ValueError (f"decoder_impl must be 'dense' or 'triton', got { decoder_impl !r} " )
96+ self .decoder_impl = decoder_impl
8697
8798 # Pre-bias (subtracted from normalized input, added to output before denorm)
8899 self .pre_bias = nn .Parameter (torch .zeros (input_dim ))
@@ -208,9 +219,40 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
208219 top_k_vals , top_k_indices = torch .topk (codes_relu , self .top_k , dim = - 1 )
209220 codes = torch .zeros_like (codes_relu ).scatter (- 1 , top_k_indices , top_k_vals )
210221
211- recon = self .decode (codes , info )
222+ if self .decoder_impl == "triton" :
223+ recon = self ._decode_topk_triton (top_k_vals , top_k_indices , info )
224+ else :
225+ recon = self .decode (codes , info )
212226 return recon , codes
213227
228+ def _decode_topk_triton (
229+ self ,
230+ top_k_vals : torch .Tensor ,
231+ top_k_indices : torch .Tensor ,
232+ info : Optional [Dict [str , torch .Tensor ]] = None ,
233+ denormalize : bool = True ,
234+ ) -> torch .Tensor :
235+ """Decode from top-k (values, indices) via the sparse Triton kernel.
236+
237+ Returns reconstruction with pre_bias added; denormalized to input scale when
238+ ``denormalize`` (set False to get the normalized-space recon for aux loss).
239+ """
240+ from ..kernels import TritonDecoderAutograd
241+
242+ recon = TritonDecoderAutograd .apply (top_k_indices .contiguous (), top_k_vals .contiguous (), self .decoder .weight )
243+ recon = recon + self .pre_bias
244+ if denormalize and self .normalize_input and info is not None :
245+ recon = self ._denormalize (recon , info )
246+ return recon
247+
248+ def _update_dead_latent_stats_from_indices (self , top_k_indices : torch .Tensor , n_tokens : int ) -> None :
249+ """Update stats_last_nonzero from top-k indices (no dense [batch, hidden] tensor)."""
250+ active_mask = torch .zeros_like (self .stats_last_nonzero , dtype = torch .bool )
251+ active_mask [top_k_indices .reshape (- 1 )] = True
252+ self .stats_last_nonzero = torch .where (
253+ active_mask , torch .zeros_like (self .stats_last_nonzero ), self .stats_last_nonzero + n_tokens
254+ )
255+
214256 def forward_with_aux (self , x : torch .Tensor ) -> Dict [str , torch .Tensor ]:
215257 """Forward pass with auxiliary info for auxk loss computation.
216258
@@ -257,8 +299,9 @@ def _compute_auxk_loss(
257299 x : torch .Tensor ,
258300 recon : torch .Tensor ,
259301 pre_act : torch .Tensor ,
260- codes : torch .Tensor ,
302+ codes : Optional [ torch .Tensor ] ,
261303 norm_info : Optional [Dict [str , torch .Tensor ]] = None ,
304+ recon_norm : Optional [torch .Tensor ] = None ,
262305 ) -> torch .Tensor :
263306 """Compute auxiliary loss for dead latents.
264307
@@ -293,8 +336,10 @@ def _compute_auxk_loss(
293336 if self .normalize_input and norm_info is not None :
294337 # Normalize x to match the space where encoding happened
295338 x_norm = (x - norm_info ["mu" ]) / norm_info ["std" ]
296- # Reuse codes from forward pass instead of re-encoding
297- recon_norm = self .decoder (codes ) + self .pre_bias
339+ # Reuse codes from forward pass instead of re-encoding (or a precomputed
340+ # normalized recon, e.g. from the sparse/triton decode path).
341+ if recon_norm is None :
342+ recon_norm = self .decoder (codes ) + self .pre_bias
298343 residual = x_norm - recon_norm .detach ()
299344 else :
300345 residual = x - recon .detach () + self .pre_bias .detach ()
@@ -375,6 +420,9 @@ def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
375420 - aux (if auxk enabled): auxiliary loss value
376421 - dead_pct (if auxk enabled): percentage of dead latents
377422 """
423+ if self .decoder_impl == "triton" :
424+ return self ._loss_triton (x )
425+
378426 # Forward pass with auxiliary info
379427 info = self .forward_with_aux (x )
380428 recon = info ["recon" ]
@@ -422,3 +470,53 @@ def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
422470 result ["aux" ] = aux_loss
423471
424472 return result
473+
474+ def _loss_triton (self , x : torch .Tensor ) -> Dict [str , torch .Tensor ]:
475+ """loss() using the sparse Triton decoder.
476+
477+ Numerically equivalent to the dense loss() but never materializes the dense
478+ [batch, hidden_dim] code tensor or runs the full decoder matmul: it decodes
479+ from the top-k (values, indices) and derives dead-latent stats / L0 from the
480+ indices. This is what lets hidden_dim scale to ~1M+.
481+ """
482+ pre_act , info = self .encode_pre_act (x )
483+ codes_relu = torch .relu (pre_act )
484+ top_k_vals , top_k_indices = torch .topk (codes_relu , self .top_k , dim = - 1 )
485+
486+ # Sparse decode in normalized space (pre_bias added); denormalize for the main loss.
487+ recon_norm = self ._decode_topk_triton (top_k_vals , top_k_indices , info , denormalize = False )
488+ recon = self ._denormalize (recon_norm , info ) if (self .normalize_input and info ) else recon_norm
489+
490+ # Dead-latent stats from indices (no dense codes tensor).
491+ self ._update_dead_latent_stats_from_indices (top_k_indices , x .shape [0 ])
492+
493+ # Primary reconstruction loss (FVU), centered by pre_bias -- matches dense loss().
494+ mse = (recon - x ).pow (2 ).mean (dim = - 1 )
495+ x_var = (x - self .pre_bias ).pow (2 ).mean (dim = - 1 )
496+ recon_loss = (mse / (x_var + 1e-8 )).mean ()
497+
498+ # For TopK, L0 == count of nonzero top-k values.
499+ l0 = (top_k_vals != 0 ).float ().sum (dim = - 1 ).mean ()
500+
501+ with torch .no_grad ():
502+ raw_mse = (recon - x ).pow (2 ).mean ()
503+ total_var = torch .var (x , dim = 0 ).sum ()
504+ residual_var = torch .var (recon - x , dim = 0 ).sum ()
505+ var_explained = 1.0 - (residual_var / (total_var + 1e-8 ))
506+
507+ result = {
508+ "total" : recon_loss ,
509+ "fvu" : 1.0 - var_explained ,
510+ "sparsity" : l0 ,
511+ "mse" : raw_mse ,
512+ "variance_explained" : var_explained ,
513+ }
514+ dead_pct = (self .stats_last_nonzero > self .dead_tokens_threshold ).float ().mean () * 100
515+ result ["dead_pct" ] = dead_pct
516+
517+ if self .auxk is not None :
518+ aux_loss = self ._compute_auxk_loss (x , recon , pre_act , codes = None , norm_info = info , recon_norm = recon_norm )
519+ result ["total" ] = recon_loss + self .auxk_coef * aux_loss
520+ result ["aux" ] = aux_loss
521+
522+ return result
0 commit comments