Skip to content

Commit 5f6d2ac

Browse files
committed
Add Indexer training loss
1 parent f11f550 commit 5f6d2ac

8 files changed

Lines changed: 271 additions & 36 deletions

File tree

src/maxtext/common/metric_logger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,9 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None):
338338
self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(
339339
metrics["scalar"].get("evaluation/moe_lb_loss", 0.0)
340340
)
341+
self.cumulative_eval_metrics["scalar"]["eval/indexer_loss"] += float(
342+
metrics["scalar"].get("evaluation/indexer_loss", 0.0)
343+
)
341344
self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(metrics["scalar"].get("evaluation/mtp_loss", 0.0))
342345
self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] += float(
343346
metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0)
@@ -355,6 +358,9 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None):
355358
self.cumulative_eval_metrics["scalar"]["eval/avg_moe_lb_loss"] = (
356359
self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count
357360
)
361+
self.cumulative_eval_metrics["scalar"]["eval/avg_indexer_loss"] = (
362+
self.cumulative_eval_metrics["scalar"]["eval/indexer_loss"] / eval_step_count
363+
)
358364
self.cumulative_eval_metrics["scalar"]["eval/avg_mtp_loss"] = (
359365
self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] / eval_step_count
360366
)

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ use_sparse_indexer: False
359359
index_head_dim: 128
360360
index_n_heads: 64
361361
index_topk: 2048
362+
sparse_indexer_loss: False # Indicate if use sparse loss for indexer training.
363+
indexer_loss_scaling_factor: 0.0 # Scaling factor for the indexer KL divergence loss
362364

363365
# MLA parameters
364366
q_lora_rank: 0

src/maxtext/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,8 @@ class AttentionIndexer(BaseModel):
534534
index_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.")
535535
index_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.")
536536
index_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")
537+
sparse_indexer_loss: bool = Field(False, description="Use sparse loss for indexer training.")
538+
indexer_loss_scaling_factor: float = Field(0.0, description="Scaling factor for the indexer KL divergence loss.")
537539

538540

539541
class Llama4Attention(BaseModel):

src/maxtext/layers/attention_mla.py

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from maxtext.inference import paged_attention
7474
from maxtext.inference.kvcache import KVQuant
7575
from maxtext.utils.sharding import create_sharding
76+
from maxtext.utils.globals import EPS
7677

7778

7879
class Indexer(nnx.Module):
@@ -246,10 +247,10 @@ def __call__(
246247
the inputs and configuration.
247248
248249
Returns:
249-
index_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens
250+
indexer_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens
250251
and large negative values otherwise.
251252
topk_indices: Indices of the top-k selected tokens [b, t, k].
252-
index_score: The computed relevance scores [b, t, s].
253+
indexer_score: The computed relevance scores [b, t, s].
253254
254255
Notation:
255256
b: Batch size
@@ -283,27 +284,27 @@ def __call__(
283284
logits = jax.nn.relu(logits)
284285
# Compute head weights: project from input, [b, t, embed_dim] -> [b, t, h]
285286
weights = self.weights_proj(inputs_q)
286-
# Weights scaling affect index_score, but does not affect topk_indices. Keep scaling for numerical stability.
287+
# Weights scaling affect indexer_score, but does not affect topk_indices. Keep scaling for numerical stability.
287288
# https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480
288289
weights = weights * (self.n_heads**-0.5) * self.softmax_scale
289290
# Aggregate head-wise logits: logits @ weights
290-
index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]
291+
indexer_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]
291292

292293
# Apply attention mask before TopK
293294
if attention_mask is not None:
294-
index_score += attention_mask
295+
indexer_score += attention_mask
295296

296297
# TopK selection based on index score
297-
_, topk_indices = jax.lax.top_k(index_score, k=self.index_topk) # topk_indices [b, t, k]
298+
_, topk_indices = jax.lax.top_k(indexer_score, k=self.index_topk) # topk_indices [b, t, k]
298299

299300
# Create Sparse Index Mask: 0 and large negatives
300-
index_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s]
301+
indexer_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s]
301302

