Skip to content

Commit d6638b6

Browse files
committed
Add Ulysses ring split override
1 parent f319797 commit d6638b6

13 files changed

Lines changed: 78 additions & 4 deletions

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,13 +603,17 @@ To generate images, run the following command:
603603
Public configs still shard sequence only over the `context` mesh axis. The attention kernel privately reshapes
604604
that context axis into hidden ring and Ulysses axes, runs the Ulysses all-to-all over the hidden Ulysses axis,
605605
and reuses Tokamax ring attention over the hidden ring axis.
606+
By default, the split is selected automatically. For tuning, set
607+
`ulysses_ring_ulysses_parallelism=<ulysses_shards>`; ring shards are derived as
608+
`ici_context_parallelism / ulysses_ring_ulysses_parallelism`.
606609

607610
```bash
608611
python src/maxdiffusion/generate_wan.py \
609612
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
610613
attention="ulysses_ring" \
611614
dcn_context_parallelism=<num_slices> \
612615
ici_context_parallelism=<context_shards_per_slice> \
616+
ulysses_ring_ulysses_parallelism=<optional_ulysses_shards> \
613617
...
614618
```
615619

docs/tpu_multihost_wan_bench.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ run_case ulysses_ring_dp2_cp8 ulysses_ring 2 8 1
188188

189189
## Topology Note
190190

191-
TPU v7x exposes dual chiplets as two JAX devices. For `ulysses_ring`, expose only the total sequence sharding through `context`; the attention kernel derives a private ring and Ulysses split from that axis.
191+
TPU v7x exposes dual chiplets as two JAX devices. For `ulysses_ring`, expose only the total sequence sharding through `context`; the attention kernel derives a private ring and Ulysses split from that axis. To tune that split explicitly, set `ulysses_ring_ulysses_parallelism`; ring shards are derived as `ici_context_parallelism / ulysses_ring_ulysses_parallelism`.
192192
- `4x4` uses tensor `4`, so the dual-chip pairing is still inside the Ulysses side.
193193

194194
The plain `ring` baseline has no Ulysses group, so it cannot preserve that property by construction.

docs/tpu_wan_bench_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ Set in `src/maxdiffusion/configs/base_wan_27b.yml` or overridden on the command
143143
**Parallelism rule**: product of all ICI axes must equal 8 (chips per host):
144144
- `ici_dp × ici_fsdp × ici_cp × ici_tp = 8`
145145

146-
For `ulysses_ring`, set the desired total sequence shards with `ici_context_parallelism`; the internal ring and Ulysses split is selected by the attention kernel.
146+
For `ulysses_ring`, set the desired total sequence shards with `ici_context_parallelism`; the internal ring and Ulysses split is selected by the attention kernel. To tune it manually, set `ulysses_ring_ulysses_parallelism=<ulysses_shards>` and the ring shard count is derived as `ici_context_parallelism / ulysses_ring_ulysses_parallelism`.
147147

148148
---
149149

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ split_head_dim: True
6767
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# Optional for attention=ulysses_ring. -1 auto-selects the hidden split; otherwise
71+
# this many context shards are used for Ulysses and ring shards are context / this.
72+
ulysses_ring_ulysses_parallelism: -1
7073
flash_min_seq_length: 4096
7174
dropout: 0.0
7275

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6464
use_base2_exp: True
6565
use_experimental_scheduler: True
66+
# Optional for attention=ulysses_ring. -1 auto-selects the hidden split; otherwise
67+
# this many context shards are used for Ulysses and ring shards are context / this.
68+
ulysses_ring_ulysses_parallelism: -1
6669
flash_min_seq_length: 0
6770

6871
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ split_head_dim: True
6767
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# Optional for attention=ulysses_ring. -1 auto-selects the hidden split; otherwise
71+
# this many context shards are used for Ulysses and ring shards are context / this.
72+
ulysses_ring_ulysses_parallelism: -1
7073
flash_min_seq_length: 4096
7174
dropout: 0.0
7275

