Skip to content

Commit bca7eee

Browse files
louzongzhivasqu
andauthored
Add IndexCache support for GLM5 DSA (huggingface#45424)
* Add IndexCache support for GLM5 DSA * Refactor: Make IndexCache layer scheduling explicit in Config Moves index_topk_pattern generation from Attention.__init__ to Config.__post_init__ as suggested. Layers now simply check `config.index_topk_pattern[layer_idx]` instead of computing skip conditions, matching the mlp_layer_types pattern for consistent explicit configuration. * fix * oof, typo * remove the exception as its now hidden behind kwargs for BC --------- Co-authored-by: vasqu <antonprogamer@gmail.com>
1 parent 5617561 commit bca7eee

3 files changed

Lines changed: 193 additions & 46 deletions

File tree

src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class GlmMoeDsaConfig(PreTrainedConfig):
3939
Head dimension for the indexer projections (DSA).
4040
index_n_heads (`int | None`, *optional*, defaults to 32):
4141
Number of heads for the indexer projections (DSA).
42+
indexer_types (`list[str]`, *optional*):
43+
Indexer mode for each layer (`"full"` or `"shared"`). Defaults to first layer full, then every `index_topk_freq`-th layer full, rest shared.
4244
4345
```python
4446
>>> from transformers import GlmMoeDsaConfig, GlmMoeDsaModel
@@ -117,6 +119,7 @@ class GlmMoeDsaConfig(PreTrainedConfig):
117119
index_topk: int = 2048
118120
index_head_dim: int = 128
119121
index_n_heads: int = 32
122+
indexer_types: list[str] | None = None
120123

121124
def __post_init__(self, **kwargs):
122125
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
@@ -126,6 +129,20 @@ def __post_init__(self, **kwargs):
126129
self.mlp_layer_types = ["dense"] * min(3, self.num_hidden_layers) + ["sparse"] * (
127130
self.num_hidden_layers - 3
128131
)
132+
133+
# Indexer layer types
134+
if self.indexer_types is None:
135+
pattern = kwargs.pop("index_topk_pattern", None)
136+
freq = kwargs.pop("index_topk_freq", 1)
137+
if pattern is not None:
138+
self.indexer_types = (
139+
[{"F": "full", "S": "shared"}[c] for c in pattern] if isinstance(pattern, str) else list(pattern)
140+
)
141+
else:
142+
# First layer full, then every freq-th layer full, rest shared
143+
self.indexer_types = [
144+
"full" if (max(i - 1, 0) % freq) == 0 else "shared" for i in range(self.num_hidden_layers)
145+
]
129146
super().__post_init__(**kwargs)
130147

131148

src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def apply_rotary_pos_emb(
104104

105105
class GlmMoeDsaIndexer(nn.Module):
106106
"""
107-
Dynamic Sparse Attention (DSA) indexer for selecting top-k tokens.
107+
DeepSeek Sparse Attention (DSA) indexer for selecting top-k tokens.
108108
109109
The Indexer has its own lightweight projections (wq_b, wk) separate from the
110110
main MLA attention. It uses non-interleaved (NeoX/Llama) RoPE, unlike the main attention
@@ -139,7 +139,7 @@ def __init__(self, config: "GlmMoeDsaConfig", layer_idx: int):
139139
self.softmax_scale = self.head_dim**-0.5
140140

141141
# Indexer maintains its own key cache (not in DynamicCache, which is sized for attention layers only)
142-
self._cached_keys: torch.Tensor | None = None
142+
self.register_buffer("_cached_keys", None, persistent=False)
143143

144144
@torch.no_grad()
145145
def forward(
@@ -268,7 +268,7 @@ def eager_attention_forward(
268268

269269
class GlmMoeDsaAttention(nn.Module):
270270
"""
271-
Multi-head Latent Attention (MLA) with Dynamic Sparse Attention (DSA) indexer.
271+
Multi-head Latent Attention (MLA) with DeepSeek Sparse Attention (DSA) indexer.
272272
273273
This follows the same architecture as DeepSeek V3.2's MLA:
274274
- Query: x → q_a_proj → RMSNorm → q_b_proj → split(q_nope, q_pe) → RoPE(q_pe)
@@ -335,14 +335,23 @@ def __init__(self, config: GlmMoeDsaConfig, layer_idx: int):
335335

336336
self.indexer = GlmMoeDsaIndexer(config, layer_idx)
337337

338+
# Refer: https://arxiv.org/abs/2603.12201 for more details.
339+
# skip_topk: when True, this layer will skip computation and reuse previous layer's topk indices.
340+
# next_skip_topk: when True, the next layer will skip computation and reuse this layer's topk indices.
341+
self.skip_topk = config.indexer_types[layer_idx] == "shared"
342+
self.next_skip_topk = (
343+
config.indexer_types[layer_idx + 1] == "shared" if layer_idx < len(config.indexer_types) - 1 else False
344+
)
345+
338346
def forward(
339347
self,
340348
hidden_states: torch.Tensor,
341349
position_embeddings: tuple[torch.Tensor, torch.Tensor],
342350
attention_mask: torch.Tensor | None,
343351
past_key_values: Cache | None = None,
352+
prev_topk_indices: torch.Tensor | None = None,
344353
**kwargs: Unpack[FlashAttentionKwargs],
345-
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
354+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
346355
batch_size, seq_length = hidden_states.shape[:-1]
347356
cos, sin = position_embeddings
348357

@@ -385,20 +394,23 @@ def forward(
385394

386395
# ===== Indexer (DSA sparse mask) =====
387396
# attention_mask is [B, 1, S, T] (4D) for eager and (2D) otherwise but indexer works with [B, S, T] (3D)
388-
indexer_mask = (
389-
attention_mask[:, 0, :, :]
390-
if attention_mask is not None and attention_mask.dim() == 4
391-
else attention_mask.unsqueeze(1)
392-
if attention_mask is not None
393-
else None
394-
)
395-
topk_indices = self.indexer(
396-
hidden_states,
397-
q_resid,
398-
position_embeddings,
399-
indexer_mask,
400-
use_cache=past_key_values is not None,
401-
) # [B, S, topk]
397+
if not self.skip_topk or prev_topk_indices is None:
398+
indexer_mask = (
399+
attention_mask[:, 0, :, :]
400+
if attention_mask is not None and attention_mask.dim() == 4
401+
else attention_mask.unsqueeze(1)
402+
if attention_mask is not None
403+
else None
404+
)
405+
topk_indices = self.indexer(
406+
hidden_states,
407+
q_resid,
408+
position_embeddings,
409+
indexer_mask,
410+
use_cache=past_key_values is not None,
411+
) # [B, S, topk]
412+
else:
413+
topk_indices = prev_topk_indices # [B, S, topk]
402414

403415
# Build combined DSA + causal mask: -inf everywhere except selected top-k positions
404416
total_len = key_states.shape[2]
@@ -445,7 +457,7 @@ def forward(
445457

446458
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
447459
attn_output = self.o_proj(attn_output)
448-
return attn_output, attn_weights
460+
return attn_output, attn_weights, topk_indices if self.next_skip_topk else None
449461

450462

451463
class GlmMoeDsaMLP(nn.Module):
@@ -602,18 +614,20 @@ def forward(
602614
past_key_values: Cache | None = None,
603615
use_cache: bool | None = False,
604616
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
617+
prev_topk_indices: torch.Tensor | None = None,
605618
**kwargs: Unpack[TransformersKwargs],
606-
) -> torch.Tensor:
619+
) -> tuple[torch.Tensor, torch.Tensor | None]:
607620
residual = hidden_states
608621
hidden_states = self.input_layernorm(hidden_states)
609622
# Self Attention
610-
hidden_states, _ = self.self_attn(
623+
hidden_states, _, topk_indices = self.self_attn(
611624
hidden_states=hidden_states,
612625
attention_mask=attention_mask,
613626
position_ids=position_ids,
614627
past_key_values=past_key_values,
615628
use_cache=use_cache,
616629
position_embeddings=position_embeddings,
630+
prev_topk_indices=prev_topk_indices,
617631
**kwargs,
618632
)
619633
hidden_states = residual + hidden_states
@@ -623,7 +637,7 @@ def forward(
623637
hidden_states = self.post_attention_layernorm(hidden_states)
624638
hidden_states = self.mlp(hidden_states)
625639
hidden_states = residual + hidden_states
626-
return hidden_states
640+
return hidden_states, topk_indices
627641

628642

629643
@auto_docstring
@@ -784,14 +798,16 @@ def forward(
784798
hidden_states = inputs_embeds
785799
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
786800

801+
topk_indices = None
787802
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
788-
hidden_states = decoder_layer(
803+
hidden_states, topk_indices = decoder_layer(
789804
hidden_states,
790805
attention_mask=causal_mask,
791806
position_embeddings=position_embeddings,
792807
position_ids=position_ids,
793808
past_key_values=past_key_values,
794809
use_cache=use_cache,
810+
prev_topk_indices=topk_indices,
795811
**kwargs,
796812
)
797813

0 commit comments

Comments
 (0)