@@ -884,7 +884,7 @@ def apply_attention(
884884 previous_chunk : Any = None ,
885885 bidirectional_mask : Any = None ,
886886 sinks : Array | None = None ,
887- index_mask : Array | None = None ,
887+ indexer_mask : Array | None = None ,
888888 record_max_logits : bool = False ,
889889 * ,
890890 qk_product_einsum : Callable [..., Array ],
@@ -929,7 +929,7 @@ def apply_attention(
929929 previous_chunk ,
930930 bidirectional_mask = bidirectional_mask ,
931931 sinks = sinks ,
932- index_mask = index_mask ,
932+ indexer_mask = indexer_mask ,
933933 record_max_logits = record_max_logits ,
934934 qk_product_einsum = qk_product_einsum ,
935935 wv_product_einsum = wv_product_einsum ,
@@ -1134,7 +1134,7 @@ def tpu_flash_attention(
11341134 decoder_segment_ids : Array | None ,
11351135 attn_logits_soft_cap : float | None = None ,
11361136 sinks : Array | None = None ,
1137- index_mask : Array | None = None ,
1137+ indexer_mask : Array | None = None ,
11381138 record_max_logits : bool = False ,
11391139 ) -> tuple [Array , Array ]:
11401140 """TPU Flash Attention."""
@@ -1161,12 +1161,12 @@ def tpu_flash_attention(
11611161 axis_names_splash_kernel = self ._logical_to_mesh_axes (self .flash_axis_names_splash_kernel_ep )
11621162 axis_names_q = self ._logical_to_mesh_axes (self .flash_axis_names_q_ep )
11631163 axis_names_kv = self ._logical_to_mesh_axes (self .flash_axis_names_kv_ep )
1164- index_mask_axis_names = self ._logical_to_mesh_axes ((BATCH_NO_EXP , Q_LENGTH , KV_LENGTH ))
1164+ indexer_mask_axis_names = self ._logical_to_mesh_axes ((BATCH_NO_EXP , Q_LENGTH , KV_LENGTH ))
11651165 else :
11661166 axis_names_splash_kernel = self ._logical_to_mesh_axes (self .flash_axis_names_splash_kernel )
11671167 axis_names_q = self ._logical_to_mesh_axes (self .flash_axis_names_q )
11681168 axis_names_kv = self ._logical_to_mesh_axes (self .flash_axis_names_kv )
1169- index_mask_axis_names = self ._logical_to_mesh_axes ((BATCH , Q_LENGTH , KV_LENGTH ))
1169+ indexer_mask_axis_names = self ._logical_to_mesh_axes ((BATCH , Q_LENGTH , KV_LENGTH ))
11701170
11711171 global global_block_q , global_block_kv , global_block_kv_compute , global_block_q_dkv , global_block_kv_dkv
11721172 global global_block_kv_dkv_compute , global_block_q_dq , global_block_kv_dq , global_use_fused_bwd_kernel
@@ -1376,7 +1376,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
13761376 None , # no sharding for cp_size
13771377 None , # no sharding for load_balanced_context_parallel
13781378 sink_axis_names , # sharding align with query heads
1379- index_mask_axis_names ,
1379+ indexer_mask_axis_names ,
13801380 ),
13811381 out_specs = out_specs ,
13821382 check_vma = False ,
@@ -1392,7 +1392,7 @@ def wrap_flash_attention(
13921392 cp_size ,
13931393 load_balanced_context_parallel ,
13941394 sinks ,
1395- index_mask ,
1395+ indexer_mask ,
13961396 ):
13971397 # If load_balanced_context_parallel is enabled, reorder the key and value tensors
13981398 # to ensure that they are contiguous in memory.
@@ -1421,11 +1421,11 @@ def wrap_flash_attention(
14211421 decoder_segment_ids_tuple = None
14221422
14231423 if self .config .use_tokamax_splash :
1424- if self .config .use_sparse_indexer and index_mask is not None :
1424+ if self .config .use_sparse_indexer and indexer_mask is not None :
14251425 # Construct the splash kernel call with dynamic mask
1426- def dynamic_mask_splash_kernel (q , k , v , segment , sinks , index_mask ):
1426+ def dynamic_mask_splash_kernel (q , k , v , segment , sinks , indexer_mask ):
14271427 splash_kernel = tokamax_splash_kernel .make_dynamic_splash_mha (
1428- mask = index_mask ,
1428+ mask = indexer_mask ,
14291429 config = sa_config ,
14301430 )
14311431 kernel = partial (splash_kernel , max_logit_value = max_logit_value )
@@ -1438,13 +1438,13 @@ def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask):
14381438
14391439 # Iterate over batch dimension for (query, key, value, segment, sinks, mask)
14401440 attn_fn = jax .vmap (dynamic_mask_splash_kernel , (0 , 0 , 0 , 0 , None , 0 ))
1441- index_mask = jnp .isclose (index_mask , 0.0 )
1441+ indexer_mask = jnp .isclose (indexer_mask , 0.0 )
14421442
14431443 if record_max_logits :
1444- attention_output , max_logits = attn_fn (query , key , value , decoder_segment_ids_tuple , sinks , index_mask )
1444+ attention_output , max_logits = attn_fn (query , key , value , decoder_segment_ids_tuple , sinks , indexer_mask )
14451445 return attention_output , max_logits
14461446 else :
1447- attention_output , _ = attn_fn (query , key , value , decoder_segment_ids_tuple , sinks , index_mask )
1447+ attention_output , _ = attn_fn (query , key , value , decoder_segment_ids_tuple , sinks , indexer_mask )
14481448 return attention_output , None
14491449 else :
14501450 kernel = partial (splash_kernel , max_logit_value = max_logit_value )
@@ -1509,7 +1509,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
15091509 decoder_segment_ids_q = _maybe_shard_with_pspec (decoder_segment_ids , segment_axis_names_q )
15101510 decoder_segment_ids_kv = _maybe_shard_with_pspec (decoder_segment_ids , segment_axis_names_kv )
15111511 sinks = _maybe_shard_with_pspec (sinks , sink_axis_names )
1512- index_mask = _maybe_shard_with_pspec (index_mask , index_mask_axis_names )
1512+ indexer_mask = _maybe_shard_with_pspec (indexer_mask , indexer_mask_axis_names )
15131513
15141514 ret = wrap_flash_attention (
15151515 query ,
@@ -1522,7 +1522,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
15221522 cp_size ,
15231523 load_balanced_context_parallel ,
15241524 sinks ,
1525- index_mask ,
1525+ indexer_mask ,
15261526 )
15271527
15281528 x , max_logits = ret
@@ -1766,7 +1766,7 @@ def apply_attention_dot(
17661766 previous_chunk : Any = None ,
17671767 bidirectional_mask : Any = None ,
17681768 sinks : Array | None = None ,
1769- index_mask : Array | None = None ,
1769+ indexer_mask : Array | None = None ,
17701770 record_max_logits : bool = False ,
17711771 * ,
17721772 qk_product_einsum : Callable [..., Array ],
@@ -1846,11 +1846,11 @@ def apply_attention_dot(
18461846
18471847 # Apply index mask, deepseek sparse attention
18481848 # index mask contains 0.0 for kept tokens and large negative for masked tokens.
1849- if index_mask is not None :
1850- # index_mask : from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len]
1851- index_mask = index_mask [:, None , None , :, :]
1849+ if indexer_mask is not None :
1850+ # indexer_mask : from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len]
1851+ indexer_mask = indexer_mask [:, None , None , :, :]
18521852 # attn_weights: [b, n_kv, n_q // n_kv, q_len, kv_len]
1853- attn_weights = apply_mask_to_logits (attn_weights , index_mask )
1853+ attn_weights = apply_mask_to_logits (attn_weights , indexer_mask )
18541854
18551855 if self .is_partition_in_decode (q_seq_len ):
18561856 attn_mask = partitioning .with_sharding_constraint (attn_mask , (KV_LENGTH , HEAD , None , None , None ))
@@ -2035,7 +2035,7 @@ def __call__(
20352035 previous_chunk = None ,
20362036 bidirectional_mask = None ,
20372037 sinks = None ,
2038- index_mask : Optional [Array ] = None ,
2038+ indexer_mask : Optional [Array ] = None ,
20392039 slot : Optional [int ] = None ,
20402040 page_state : Optional [page_manager .PageState ] = None ,
20412041 record_max_logits : bool = False ,
@@ -2059,7 +2059,7 @@ def __call__(
20592059 previous_chunk = previous_chunk ,
20602060 bidirectional_mask = bidirectional_mask ,
20612061 sinks = sinks ,
2062- index_mask = index_mask ,
2062+ indexer_mask = indexer_mask ,
20632063 record_max_logits = record_max_logits ,
20642064 qk_product_einsum = self .AqtEinsum_0 ,
20652065 wv_product_einsum = self .AqtEinsum_1 ,
0 commit comments