src/maxdiffusion/configs/base_wan_animate.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ split_head_dim: True
6565
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6666
use_base2_exp: True
6767
use_experimental_scheduler: True
68+
# Optional for attention=ulysses_ring. -1 auto-selects the hidden split; otherwise
69+
# this many context shards are used for Ulysses and ring shards are context / this.
70+
ulysses_ring_ulysses_parallelism: -1
6871
flash_min_seq_length: 4096
6972
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
7073
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ split_head_dim: True
6767
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# Optional for attention=ulysses_ring. -1 auto-selects the hidden split; otherwise
71+
# this many context shards are used for Ulysses and ring shards are context / this.
72+
ulysses_ring_ulysses_parallelism: -1
7073
flash_min_seq_length: 4096
7174
dropout: 0.0
7275

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ split_head_dim: True
6767
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# Optional for attention=ulysses_ring. -1 auto-selects the hidden split; otherwise
71+
# this many context shards are used for Ulysses and ring shards are context / this.
72+
ulysses_ring_ulysses_parallelism: -1
7073
flash_min_seq_length: 4096
7174
dropout: 0.0
7275

src/maxdiffusion/models/attention_flax.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,28 @@ def _replace_mesh_axis_names(axis_names, old_axis: str, new_axes: tuple[str, ...
206206
)
207207

208208

209-
def _choose_internal_ulysses_shards(context_shards: int, heads: int) -> int:
209+
def _choose_internal_ulysses_shards(
210+
context_shards: int,
211+
heads: int,
212+
requested_ulysses_shards: int = -1,
213+
) -> int:
210214
"""Choose a hidden Ulysses split inside the public context axis."""
211215
if context_shards <= 1:
212216
raise ValueError(f"Ulysses ring attention requires context_shards > 1, got {context_shards}.")
213217

218+
if requested_ulysses_shards and requested_ulysses_shards > 0:
219+
if context_shards % requested_ulysses_shards != 0:
220+
raise ValueError(
221+
"Ulysses ring attention requires the requested Ulysses shard count to divide the context shard count, "
222+
f"got context_shards={context_shards} and ulysses_shards={requested_ulysses_shards}."
223+
)
224+
if heads % requested_ulysses_shards != 0:
225+
raise ValueError(
226+
"Ulysses ring attention requires the number of heads to be divisible by the requested Ulysses shard "
227+
f"count, got heads={heads} and ulysses_shards={requested_ulysses_shards}."
228+
)
229+
return requested_ulysses_shards
230+
214231
balanced_limit = int(math.sqrt(context_shards))
215232
balanced_candidates = [
216233
factor
@@ -844,6 +861,7 @@ def _ulysses_ring_attention(
844861
ring_axis: str = INTERNAL_RING_AXIS,
845862
use_base2_exp: bool = False,
846863
use_experimental_scheduler: bool = False,
864+
ulysses_ring_ulysses_parallelism: int = -1,
847865
) -> jax.Array:
848866
"""2D context-parallel attention using a private Ulysses x ring mesh.
849867
@@ -857,7 +875,11 @@ def _ulysses_ring_attention(
857875
raise ValueError(f"Ulysses ring attention requires mesh axis {context_axis!r}, got mesh axes {mesh.shape}.")
858876

859877
num_context_shards = mesh.shape[context_axis]
860-
num_ulysses_shards = _choose_internal_ulysses_shards(num_context_shards, heads)
878+
num_ulysses_shards = _choose_internal_ulysses_shards(
879+
num_context_shards,
880+
heads,
881+
requested_ulysses_shards=ulysses_ring_ulysses_parallelism,
882+
)
861883
num_ring_shards = num_context_shards // num_ulysses_shards
862884
internal_mesh = _create_internal_ulysses_ring_mesh(
863885
mesh,
@@ -1166,6 +1188,7 @@ def ulysses_ring_kernel(q, k, v, context):
11661188
attention_mask=context["attention_mask"],
11671189
use_base2_exp=context["use_base2_exp"],
11681190
use_experimental_scheduler=context["use_experimental_scheduler"],
1191+
ulysses_ring_ulysses_parallelism=context["ulysses_ring_ulysses_parallelism"],
11691192
)
11701193

11711194

@@ -1279,6 +1302,7 @@ def _apply_attention(
12791302
attention_mask: Array = None,
12801303
use_base2_exp: bool = False,
12811304
use_experimental_scheduler: bool = False,
1305+
ulysses_ring_ulysses_parallelism: int = -1,
12821306
):
12831307
"""Routes to different attention kernels using a module-level registry."""
12841308

@@ -1316,6 +1340,7 @@ def _apply_attention(
13161340
"scale": scale,
13171341
"use_base2_exp": use_base2_exp,
13181342
"use_experimental_scheduler": use_experimental_scheduler,
1343+
"ulysses_ring_ulysses_parallelism": ulysses_ring_ulysses_parallelism,
13191344
"dim_head": dim_head,
13201345
"split_head_dim": split_head_dim,
13211346
"float32_qk_product": float32_qk_product,
@@ -1521,10 +1546,12 @@ def __init__(
15211546
residual_checkpoint_name: str | None = None,
15221547
use_base2_exp: bool = False,
15231548
use_experimental_scheduler: bool = False,
1549+
ulysses_ring_ulysses_parallelism: int = -1,
15241550
):
15251551
self.dpa_layer = None
15261552
self.use_base2_exp = use_base2_exp
15271553
self.use_experimental_scheduler = use_experimental_scheduler
1554+
self.ulysses_ring_ulysses_parallelism = ulysses_ring_ulysses_parallelism
15281555
if attention_kernel == "cudnn_flash_te":
15291556
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
15301557

@@ -1587,6 +1614,9 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
15871614
attention_mask=attention_mask,
15881615
use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False,
15891616
use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False,
1617+
ulysses_ring_ulysses_parallelism=(
1618+
self.ulysses_ring_ulysses_parallelism if hasattr(self, "ulysses_ring_ulysses_parallelism") else -1
1619+
),
15901620
)
15911621

15921622

@@ -1607,6 +1637,7 @@ class AttentionOp(nn.Module):
16071637
quant: Quant = None
16081638
use_base2_exp: bool = False
16091639
use_experimental_scheduler: bool = False
1640+
ulysses_ring_ulysses_parallelism: int = -1
16101641

16111642
def setup(self):
16121643
self.dpa_layer = None
@@ -1654,6 +1685,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
16541685
attention_mask=attention_mask,
16551686
use_base2_exp=self.use_base2_exp,
16561687
use_experimental_scheduler=self.use_experimental_scheduler,
1688+
ulysses_ring_ulysses_parallelism=self.ulysses_ring_ulysses_parallelism,
16571689
)
16581690

16591691

@@ -1692,6 +1724,7 @@ def __init__(
16921724
image_seq_len: Optional[int] = None, # New for I2V
16931725
use_base2_exp: bool = False,
16941726
use_experimental_scheduler: bool = False,
1727+
ulysses_ring_ulysses_parallelism: int = -1,
16951728
):
16961729
if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
16971730
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
@@ -1740,6 +1773,7 @@ def __init__(
17401773
residual_checkpoint_name=residual_checkpoint_name,
17411774
use_base2_exp=use_base2_exp,
17421775
use_experimental_scheduler=use_experimental_scheduler,
1776+
ulysses_ring_ulysses_parallelism=ulysses_ring_ulysses_parallelism,
17431777
)
17441778
# None axes corresponds to the stacked weights across all blocks
17451779
# because of the use of nnx.vmap and nnx.scan.

0 commit comments

Comments
 (0)