302303
# Re-apply attention mask after TopK: in case number of unmasked tokens < TopK
303304
if attention_mask is not None:
304-
index_mask += attention_mask
305+
indexer_mask += attention_mask
305306

306-
return index_mask, topk_indices, index_score
307+
return indexer_mask, topk_indices, indexer_score
307308

308309

309310
def mla_as_linen(
@@ -951,6 +952,61 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
951952

952953
return key, value, cached_values
953954

955+
def calculate_indexer_loss(
956+
self,
957+
indexer_score: Array,
958+
query: Array,
959+
key: Array,
960+
attention_mask: Optional[Array | None],
961+
indexer_mask: Array,
962+
sparse_loss: bool,
963+
scaling_factor: float,
964+
) -> Array:
965+
"""Calculates the indexer KL divergence loss.
966+
967+
This loss trains the indexer to predict which tokens are important by matching
968+
the distribution of true attention scores from the main model.
969+
970+
# Ref: DeepSeek-V3.2 - https://arxiv.org/pdf/2512.02556
971+
972+
Args:
973+
indexer_score: Scores predicted by indexer [batch, q_len, kv_len].
974+
query: Query tensor from main model [batch, q_len, heads, dim].
975+
key: Key tensor from main model [batch, kv_len, heads, dim].
976+
attention_mask: Attention mask [batch, q_len, kv_len] or None.
977+
indexer_mask: Indexer mask [batch, q_len, kv_len].
978+
sparse_loss: Whether to use sparse loss.
979+
scaling_factor: The scaling factor for the loss.
980+
981+
Returns:
982+
The computed KL divergence loss.
983+
"""
984+
# Compute attention scores: [b, t, h, d] @ [b, s, h, d] -> [b, h, t, s]
985+
attention_scores = jnp.einsum("bthd, bshd -> bhts", query, key, precision=self.config.matmul_precision)
986+
if attention_mask is not None:
987+
attention_scores = attention_scores + attention_mask[:, None, :, :]
988+
# indexer_score already has attention_mask added in Indexer.__call__
989+
990+
if sparse_loss:
991+
attention_scores = attention_scores + indexer_mask[:, None, :, :]
992+
indexer_score = indexer_score + indexer_mask
993+
994+
# Use float32 for softmax numerical stability.
995+
attention_probs = jax.nn.softmax(attention_scores.astype(jnp.float32), axis=-1)
996+
indexer_probs = jax.nn.softmax(indexer_score.astype(jnp.float32), axis=-1)
997+
998+
# Aggregate heads: [b, h, t, s] -> [b, t, s]
999+
attention_probs = jnp.sum(attention_probs, axis=1)
1000+
# L1 normalize aggregated target distribution
1001+
attention_probs = attention_probs / (jnp.sum(attention_probs, axis=-1, keepdims=True) + EPS)
1002+
1003+
# KL Divergence: KL(attention || indexer)
1004+
kl_per_token = jax.scipy.special.kl_div(attention_probs + EPS, indexer_probs + EPS)
1005+
# kl_per_token = attention_probs * (jnp.log(attention_probs + EPS) - jnp.log(indexer_probs + EPS))
1006+
indexer_loss = jnp.mean(jnp.sum(kl_per_token, axis=-1))
1007+
1008+
return indexer_loss * scaling_factor
1009+
9541010
def __call__(
9551011
self,
9561012
inputs_q: Array,
@@ -1013,23 +1069,37 @@ def __call__(
10131069
value = checkpoint_name(value, "value_proj")
10141070

10151071
# Indexer Logic
1016-
index_mask = None
1072+
indexer_mask = None
10171073
if self.use_sparse_indexer:
10181074
if model_mode != MODEL_MODE_TRAIN:
10191075
raise NotImplementedError("Sparse indexer has not implemented for inference yet.")
10201076
# generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
10211077
attention_mask = self.attention_op.generate_attention_mask(
10221078
query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask
1023-
).squeeze(axis=(1, 2))
1024-
# apply indexer, index_mask [b, q_len, kv_len]
1025-
index_mask, _, _ = self.indexer(
1079+
)
1080+
if attention_mask is not None:
1081+
attention_mask = attention_mask.squeeze(axis=(1, 2))
1082+
# apply indexer, indexer_mask [b, q_len, kv_len]
1083+
indexer_mask, _, indexer_score = self.indexer(
10261084
inputs_q=inputs_q,
10271085
low_rank_q=low_rank_q,
10281086
inputs_kv=inputs_kv,
10291087
inputs_positions=inputs_positions,
10301088
attention_mask=attention_mask,
10311089
)
10321090

1091+
if self.config.indexer_loss_scaling_factor > 0.0:
1092+
indexer_loss = self.calculate_indexer_loss(
1093+
indexer_score=indexer_score,
1094+
query=query,
1095+
key=key,
1096+
attention_mask=attention_mask,
1097+
indexer_mask=indexer_mask,
1098+
sparse_loss=self.config.sparse_indexer_loss,
1099+
scaling_factor=self.config.indexer_loss_scaling_factor,
1100+
)
1101+
self.sow(nnx.Intermediate, "indexer_loss", indexer_loss)
1102+
10331103
# Check if we need QK Clip stats
10341104
use_qk_clip = self.model_mode == MODEL_MODE_TRAIN and self.config.use_qk_clip
10351105

@@ -1047,7 +1117,7 @@ def __call__(
10471117
decoder_segment_ids,
10481118
model_mode,
10491119
cached_values,
1050-
index_mask=index_mask,
1120+
indexer_mask=indexer_mask,
10511121
record_max_logits=use_qk_clip,
10521122
)
10531123

src/maxtext/layers/attention_op.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def apply_attention(
884884
previous_chunk: Any = None,
885885
bidirectional_mask: Any = None,
886886
sinks: Array | None = None,
887-
index_mask: Array | None = None,
887+
indexer_mask: Array | None = None,
888888
record_max_logits: bool = False,
889889
*,
890890
qk_product_einsum: Callable[..., Array],
@@ -929,7 +929,7 @@ def apply_attention(
929929
previous_chunk,
930930
bidirectional_mask=bidirectional_mask,
931931
sinks=sinks,
932-
index_mask=index_mask,
932+
indexer_mask=indexer_mask,
933933
record_max_logits=record_max_logits,
934934
qk_product_einsum=qk_product_einsum,
935935
wv_product_einsum=wv_product_einsum,
@@ -1134,7 +1134,7 @@ def tpu_flash_attention(
11341134
decoder_segment_ids: Array | None,
11351135
attn_logits_soft_cap: float | None = None,
11361136
sinks: Array | None = None,
1137-
index_mask: Array | None = None,
1137+
indexer_mask: Array | None = None,
11381138
record_max_logits: bool = False,
11391139
) -> tuple[Array, Array]:
11401140
"""TPU Flash Attention."""
@@ -1161,12 +1161,12 @@ def tpu_flash_attention(
11611161
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep)
11621162
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep)
11631163
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep)
1164-
index_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH))
1164+
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH))
11651165
else:
11661166
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
11671167
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
11681168
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
1169-
index_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))
1169+
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))
11701170

