@@ -723,20 +723,22 @@ def _ulysses_ring_attention(
723723 num_ring_shards = mesh .shape [ring_axis ]
724724 num_sequence_shards = num_ulysses_shards * num_ring_shards
725725
726- query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_sequence_shards )
727- key , _ = _reshape_data_for_flash (key , heads , num_sequence_shards )
728- value , _ = _reshape_data_for_flash (value , heads , num_sequence_shards )
726+ query , orig_q_seq_len = _reshape_data_for_ulysses (query , heads , num_sequence_shards )
727+ key , _ = _reshape_data_for_ulysses (key , heads , num_sequence_shards )
728+ value , _ = _reshape_data_for_ulysses (value , heads , num_sequence_shards )
729729
730- num_heads = query .shape [1 ]
730+ num_heads = query .shape [2 ]
731731 if num_heads % num_ulysses_shards != 0 :
732732 raise ValueError (
733733 "Ulysses ring attention requires the number of heads to be divisible by the Ulysses shard count, "
734734 f"got heads={ num_heads } and ulysses_shards={ num_ulysses_shards } ."
735735 )
736- block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , "tokamax_ring" )
736+ block_sizes = _select_flash_block_sizes (
737+ _bshd_as_bhsd_shape (query ), _bshd_as_bhsd_shape (key ), flash_block_sizes , dtype , "tokamax_ring"
738+ )
737739
738- q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
739- kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
740+ q_axis_names = nn .logical_to_mesh_axes (_bshd_axis_names ( axis_names_q ) )
741+ kv_axis_names = nn .logical_to_mesh_axes (_bshd_axis_names ( axis_names_kv ) )
740742
741743 @functools .partial (
742744 jax .shard_map ,
@@ -746,9 +748,13 @@ def _ulysses_ring_attention(
746748 check_vma = False ,
747749 )
748750 def wrap_ulysses_ring_attention (query , key , value ):
749- query = jax .lax .all_to_all (query , axis_name = ulysses_axis , split_axis = 1 , concat_axis = 2 , tiled = True )
750- key = jax .lax .all_to_all (key , axis_name = ulysses_axis , split_axis = 1 , concat_axis = 2 , tiled = True )
751- value = jax .lax .all_to_all (value , axis_name = ulysses_axis , split_axis = 1 , concat_axis = 2 , tiled = True )
751+ query = jax .lax .all_to_all (query , axis_name = ulysses_axis , split_axis = 2 , concat_axis = 1 , tiled = True )
752+ key = jax .lax .all_to_all (key , axis_name = ulysses_axis , split_axis = 2 , concat_axis = 1 , tiled = True )
753+ value = jax .lax .all_to_all (value , axis_name = ulysses_axis , split_axis = 2 , concat_axis = 1 , tiled = True )
754+
755+ query = _bshd_to_bhsd (query )
756+ key = _bshd_to_bhsd (key )
757+ value = _bshd_to_bhsd (value )
752758
753759 uses_fused_kernel = block_sizes .use_fused_bwd_kernel
754760 block_q_sizes = (block_sizes .block_q , block_sizes .block_q_dkv )
@@ -809,11 +815,12 @@ def wrap_ulysses_ring_attention(query, key, value):
809815 attention_output = vmapped_splash (query , key , value , segment_ids )
810816 attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
811817
818+ attention_output = _bhsd_to_bshd (attention_output )
812819 return jax .lax .all_to_all (
813820 attention_output ,
814821 axis_name = ulysses_axis ,
815- split_axis = 2 ,
816- concat_axis = 1 ,
822+ split_axis = 1 ,
823+ concat_axis = 2 ,
817824 tiled = True ,
818825 )
819826
@@ -824,8 +831,8 @@ def wrap_ulysses_ring_attention(query, key, value):
824831 f" axis, batch dimension: { query .shape [0 ]} , devices_in_batch_sharding: { devices_in_batch_sharding } "
825832 )
826833 x = wrap_ulysses_ring_attention (query , key , value )
827- x = x [:, :, :orig_q_seq_len , :]
828- x = _reshape_heads_to_head_dim ( x )
834+ x = x [:, :orig_q_seq_len , :, :]
835+ x = x . reshape ( x . shape [ 0 ], x . shape [ 1 ], - 1 )
829836
830837 return x
831838
@@ -950,7 +957,7 @@ def _apply_attention(
950957 """Routes to different attention kernels."""
951958 _check_attention_inputs (query , key , value )
952959 seq_len_idx = 1
953- if query .ndim == 4 and attention_kernel != "ulysses" :
960+ if query .ndim == 4 and attention_kernel != "ulysses" and attention_kernel not in ULYSSES_RING_ATTENTION_KERNELS :
954961 seq_len_idx = 2
955962 if attention_kernel in ["flash" , "tokamax_flash" , "ulysses" ] or attention_kernel in ULYSSES_RING_ATTENTION_KERNELS :
956963 can_use_flash_attention = (
@@ -1628,7 +1635,7 @@ def __call__(
16281635
16291636 if rotary_emb is not None :
16301637 with self .conditional_named_scope ("attn_rope" ):
1631- if self .attention_op .attention_kernel == "ulysses" :
1638+ if self .attention_op .attention_kernel == "ulysses" or self . attention_op . attention_kernel in ULYSSES_RING_ATTENTION_KERNELS :
16321639 query_proj = _unflatten_heads_bshd (query_proj , self .heads )
16331640 key_proj = _unflatten_heads_bshd (key_proj , self .heads )
16341641 value_proj = _unflatten_heads_bshd (value_proj , self .heads )
0 commit comments