2727from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
2828from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
30- from tokamax ._src .ops .experimental .tpu .splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
3130from einops import rearrange
3231from .. import common_types , max_logging
3332
@@ -305,92 +304,62 @@ def wrap_flash_attention(query, key, value):
305304 mask = mask ,
306305 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
307306 config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
308- save_residuals = True if "ring" in attention_kernel else False ,
309- )
310- elif attention_kernel == "tokamax_ring" :
311- mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
312- splash_kernel = tokamax_ring_attention_kernel .make_ring_attention (
313- mask = mask ,
314- is_mqa = False ,
315- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
316- save_residuals = True ,
317- ring_axis = "fsdp" ,
307+ save_residuals = True if attention_kernel == "ring" else False ,
318308 )
319309 else :
320310 splash_kernel = splash_attention_kernel .make_splash_mha (
321311 mask = multi_head_mask ,
322312 head_shards = 1 , # the sizes of the axis is sharding over heads
323313 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
324314 block_sizes = block_sizes ,
325- save_residuals = True if "ring" in attention_kernel else False ,
315+ save_residuals = True if attention_kernel == "ring" else False ,
326316 residual_checkpoint_name = residual_checkpoint_name
327317 )
318+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
328319
329- if attention_kernel == "tokamax_ring" :
330- # For tokamax_ring, use the kernel directly without vmap
331- # The ring attention kernel handles the ring topology internally
332- if not mask_padding_tokens :
333- segment_ids = None
334- attention_output = splash_kernel (
335- fwd_mask_info = None ,
336- dkv_mask_info = None ,
337- q = query ,
338- k = key ,
339- v = value ,
340- segment_ids = segment_ids ,
341- is_mqa = False ,
342- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
343- mask_value = - jnp .inf ,
344- mask_function = None ,
345- fwd_mask_sparsity = 1.0 ,
346- save_residuals = True ,
347- )
320+ if not mask_padding_tokens :
321+ segment_ids = None
322+ if attention_kernel in ["flash" , "tokamax_flash" ]:
323+ attention_output = vmapped_splash (query , key , value , segment_ids )
348324 else :
349- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
350-
351- if not mask_padding_tokens :
352- segment_ids = None
353- if attention_kernel in ["flash" , "tokamax_flash" ]:
354- attention_output = vmapped_splash (query , key , value , segment_ids )
355- else :
356- if num_fsdp_shards > 1 :
357- out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
358- m = lse .astype (jnp .float32 )
359- l = jnp .exp (lse - m )
360- o = out .astype (jnp .float32 ) * l [..., None ]
325+ if num_fsdp_shards > 1 :
326+ out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
327+ m = lse .astype (jnp .float32 )
328+ l = jnp .exp (lse - m )
329+ o = out .astype (jnp .float32 ) * l [..., None ]
361330
362- perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
331+ perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
363332
364- k1 = jax .lax .ppermute (key , axis_name = "fsdp" , perm = perm )
365- v1 = jax .lax .ppermute (value , axis_name = "fsdp" , perm = perm )
333+ k1 = jax .lax .ppermute (key , axis_name = "fsdp" , perm = perm )
334+ v1 = jax .lax .ppermute (value , axis_name = "fsdp" , perm = perm )
366335
367- def ring_scan_body (carry , _ ):
368- m , l , o , k_current , v_current = carry
369- k_next = jax .lax .ppermute (k_current , axis_name = "fsdp" , perm = perm )
370- v_next = jax .lax .ppermute (v_current , axis_name = "fsdp" , perm = perm )
336+ def ring_scan_body (carry , _ ):
337+ m , l , o , k_current , v_current = carry
338+ k_next = jax .lax .ppermute (k_current , axis_name = "fsdp" , perm = perm )
339+ v_next = jax .lax .ppermute (v_current , axis_name = "fsdp" , perm = perm )
371340
372- out_chunk , (lse_chunk ,) = vmapped_splash (query , k_current , v_current , segment_ids )
341+ out_chunk , (lse_chunk ,) = vmapped_splash (query , k_current , v_current , segment_ids )
373342
374- m_chunk = lse_chunk .astype (jnp .float32 )
375- m_old = m
376- m = jnp .maximum (m_old , m_chunk )
343+ m_chunk = lse_chunk .astype (jnp .float32 )
344+ m_old = m
345+ m = jnp .maximum (m_old , m_chunk )
377346
378- exp_m_diff = jnp .exp (m_old - m )
379- exp_m_chunk_diff = jnp .exp (m_chunk - m )
347+ exp_m_diff = jnp .exp (m_old - m )
348+ exp_m_chunk_diff = jnp .exp (m_chunk - m )
380349
381- l = l * exp_m_diff + jnp .exp (lse_chunk - m )
382- o = o * exp_m_diff [..., None ]
383- o += exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 )
350+ l = l * exp_m_diff + jnp .exp (lse_chunk - m )
351+ o = o * exp_m_diff [..., None ]
352+ o += exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 )
384353
385- # Return the updated state for the next iteration
386- return (m , l , o , k_next , v_next ), None
354+ # Return the updated state for the next iteration
355+ return (m , l , o , k_next , v_next ), None
387356
388- initial_carry = (m , l , o , k1 , v1 )
389- (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
357+ initial_carry = (m , l , o , k1 , v1 )
358+ (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
390359
391- attention_output = o_final / l_final [..., None ]
392- else :
393- raise ValueError ("ring attention requires fsdp > 1" )
360+ attention_output = o_final / l_final [..., None ]
361+ else :
362+ raise ValueError ("ring attention requires fsdp > 1" )
394363
395364 return attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
396365
@@ -566,7 +535,7 @@ def _apply_attention(
566535 mask_padding_tokens = mask_padding_tokens ,
567536 residual_checkpoint_name = residual_checkpoint_name ,
568537 )
569- elif "ring" in attention_kernel :
538+ elif attention_kernel == "ring" :
570539 return _tpu_flash_attention (
571540 query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel ,
572541 mask_padding_tokens = mask_padding_tokens ,
@@ -577,7 +546,6 @@ def _apply_attention(
577546 raise ValueError (f"Unexpected attention kernel { attention_kernel = } ." )
578547
579548
580-
581549def _query_chunk_attention (query , key , value , precision , key_chunk_size : int = 4096 ):
582550 """Multi-head dot product attention with a limited number of queries."""
583551 num_kv , num_heads , k_features = key .shape [- 3 :]
0 commit comments