@@ -104,7 +104,7 @@ def apply_rotary_pos_emb(
104104
105105class GlmMoeDsaIndexer (nn .Module ):
106106 """
107- Dynamic Sparse Attention (DSA) indexer for selecting top-k tokens.
107+ DeepSeek Sparse Attention (DSA) indexer for selecting top-k tokens.
108108
109109 The Indexer has its own lightweight projections (wq_b, wk) separate from the
110110 main MLA attention. It uses non-interleaved (NeoX/Llama) RoPE, unlike the main attention
@@ -139,7 +139,7 @@ def __init__(self, config: "GlmMoeDsaConfig", layer_idx: int):
139139 self .softmax_scale = self .head_dim ** - 0.5
140140
141141 # Indexer maintains its own key cache (not in DynamicCache, which is sized for attention layers only)
142- self ._cached_keys : torch . Tensor | None = None
142+ self .register_buffer ( " _cached_keys" , None , persistent = False )
143143
144144 @torch .no_grad ()
145145 def forward (
@@ -268,7 +268,7 @@ def eager_attention_forward(
268268
269269class GlmMoeDsaAttention (nn .Module ):
270270 """
271- Multi-head Latent Attention (MLA) with Dynamic Sparse Attention (DSA) indexer.
271+ Multi-head Latent Attention (MLA) with DeepSeek Sparse Attention (DSA) indexer.
272272
273273 This follows the same architecture as DeepSeek V3.2's MLA:
274274 - Query: x → q_a_proj → RMSNorm → q_b_proj → split(q_nope, q_pe) → RoPE(q_pe)
@@ -335,14 +335,23 @@ def __init__(self, config: GlmMoeDsaConfig, layer_idx: int):
335335
336336 self .indexer = GlmMoeDsaIndexer (config , layer_idx )
337337
338+ # Refer: https://arxiv.org/abs/2603.12201 for more details.
339+ # skip_topk: when True, this layer will skip computation and reuse previous layer's topk indices.
340+ # next_skip_topk: when True, the next layer will skip computation and reuse this layer's topk indices.
341+ self .skip_topk = config .indexer_types [layer_idx ] == "shared"
342+ self .next_skip_topk = (
343+ config .indexer_types [layer_idx + 1 ] == "shared" if layer_idx < len (config .indexer_types ) - 1 else False
344+ )
345+
338346 def forward (
339347 self ,
340348 hidden_states : torch .Tensor ,
341349 position_embeddings : tuple [torch .Tensor , torch .Tensor ],
342350 attention_mask : torch .Tensor | None ,
343351 past_key_values : Cache | None = None ,
352+ prev_topk_indices : torch .Tensor | None = None ,
344353 ** kwargs : Unpack [FlashAttentionKwargs ],
345- ) -> tuple [torch .Tensor , torch .Tensor | None , tuple [ torch .Tensor ] | None ]:
354+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor | None ]:
346355 batch_size , seq_length = hidden_states .shape [:- 1 ]
347356 cos , sin = position_embeddings
348357
@@ -385,20 +394,23 @@ def forward(
385394
386395 # ===== Indexer (DSA sparse mask) =====
387396 # attention_mask is [B, 1, S, T] (4D) for eager and (2D) otherwise but indexer works with [B, S, T] (3D)
388- indexer_mask = (
389- attention_mask [:, 0 , :, :]
390- if attention_mask is not None and attention_mask .dim () == 4
391- else attention_mask .unsqueeze (1 )
392- if attention_mask is not None
393- else None
394- )
395- topk_indices = self .indexer (
396- hidden_states ,
397- q_resid ,
398- position_embeddings ,
399- indexer_mask ,
400- use_cache = past_key_values is not None ,
401- ) # [B, S, topk]
397+ if not self .skip_topk or prev_topk_indices is None :
398+ indexer_mask = (
399+ attention_mask [:, 0 , :, :]
400+ if attention_mask is not None and attention_mask .dim () == 4
401+ else attention_mask .unsqueeze (1 )
402+ if attention_mask is not None
403+ else None
404+ )
405+ topk_indices = self .indexer (
406+ hidden_states ,
407+ q_resid ,
408+ position_embeddings ,
409+ indexer_mask ,
410+ use_cache = past_key_values is not None ,
411+ ) # [B, S, topk]
412+ else :
413+ topk_indices = prev_topk_indices # [B, S, topk]
402414
403415 # Build combined DSA + causal mask: -inf everywhere except selected top-k positions
404416 total_len = key_states .shape [2 ]
@@ -445,7 +457,7 @@ def forward(
445457
446458 attn_output = attn_output .reshape (batch_size , seq_length , - 1 ).contiguous ()
447459 attn_output = self .o_proj (attn_output )
448- return attn_output , attn_weights
460+ return attn_output , attn_weights , topk_indices if self . next_skip_topk else None
449461
450462
451463class GlmMoeDsaMLP (nn .Module ):
@@ -602,18 +614,20 @@ def forward(
602614 past_key_values : Cache | None = None ,
603615 use_cache : bool | None = False ,
604616 position_embeddings : tuple [torch .Tensor , torch .Tensor ] | None = None ,
617+ prev_topk_indices : torch .Tensor | None = None ,
605618 ** kwargs : Unpack [TransformersKwargs ],
606- ) -> torch .Tensor :
619+ ) -> tuple [ torch .Tensor , torch . Tensor | None ] :
607620 residual = hidden_states
608621 hidden_states = self .input_layernorm (hidden_states )
609622 # Self Attention
610- hidden_states , _ = self .self_attn (
623+ hidden_states , _ , topk_indices = self .self_attn (
611624 hidden_states = hidden_states ,
612625 attention_mask = attention_mask ,
613626 position_ids = position_ids ,
614627 past_key_values = past_key_values ,
615628 use_cache = use_cache ,
616629 position_embeddings = position_embeddings ,
630+ prev_topk_indices = prev_topk_indices ,
617631 ** kwargs ,
618632 )
619633 hidden_states = residual + hidden_states
@@ -623,7 +637,7 @@ def forward(
623637 hidden_states = self .post_attention_layernorm (hidden_states )
624638 hidden_states = self .mlp (hidden_states )
625639 hidden_states = residual + hidden_states
626- return hidden_states
640+ return hidden_states , topk_indices
627641
628642
629643@auto_docstring
@@ -784,14 +798,16 @@ def forward(
784798 hidden_states = inputs_embeds
785799 position_embeddings = self .rotary_emb (hidden_states , position_ids = position_ids )
786800
801+ topk_indices = None
787802 for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
788- hidden_states = decoder_layer (
803+ hidden_states , topk_indices = decoder_layer (
789804 hidden_states ,
790805 attention_mask = causal_mask ,
791806 position_embeddings = position_embeddings ,
792807 position_ids = position_ids ,
793808 past_key_values = past_key_values ,
794809 use_cache = use_cache ,
810+ prev_topk_indices = topk_indices ,
795811 ** kwargs ,
796812 )
797813
0 commit comments