11711171
global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
11721172
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel
@@ -1376,7 +1376,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
13761376
None, # no sharding for cp_size
13771377
None, # no sharding for load_balanced_context_parallel
13781378
sink_axis_names, # sharding align with query heads
1379-
index_mask_axis_names,
1379+
indexer_mask_axis_names,
13801380
),
13811381
out_specs=out_specs,
13821382
check_vma=False,
@@ -1392,7 +1392,7 @@ def wrap_flash_attention(
13921392
cp_size,
13931393
load_balanced_context_parallel,
13941394
sinks,
1395-
index_mask,
1395+
indexer_mask,
13961396
):
13971397
# If load_balanced_context_parallel is enabled, reorder the key and value tensors
13981398
# to ensure that they are contiguous in memory.
@@ -1421,11 +1421,11 @@ def wrap_flash_attention(
14211421
decoder_segment_ids_tuple = None
14221422

14231423
if self.config.use_tokamax_splash:
1424-
if self.config.use_sparse_indexer and index_mask is not None:
1424+
if self.config.use_sparse_indexer and indexer_mask is not None:
14251425
# Construct the splash kernel call with dynamic mask
1426-
def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask):
1426+
def dynamic_mask_splash_kernel(q, k, v, segment, sinks, indexer_mask):
14271427
splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha(
1428-
mask=index_mask,
1428+
mask=indexer_mask,
14291429
config=sa_config,
14301430
)
14311431
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
@@ -1438,13 +1438,13 @@ def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask):
14381438

