@@ -272,6 +272,7 @@ def convert_to_tokamax_splash_config(
272272 attn_logits_soft_cap : float | None = None ,
273273 fuse_reciprocal : bool = True ,
274274 use_base2_exp : bool = False ,
275+ use_experimental_scheduler : bool = False ,
275276 max_logit_const : float | None = None ,
276277 interpret : bool = False ,
277278 dq_reduction_steps : int | None = None ,
@@ -294,6 +295,7 @@ def convert_to_tokamax_splash_config(
294295 attn_logits_soft_cap = attn_logits_soft_cap ,
295296 fuse_reciprocal = fuse_reciprocal ,
296297 use_base2_exp = use_base2_exp ,
298+ use_experimental_scheduler = use_experimental_scheduler ,
297299 max_logit_const = max_logit_const ,
298300 interpret = interpret ,
299301 dq_reduction_steps = dq_reduction_steps ,
@@ -314,6 +316,8 @@ def _tpu_flash_attention(
314316 mask_padding_tokens : bool = True ,
315317 residual_checkpoint_name : str | None = None ,
316318 attention_mask : jax .Array = None ,
319+ use_base2_exp : bool = False ,
320+ use_experimental_scheduler : bool = False ,
317321) -> jax .Array :
318322 """TPU Flash Attention"""
319323
@@ -399,7 +403,12 @@ def wrap_flash_attention(query, key, value):
399403 splash_kernel = tokamax_splash_attention_kernel .make_splash_mha (
400404 mask = mask ,
401405 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
402- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
406+ config = convert_to_tokamax_splash_config (
407+ block_sizes ,
408+ residual_checkpoint_name = residual_checkpoint_name ,
409+ use_base2_exp = use_base2_exp ,
410+ use_experimental_scheduler = use_experimental_scheduler ,
411+ ),
403412 save_residuals = False ,
404413 )
405414 elif attention_kernel == "tokamax_ring" :
@@ -409,7 +418,12 @@ def wrap_flash_attention(query, key, value):
409418 splash_kernel = tokamax_ring_attention_kernel .make_ring_attention (
410419 mask = mask ,
411420 is_mqa = False ,
412- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
421+ config = convert_to_tokamax_splash_config (
422+ block_sizes ,
423+ residual_checkpoint_name = residual_checkpoint_name ,
424+ use_base2_exp = use_base2_exp ,
425+ use_experimental_scheduler = use_experimental_scheduler ,
426+ ),
413427 save_residuals = False ,
414428 ring_axis = "context" ,
415429 rotate_segment_ids = False , # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
@@ -473,13 +487,13 @@ def ring_scan_body(carry, _):
473487 raise ValueError ("ring attention requires context > 1" )
474488 return attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
475489
476- devices_in_data_context = mesh .shape ["data" ] * mesh .shape ["context" ]
490+ devices_in_batch_sharding = mesh .shape ["data" ] * ( mesh .shape ["fsdp" ] if "fsdp" in mesh . shape else 1 )
477491 # This warning might show up when doing model eval for example, when calculating model flops
478492 # and that is expected.
479- if not (query .shape [0 ] / devices_in_data_context ).is_integer ():
493+ if not (query .shape [0 ] / devices_in_batch_sharding ).is_integer ():
480494 max_logging .log (
481- "Warning, batch dimension should be shardable among the devices in data and context "
482- f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_context : { devices_in_data_context } "
495+ "Warning, batch dimension should be shardable among the devices in data and fsdp "
496+ f" axis, batch dimension: { query .shape [0 ]} , devices_in_batch_sharding : { devices_in_batch_sharding } "
483497 )
484498 x = wrap_flash_attention (query , key , value )
485499 # Trim back to original sequence length after context-axis padding.
@@ -614,11 +628,11 @@ def wrap_ulysses_attention(query, key, value):
614628 attention_output = jax .lax .all_to_all (attention_output , axis_name = axis_name , split_axis = 2 , concat_axis = 1 , tiled = True )
615629 return attention_output
616630
617- devices_in_data_context = mesh .shape ["data" ] * num_shards
618- if not (query .shape [0 ] / devices_in_data_context ).is_integer ():
631+ devices_in_batch_sharding = mesh .shape ["data" ] * ( mesh . shape [ "fsdp" ] if "fsdp" in mesh . shape else 1 )
632+ if not (query .shape [0 ] / devices_in_batch_sharding ).is_integer ():
619633 max_logging .log (
620- "Warning, batch dimension should be shardable among the devices in data and context "
621- f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_context : { devices_in_data_context } "
634+ "Warning, batch dimension should be shardable among the devices in data and fsdp "
635+ f" axis, batch dimension: { query .shape [0 ]} , devices_in_batch_sharding : { devices_in_batch_sharding } "
622636 )
623637 x = wrap_ulysses_attention (query , key , value )
624638 x = x [:, :, :orig_q_seq_len , :]
@@ -741,6 +755,8 @@ def _apply_attention(
741755 mask_padding_tokens : bool = True ,
742756 residual_checkpoint_name : str | None = None ,
743757 attention_mask : Array = None ,
758+ use_base2_exp : bool = False ,
759+ use_experimental_scheduler : bool = False ,
744760):
745761 """Routes to different attention kernels."""
746762 _check_attention_inputs (query , key , value )
@@ -789,6 +805,8 @@ def _apply_attention(
789805 mask_padding_tokens = mask_padding_tokens ,
790806 residual_checkpoint_name = residual_checkpoint_name ,
791807 attention_mask = attention_mask ,
808+ use_base2_exp = use_base2_exp ,
809+ use_experimental_scheduler = use_experimental_scheduler ,
792810 )
793811 elif "ring" in attention_kernel :
794812 return _tpu_flash_attention (
@@ -983,8 +1001,12 @@ def __init__(
9831001 quant : Quant = None ,
9841002 mask_padding_tokens : bool = True ,
9851003 residual_checkpoint_name : str | None = None ,
1004+ use_base2_exp : bool = False ,
1005+ use_experimental_scheduler : bool = False ,
9861006 ):
9871007 self .dpa_layer = None
1008+ self .use_base2_exp = use_base2_exp
1009+ self .use_experimental_scheduler = use_experimental_scheduler
9881010 if attention_kernel == "cudnn_flash_te" :
9891011 from transformer_engine .jax .flax .transformer import DotProductAttention # pytype: disable=import-error
9901012
@@ -1045,6 +1067,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
10451067 mask_padding_tokens = self .mask_padding_tokens ,
10461068 residual_checkpoint_name = self .residual_checkpoint_name ,
10471069 attention_mask = attention_mask ,
1070+ use_base2_exp = self .use_base2_exp if hasattr (self , "use_base2_exp" ) else False ,
1071+ use_experimental_scheduler = self .use_experimental_scheduler if hasattr (self , "use_experimental_scheduler" ) else False ,
10481072 )
10491073
10501074
@@ -1063,6 +1087,8 @@ class AttentionOp(nn.Module):
10631087 flash_block_sizes : BlockSizes = None
10641088 dtype : DType = jnp .float32
10651089 quant : Quant = None
1090+ use_base2_exp : bool = False
1091+ use_experimental_scheduler : bool = False
10661092
10671093 def setup (self ):
10681094 self .dpa_layer = None
@@ -1108,6 +1134,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
11081134 flash_block_sizes = self .flash_block_sizes ,
11091135 dpa_layer = self .dpa_layer ,
11101136 attention_mask = attention_mask ,
1137+ use_base2_exp = self .use_base2_exp ,
1138+ use_experimental_scheduler = self .use_experimental_scheduler ,
11111139 )
11121140
11131141
@@ -1144,6 +1172,8 @@ def __init__(
11441172 enable_jax_named_scopes : bool = False ,
11451173 added_kv_proj_dim : Optional [int ] = None , # New for I2V
11461174 image_seq_len : Optional [int ] = None , # New for I2V
1175+ use_base2_exp : bool = False ,
1176+ use_experimental_scheduler : bool = False ,
11471177 ):
11481178 if attention_kernel in {"flash" , "cudnn_flash_te" } and mesh is None :
11491179 raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { self .mesh } " )
@@ -1186,6 +1216,8 @@ def __init__(
11861216 quant = quant ,
11871217 mask_padding_tokens = mask_padding_tokens ,
11881218 residual_checkpoint_name = residual_checkpoint_name ,
1219+ use_base2_exp = use_base2_exp ,
1220+ use_experimental_scheduler = use_experimental_scheduler ,
11891221 )
11901222 # None axes corresponds to the stacked weights across all blocks
11911223 # because of the use of nnx.vmap and nnx.scan.
0 commit comments