@@ -190,6 +190,49 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
190190 return tensor , kv_size , seq_len
191191
192192
193+ def _flash_sequence_length (tensor : Array ) -> int :
194+ if tensor .ndim == 3 :
195+ return tensor .shape [1 ]
196+ if tensor .ndim == 4 :
197+ return tensor .shape [2 ]
198+ raise ValueError (f"Flash attention expects rank-3 or rank-4 inputs, got rank { tensor .ndim } ." )
199+
200+
201+ def _select_flash_block_sizes (
202+ query : Array ,
203+ key : Array ,
204+ flash_block_sizes : BlockSizes ,
205+ dtype : jnp .dtype ,
206+ attention_kernel : str ,
207+ ) -> BlockSizes :
208+ query_seq_len = _flash_sequence_length (query )
209+ key_seq_len = _flash_sequence_length (key )
210+
211+ q_max_block_size = 1024 if dtype == jnp .bfloat16 else 512
212+ if key_seq_len != query_seq_len :
213+ kv_max_block_size = ((key_seq_len + 127 ) // 128 ) * 128
214+ else :
215+ kv_max_block_size = q_max_block_size
216+
217+ # Keep configured block sizes for self-attention, but let
218+ # cross-attention derive safe KV-aware sizes when q_len != kv_len.
219+ if flash_block_sizes and key_seq_len == query_seq_len :
220+ return flash_block_sizes
221+
222+ block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
223+ return splash_attention_kernel .BlockSizes (
224+ block_q = block_size_q ,
225+ block_kv_compute = min (kv_max_block_size , key_seq_len ),
226+ block_kv = min (kv_max_block_size , key_seq_len ),
227+ block_q_dkv = block_size_q ,
228+ block_kv_dkv = min (kv_max_block_size , key_seq_len ),
229+ block_kv_dkv_compute = min (kv_max_block_size , query_seq_len ),
230+ block_q_dq = None if attention_kernel == "tokamax_flash" else block_size_q ,
231+ block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query_seq_len ),
232+ use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
233+ )
234+
235+
193236def convert_to_tokamax_splash_config (
194237 block_sizes : BlockSizes ,
195238 q_layout : tokamax_splash_attention_kernel .QKVLayout = tokamax_splash_attention_kernel .QKVLayout .HEAD_DIM_MINOR ,
@@ -244,28 +287,7 @@ def _tpu_flash_attention(
244287) -> jax .Array :
245288 """TPU Flash Attention"""
246289
247- q_max_block_size = 1024 if dtype == jnp .bfloat16 else 512
248- # This is the case for cross-attn.
249- if key .shape [1 ] != query .shape [1 ]:
250- kv_max_block_size = ((key .shape [1 ] + 127 ) // 128 ) * 128
251- else :
252- kv_max_block_size = q_max_block_size
253- # ensure that for cross attention we override the block sizes.
254- if flash_block_sizes and key .shape [1 ] == query .shape [1 ]:
255- block_sizes = flash_block_sizes
256- else :
257- block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
258- block_sizes = splash_attention_kernel .BlockSizes (
259- block_q = block_size_q ,
260- block_kv_compute = min (kv_max_block_size , key .shape [2 ]),
261- block_kv = min (kv_max_block_size , key .shape [2 ]),
262- block_q_dkv = block_size_q ,
263- block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
264- block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
265- block_q_dq = None if attention_kernel == "tokamax_flash" else block_size_q ,
266- block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query .shape [2 ]),
267- use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
268- )
290+ block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , attention_kernel )
269291 num_context_shards = mesh .shape ["context" ]
270292 query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_context_shards )
271293 key , _ = _reshape_data_for_flash (key , heads , num_context_shards )
@@ -717,8 +739,8 @@ def __init__(
717739 dtype = dtype ,
718740 param_dtype = weights_dtype ,
719741 precision = precision ,
720- kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , None )),
721- bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None ,)),
742+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "mlp" )),
743+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp" ,)),
722744 )
723745 self .act = get_activation (activation_fn )
724746 self .net_2 = nnx .Linear (
@@ -729,8 +751,8 @@ def __init__(
729751 dtype = dtype ,
730752 param_dtype = weights_dtype ,
731753 precision = precision ,
732- kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed " , "mlp " )),
733- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp " ,)),
754+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("mlp " , "embed " )),
755+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed " ,)),
734756 )
735757
736758 def __call__ (self , hidden_states : Array ) -> Array :
@@ -979,7 +1001,7 @@ def __init__(
9791001 precision = precision ,
9801002 bias_init = nnx .with_partitioning (
9811003 nnx .initializers .zeros ,
982- ("embed " ,),
1004+ ("heads " ,),
9831005 ),
9841006 )
9851007
@@ -993,7 +1015,7 @@ def __init__(
9931015 precision = precision ,
9941016 bias_init = nnx .with_partitioning (
9951017 nnx .initializers .zeros ,
996- ("embed " ,),
1018+ ("heads " ,),
9971019 ),
9981020 )
9991021
@@ -1007,7 +1029,7 @@ def __init__(
10071029 precision = precision ,
10081030 bias_init = nnx .with_partitioning (
10091031 nnx .initializers .zeros ,
1010- ("embed " ,),
1032+ ("heads " ,),
10111033 ),
10121034 )
10131035
@@ -1021,7 +1043,7 @@ def __init__(
10211043 precision = precision ,
10221044 bias_init = nnx .with_partitioning (
10231045 nnx .initializers .zeros ,
1024- ("heads " ,),
1046+ ("embed " ,),
10251047 ),
10261048 )
10271049
@@ -1333,11 +1355,13 @@ def setup(self):
13331355 precision = self .precision ,
13341356 )
13351357
1358+ proj_attn_kernel_axes = ("heads" , "embed" )
1359+
13361360 self .proj_attn = nn .Dense (
13371361 self .query_dim ,
1338- kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), kernel_axes ),
1362+ kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), proj_attn_kernel_axes ),
13391363 use_bias = True ,
1340- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("heads " ,)),
1364+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("embed " ,)),
13411365 dtype = self .dtype ,
13421366 param_dtype = self .weights_dtype ,
13431367 name = "i_proj" ,
@@ -1346,9 +1370,9 @@ def setup(self):
13461370
13471371 self .encoder_proj_attn = nn .Dense (
13481372 self .query_dim ,
1349- kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), kernel_axes ),
1373+ kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), proj_attn_kernel_axes ),
13501374 use_bias = True ,
1351- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("heads " ,)),
1375+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("embed " ,)),
13521376 dtype = self .dtype ,
13531377 param_dtype = self .weights_dtype ,
13541378 name = "e_proj" ,
0 commit comments