@@ -110,9 +110,9 @@ def __init__(
110110 self .dtype = config .dtype
111111 self .weight_dtype = config .weight_dtype
112112
113- self .n_heads = config .index_n_heads
114- self .head_dim = config .index_head_dim
115- self .index_topk = config .index_topk
113+ self .n_heads = config .indexer_n_heads
114+ self .head_dim = config .indexer_head_dim
115+ self .indexer_topk = config .indexer_topk
116116 self .emb_dim = config .emb_dim
117117 self .rope_head_dim = config .qk_rope_head_dim
118118 self .q_lora_rank = config .q_lora_rank
@@ -180,13 +180,13 @@ def apply_partial_rope(
180180 2. Input Layout: Indexer uses concatenated layout (interleave=False), whereas MLA uses interleaved (interleave=True).
181181
182182 Args:
183- inputs: Input array of shape [batch, seqlen, index_n_heads, index_head_dim ].
183+ inputs: Input array of shape [batch, seqlen, indexer_n_heads, indexer_head_dim ].
184184 positions: Position array of shape [batch, seqlen].
185185
186186 Returns:
187- Array with partial RoPE applied, with shape [batch, seqlen, index_n_heads, index_head_dim ]
187+ Array with partial RoPE applied, with shape [batch, seqlen, indexer_n_heads, indexer_head_dim ]
188188 """
189- # index_head_dim -> [rope_head_dim, index_head_dim - rope_head_dim]
189+ # indexer_head_dim -> [rope_head_dim, indexer_head_dim - rope_head_dim]
190190 x_pe , x_nope = jnp .split (inputs , [self .rope_head_dim ], axis = - 1 )
191191 # x_pe [B, S, H, rope_head_dim], positions [B, S]
192192 x_pe = self .rotary_embedding (x_pe , position = inputs_positions )
@@ -256,22 +256,50 @@ def __call__(
256256 b: Batch size
257257 t: Query Sequence Length (Target), note t = s here
258258 s: Key/Value Sequence Length (Source)
259- h: Number of Indexer Heads (index_n_heads )
260- d: Indexer Head Dimension (index_head_dim )
259+ h: Number of Indexer Heads (indexer_n_heads )
260+ d: Indexer Head Dimension (indexer_head_dim )
261261 """
262262 # NOTE: If sequence length <= topk, indexer always selects all tokens.
263- if self .config .max_target_length <= self .index_topk :
263+ if self .config .max_target_length <= self .indexer_topk :
264264 return None , None , None
265265
266266 bsz , seqlen , _ = inputs_q .shape # s = t = seqlen
267267
268268 # Query Processing: Project from Latent low_rank_q
269- q = self .wq_b (low_rank_q ) # [b, t, q_lora_rank] -> [b, t, h * d]
269+ if self .config .indexer_sparse_training :
270+ # ==============================================================================
271+ # Gradient Isolation Strategy: Main Model vs. Indexer
272+ # ==============================================================================
273+ # This creates a barrier to train both components independently:
274+ #
275+ # Forward Pass:
276+ # - The Indexer receives a detached copy of the inputs (via `stop_gradient`)
277+ # to independently calculate its scores and `indexer_loss`.
278+ #
279+ # Backward Pass (Main Model):
280+ # - The main model optimizes its weights based solely on the LM loss.
281+ # - The `indexer_mask` in the Attention layer prevents gradients from the main
282+ # loss from flowing into the Indexer's weights.
283+ #
284+ # Backward Pass (Indexer):
285+ # - Gradients from the `indexer_loss` flow back to update the Indexer's weights.
286+ # - The `stop_gradient` applied to the inputs acts as a mathematical wall, dropping
287+ # gradients to 0.0 and preventing the Indexer loss from altering the main model's
288+ # earlier layers.
289+ inputs_q_for_indexer = jax .lax .stop_gradient (inputs_q )
290+ low_rank_q_for_indexer = jax .lax .stop_gradient (low_rank_q )
291+ inputs_kv_for_indexer = jax .lax .stop_gradient (inputs_kv )
292+ else :
293+ inputs_q_for_indexer = inputs_q
294+ low_rank_q_for_indexer = low_rank_q
295+ inputs_kv_for_indexer = inputs_kv
296+
297+ q = self .wq_b (low_rank_q_for_indexer ) # [b, t, q_lora_rank] -> [b, t, h * d]
270298 q = q .reshape (bsz , seqlen , self .n_heads , self .head_dim ) # [b, t, h, d]
271299 q = self .apply_partial_rope (q , inputs_positions = inputs_positions )
272300
273301 # Key Processing: Project from Input
274- k = self .wk (inputs_kv ) # [b, s, embed_dim] -> [b, s, d]
302+ k = self .wk (inputs_kv_for_indexer ) # [b, s, embed_dim] -> [b, s, d]
275303 k = self .k_norm (k )
276304 k = k [:, :, None , :] # [b, s, d] -> [b, s, 1, d]
277305 k = self .apply_partial_rope (k , inputs_positions = inputs_positions )
@@ -283,7 +311,7 @@ def __call__(
283311 logits = jnp .einsum ("bthd, bsd -> btsh" , q , k , precision = self .config .matmul_precision )
284312 logits = jax .nn .relu (logits )
285313 # Compute head weights: project from input, [b, t, embed_dim] -> [b, t, h]
286- weights = self .weights_proj (inputs_q )
314+ weights = self .weights_proj (inputs_q_for_indexer )
287315 # Weights scaling affect indexer_score, but does not affect topk_indices. Keep scaling for numerical stability.
288316 # https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480
289317 weights = weights * (self .n_heads ** - 0.5 ) * self .softmax_scale
@@ -295,7 +323,7 @@ def __call__(
295323 indexer_score += attention_mask
296324
297325 # TopK selection based on index score
298- _ , topk_indices = jax .lax .top_k (indexer_score , k = self .index_topk ) # topk_indices [b, t, k]
326+ _ , topk_indices = jax .lax .top_k (indexer_score , k = self .indexer_topk ) # topk_indices [b, t, k]
299327
300328 # Create Sparse Index Mask: 0 and large negatives
301329 indexer_mask = self .generate_mask (topk_indices , seqlen ) # [b, t, s]
@@ -607,8 +635,8 @@ def __init__(
607635 )
608636
609637 # Initialize Indexer
610- self .use_sparse_indexer = config .use_sparse_indexer
611- if self .use_sparse_indexer :
638+ self .use_indexer = config .use_indexer
639+ if self .use_indexer :
612640 # Need two versions of rope.
613641 # MLA applies yarn with interleave layout.
614642 # Indexer applies yarn with concatenate layout.
@@ -989,6 +1017,13 @@ def calculate_indexer_loss(
9891017 Returns:
9901018 The computed KL divergence loss.
9911019 """
1020+ if sparse_loss :
1021+ # Detach main model components from the computational graph.
1022+ # The indexer should match the main model, but the main model should not be influenced
1023+ # by the indexer's learning progress via this loss in sparse training stage.
1024+ query = jax .lax .stop_gradient (query )
1025+ key = jax .lax .stop_gradient (key )
1026+
9921027 # Compute attention scores: [b, t, h, d] @ [b, s, h, d] -> [b, h, t, s]
9931028 attention_scores = jnp .einsum ("bthd, bshd -> bhts" , query , key , precision = self .config .matmul_precision )
9941029
@@ -1080,7 +1115,7 @@ def __call__(
10801115
10811116 # Indexer Logic
10821117 indexer_mask = None
1083- if self .use_sparse_indexer :
1118+ if self .use_indexer :
10841119 if model_mode != MODEL_MODE_TRAIN :
10851120 raise NotImplementedError ("Sparse indexer has not implemented for inference yet." )
10861121 # generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
@@ -1098,14 +1133,14 @@ def __call__(
10981133 attention_mask = attention_mask ,
10991134 )
11001135
1101- if self .config .indexer_loss_scaling_factor > 0.0 :
1136+ if indexer_mask is not None and self .config .indexer_loss_scaling_factor > 0.0 :
11021137 indexer_loss = self .calculate_indexer_loss (
11031138 indexer_score = indexer_score ,
11041139 query = query ,
11051140 key = key ,
11061141 attention_mask = attention_mask ,
11071142 indexer_mask = indexer_mask ,
1108- sparse_loss = self .config .sparse_indexer_loss ,
1143+ sparse_loss = self .config .indexer_sparse_training ,
11091144 scaling_factor = self .config .indexer_loss_scaling_factor ,
11101145 )
11111146 self .sow (nnx .Intermediate , "indexer_loss" , indexer_loss )
0 commit comments