Skip to content

Commit b91e784

Browse files
sagarchaparaclaude
andcommitted
feat: extend BSHD all-to-all layout to ulysses_ring attention
Applies the same BSHD-native layout used in _ulysses_attention to _ulysses_ring_attention: all-to-all operates on the smaller post-shard head dimension, deferring the BHSD transpose until after the collective. Also routes ulysses_ring through the BSHD RoPE path in FlaxWanAttention. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 04574b4 commit b91e784

1 file changed

Lines changed: 23 additions & 16 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -723,20 +723,22 @@ def _ulysses_ring_attention(
723723
num_ring_shards = mesh.shape[ring_axis]
724724
num_sequence_shards = num_ulysses_shards * num_ring_shards
725725

726-
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_sequence_shards)
727-
key, _ = _reshape_data_for_flash(key, heads, num_sequence_shards)
728-
value, _ = _reshape_data_for_flash(value, heads, num_sequence_shards)
726+
query, orig_q_seq_len = _reshape_data_for_ulysses(query, heads, num_sequence_shards)
727+
key, _ = _reshape_data_for_ulysses(key, heads, num_sequence_shards)
728+
value, _ = _reshape_data_for_ulysses(value, heads, num_sequence_shards)
729729

730-
num_heads = query.shape[1]
730+
num_heads = query.shape[2]
731731
if num_heads % num_ulysses_shards != 0:
732732
raise ValueError(
733733
"Ulysses ring attention requires the number of heads to be divisible by the Ulysses shard count, "
734734
f"got heads={num_heads} and ulysses_shards={num_ulysses_shards}."
735735
)
736-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "tokamax_ring")
736+
block_sizes = _select_flash_block_sizes(
737+
_bshd_as_bhsd_shape(query), _bshd_as_bhsd_shape(key), flash_block_sizes, dtype, "tokamax_ring"
738+
)
737739

738-
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
739-
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
740+
q_axis_names = nn.logical_to_mesh_axes(_bshd_axis_names(axis_names_q))
741+
kv_axis_names = nn.logical_to_mesh_axes(_bshd_axis_names(axis_names_kv))
740742

741743
@functools.partial(
742744
jax.shard_map,
@@ -746,9 +748,13 @@ def _ulysses_ring_attention(
746748
check_vma=False,
747749
)
748750
def wrap_ulysses_ring_attention(query, key, value):
749-
query = jax.lax.all_to_all(query, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True)
750-
key = jax.lax.all_to_all(key, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True)
751-
value = jax.lax.all_to_all(value, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True)
751+
query = jax.lax.all_to_all(query, axis_name=ulysses_axis, split_axis=2, concat_axis=1, tiled=True)
752+
key = jax.lax.all_to_all(key, axis_name=ulysses_axis, split_axis=2, concat_axis=1, tiled=True)
753+
value = jax.lax.all_to_all(value, axis_name=ulysses_axis, split_axis=2, concat_axis=1, tiled=True)
754+
755+
query = _bshd_to_bhsd(query)
756+
key = _bshd_to_bhsd(key)
757+
value = _bshd_to_bhsd(value)
752758

753759
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
754760
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
@@ -809,11 +815,12 @@ def wrap_ulysses_ring_attention(query, key, value):
809815
attention_output = vmapped_splash(query, key, value, segment_ids)
810816
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
811817

818+
attention_output = _bhsd_to_bshd(attention_output)
812819
return jax.lax.all_to_all(
813820
attention_output,
814821
axis_name=ulysses_axis,
815-
split_axis=2,
816-
concat_axis=1,
822+
split_axis=1,
823+
concat_axis=2,
817824
tiled=True,
818825
)
819826

@@ -824,8 +831,8 @@ def wrap_ulysses_ring_attention(query, key, value):
824831
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
825832
)
826833
x = wrap_ulysses_ring_attention(query, key, value)
827-
x = x[:, :, :orig_q_seq_len, :]
828-
x = _reshape_heads_to_head_dim(x)
834+
x = x[:, :orig_q_seq_len, :, :]
835+
x = x.reshape(x.shape[0], x.shape[1], -1)
829836

830837
return x
831838

@@ -950,7 +957,7 @@ def _apply_attention(
950957
"""Routes to different attention kernels."""
951958
_check_attention_inputs(query, key, value)
952959
seq_len_idx = 1
953-
if query.ndim == 4 and attention_kernel != "ulysses":
960+
if query.ndim == 4 and attention_kernel != "ulysses" and attention_kernel not in ULYSSES_RING_ATTENTION_KERNELS:
954961
seq_len_idx = 2
955962
if attention_kernel in ["flash", "tokamax_flash", "ulysses"] or attention_kernel in ULYSSES_RING_ATTENTION_KERNELS:
956963
can_use_flash_attention = (
@@ -1628,7 +1635,7 @@ def __call__(
16281635

16291636
if rotary_emb is not None:
16301637
with self.conditional_named_scope("attn_rope"):
1631-
if self.attention_op.attention_kernel == "ulysses":
1638+
if self.attention_op.attention_kernel == "ulysses" or self.attention_op.attention_kernel in ULYSSES_RING_ATTENTION_KERNELS:
16321639
query_proj = _unflatten_heads_bshd(query_proj, self.heads)
16331640
key_proj = _unflatten_heads_bshd(key_proj, self.heads)
16341641
value_proj = _unflatten_heads_bshd(value_proj, self.heads)

0 commit comments

Comments
 (0)