3838 AttentionType ,
3939 AxisIdxes ,
4040 AxisNames ,
41- BATCH ,
41+ BATCH_ATTN ,
4242 CACHE_BATCH ,
4343 CACHE_BATCH_PREFILL ,
4444 CACHE_HEADS ,
@@ -297,8 +297,8 @@ def attention_op_as_linen(
297297 float32_qk_product : bool = False ,
298298 max_prefill_predict_length : int = - 1 ,
299299 float32_logits : bool = False ,
300- flash_axis_names_q : AxisNames = (BATCH , HEAD , LENGTH , D_KV ),
301- flash_axis_names_kv : AxisNames = (BATCH , HEAD , KV_LENGTH , D_KV ),
300+ flash_axis_names_q : AxisNames = (BATCH_ATTN , HEAD , LENGTH , D_KV ),
301+ flash_axis_names_kv : AxisNames = (BATCH_ATTN , HEAD , KV_LENGTH , D_KV ),
302302 flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH ),
303303 prefill_cache_logical_axis_names : AxisNames = (
304304 CACHE_BATCH_PREFILL ,
@@ -394,8 +394,8 @@ def __init__(
394394 float32_qk_product : bool = False ,
395395 max_prefill_predict_length : int = - 1 ,
396396 float32_logits : bool = False ,
397- flash_axis_names_q : AxisNames = (BATCH , HEAD , LENGTH , D_KV ),
398- flash_axis_names_kv : AxisNames = (BATCH , HEAD , KV_LENGTH , D_KV ),
397+ flash_axis_names_q : AxisNames = (BATCH_ATTN , HEAD , LENGTH , D_KV ),
398+ flash_axis_names_kv : AxisNames = (BATCH_ATTN , HEAD , KV_LENGTH , D_KV ),
399399 flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH ),
400400 prefill_cache_logical_axis_names : AxisNames = (
401401 CACHE_BATCH_PREFILL ,
@@ -1144,13 +1144,13 @@ def tpu_flash_attention(
11441144 segment_axis_names_kv = None
11451145 sink_axis_names = self ._logical_to_mesh_axes ((HEAD ,))
11461146 if decoder_segment_ids is not None :
1147- segment_axis_names_q = self ._logical_to_mesh_axes ((BATCH , Q_LENGTH ))
1148- segment_axis_names_kv = self ._logical_to_mesh_axes ((BATCH , KV_LENGTH ))
1147+ segment_axis_names_q = self ._logical_to_mesh_axes ((BATCH_ATTN , Q_LENGTH ))
1148+ segment_axis_names_kv = self ._logical_to_mesh_axes ((BATCH_ATTN , KV_LENGTH ))
11491149
11501150 axis_names_splash_kernel = self ._logical_to_mesh_axes (self .flash_axis_names_splash_kernel )
11511151 axis_names_q = self ._logical_to_mesh_axes (self .flash_axis_names_q )
11521152 axis_names_kv = self ._logical_to_mesh_axes (self .flash_axis_names_kv )
1153- indexer_mask_axis_names = self ._logical_to_mesh_axes ((BATCH , Q_LENGTH , KV_LENGTH ))
1153+ indexer_mask_axis_names = self ._logical_to_mesh_axes ((BATCH_ATTN , Q_LENGTH , KV_LENGTH ))
11541154
11551155 global global_block_q , global_block_kv , global_block_kv_compute , global_block_q_dkv , global_block_kv_dkv
11561156 global global_block_kv_dkv_compute , global_block_q_dq , global_block_kv_dq , global_use_fused_bwd_kernel
@@ -1730,7 +1730,7 @@ def compute_local_attention(
17301730 if model_mode == MODEL_MODE_AUTOREGRESSIVE and self .is_partition_in_decode (q_seq_len ):
17311731 local_out = partitioning .with_sharding_constraint (local_out , (DECODE_BATCH , DECODE_LENGTH , HEAD , D_KV ))
17321732 elif model_mode == MODEL_MODE_PREFILL :
1733- local_out = partitioning .with_sharding_constraint (local_out , (BATCH , KV_LENGTH , HEAD , D_KV ))
1733+ local_out = partitioning .with_sharding_constraint (local_out , (BATCH_ATTN , KV_LENGTH , HEAD , D_KV ))
17341734
17351735 if self .reshape_q and q_seq_len == 1 :
17361736 local_max = local_max [:, 0 :1 , :, :]
@@ -1774,7 +1774,7 @@ def apply_attention_dot(
17741774
17751775 # special sharding for decode
17761776 q_seq_len = query .shape [1 ]
1777- prefill_qkv_sharding = (BATCH , PREFILL_LENGTH , HEAD , D_KV )
1777+ prefill_qkv_sharding = (BATCH_ATTN , PREFILL_LENGTH , HEAD , D_KV )
17781778 decode_qkv_sharding = (DECODE_BATCH , DECODE_LENGTH , HEAD , D_KV )
17791779 if self .is_partition_in_decode (q_seq_len ):
17801780 query = partitioning .with_sharding_constraint (query , decode_qkv_sharding )
@@ -1799,7 +1799,9 @@ def apply_attention_dot(
17991799 if self .is_partition_in_decode (q_seq_len ):
18001800 attn_weights = partitioning .with_sharding_constraint (attn_weights , (KV_LENGTH , HEAD , None , None , None ))
18011801 elif model_mode == MODEL_MODE_PREFILL :
1802- attn_weights = partitioning .with_sharding_constraint (attn_weights , (BATCH , HEAD , None , PREFILL_LENGTH , KV_LENGTH ))
1802+ attn_weights = partitioning .with_sharding_constraint (
1803+ attn_weights , (BATCH_ATTN , HEAD , None , PREFILL_LENGTH , KV_LENGTH )
1804+ )
18031805
18041806 if self .attn_logits_soft_cap :
18051807 attn_weights = jnp .tanh (attn_weights / self .attn_logits_soft_cap )
@@ -1846,7 +1848,7 @@ def apply_attention_dot(
18461848 if self .is_partition_in_decode (q_seq_len ):
18471849 attn_mask = partitioning .with_sharding_constraint (attn_mask , (KV_LENGTH , HEAD , None , None , None ))
18481850 elif model_mode == MODEL_MODE_PREFILL :
1849- attn_mask = partitioning .with_sharding_constraint (attn_mask , (BATCH , HEAD , None , PREFILL_LENGTH , KV_LENGTH ))
1851+ attn_mask = partitioning .with_sharding_constraint (attn_mask , (BATCH_ATTN , HEAD , None , PREFILL_LENGTH , KV_LENGTH ))
18501852 if attn_mask is not None :
18511853 attn_weights = apply_mask_to_logits (attn_weights , attn_mask )
18521854
0 commit comments