14391439
# Iterate over batch dimension for (query, key, value, segment, sinks, mask)
14401440
attn_fn = jax.vmap(dynamic_mask_splash_kernel, (0, 0, 0, 0, None, 0))
1441-
index_mask = jnp.isclose(index_mask, 0.0)
1441+
indexer_mask = jnp.isclose(indexer_mask, 0.0)
14421442

14431443
if record_max_logits:
1444-
attention_output, max_logits = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask)
1444+
attention_output, max_logits = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, indexer_mask)
14451445
return attention_output, max_logits
14461446
else:
1447-
attention_output, _ = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask)
1447+
attention_output, _ = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, indexer_mask)
14481448
return attention_output, None
14491449
else:
14501450
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
@@ -1509,7 +1509,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
15091509
decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q)
15101510
decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv)
15111511
sinks = _maybe_shard_with_pspec(sinks, sink_axis_names)
1512-
index_mask = _maybe_shard_with_pspec(index_mask, index_mask_axis_names)
1512+
indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names)
15131513

15141514
ret = wrap_flash_attention(
15151515
query,
@@ -1522,7 +1522,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
15221522
cp_size,
15231523
load_balanced_context_parallel,
15241524
sinks,
1525-
index_mask,
1525+
indexer_mask,
15261526
)
15271527

15281528
x, max_logits = ret
@@ -1766,7 +1766,7 @@ def apply_attention_dot(
17661766
previous_chunk: Any = None,
17671767
bidirectional_mask: Any = None,
17681768
sinks: Array | None = None,
1769-
index_mask: Array | None = None,
1769+
indexer_mask: Array | None = None,
17701770
record_max_logits: bool = False,
17711771
*,
17721772
qk_product_einsum: Callable[..., Array],
@@ -1846,11 +1846,11 @@ def apply_attention_dot(
18461846

18471847
# Apply index mask, deepseek sparse attention
18481848
# index mask contains 0.0 for kept tokens and large negative for masked tokens.
1849-
if index_mask is not None:
1850-
# index_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len]
1851-
index_mask = index_mask[:, None, None, :, :]
1849+
if indexer_mask is not None:
1850+
# indexer_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len]
1851+
indexer_mask = indexer_mask[:, None, None, :, :]
18521852
# attn_weights: [b, n_kv, n_q // n_kv, q_len, kv_len]
1853-
attn_weights = apply_mask_to_logits(attn_weights, index_mask)
1853+
attn_weights = apply_mask_to_logits(attn_weights, indexer_mask)
18541854

18551855
if self.is_partition_in_decode(q_seq_len):
18561856
attn_mask = partitioning.with_sharding_constraint(attn_mask, (KV_LENGTH, HEAD, None, None, None))
@@ -2035,7 +2035,7 @@ def __call__(
20352035
previous_chunk=None,
20362036
bidirectional_mask=None,
20372037
sinks=None,
2038-
index_mask: Optional[Array] = None,
2038+
indexer_mask: Optional[Array] = None,
20392039
slot: Optional[int] = None,
20402040
page_state: Optional[page_manager.PageState] = None,
20412041
record_max_logits: bool = False,
@@ -2059,7 +2059,7 @@ def __call__(
20592059
previous_chunk=previous_chunk,
20602060
bidirectional_mask=bidirectional_mask,
20612061
sinks=sinks,
2062-
index_mask=index_mask,
2062+
indexer_mask=indexer_mask,
20632063
record_max_logits=record_max_logits,
20642064
qk_product_einsum=self.AqtEinsum_0,
20652065
wv_product_einsum=self.AqtEinsum_1,

0 commit comments

Comments
 (0)