Skip to content

Commit 673b2a9

Browse files
committed
Add Ulysses attention support
1 parent a3747b1 commit 673b2a9

11 files changed

Lines changed: 519 additions & 14 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,6 @@ wandb
181181
# Gemini CLI
182182
.gemini/
183183
gha-creds-*.json
184+
185+
# JAX cache
186+
.jax_cache/

README.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,26 @@ To generate images, run the following command:
572572
* For Wan2.2 T2V, use `base_wan_27b.yml`.
573573
* For Wan2.2 I2V, use `base_wan_i2v_27b.yml`.
574574

575+
### Ulysses Attention
576+
577+
MaxDiffusion supports Ulysses attention for WAN TPU inference. Enable it by setting `attention="ulysses"`.
578+
579+
Internally, this follows the Ulysses sequence-parallel attention pattern and trades sequence shards for head shards around the local TPU splash kernel. For background, see [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://arxiv.org/abs/2309.14509).
580+
581+
To enable Ulysses attention, set the corresponding override in your config YAML or pass it as a command-line override:
582+
583+
```bash
584+
python src/maxdiffusion/generate_wan.py \
585+
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
586+
attention="ulysses" \
587+
ici_context_parallelism=4 \
588+
...
589+
```
590+
591+
Ulysses requires `ici_context_parallelism` greater than 1, and the number of attention heads must be divisible by the context shard count. `flash_block_sizes` tuning is optional and can still be used for hardware-specific tuning.
592+
593+
In our Wan2.2 I2V benchmarks at 40 inference steps, 81 frames, and `720x1280` resolution, Ulysses improved inference time by roughly `~10%` compared with flash attention, with about `~20s` lower latency on the v6e-8 and v7x-8 TPU setup.
594+
575595
### Caching Mechanisms
576596

577597
Wan 2.x pipelines support several caching strategies to accelerate inference by skipping redundant transformer forward passes. These are **mutually exclusive** — enable only one at a time.
@@ -772,4 +792,3 @@ This script will automatically format your code with `pyink` and help you identi
772792
773793
774794
The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.
775-

src/maxdiffusion/common_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,13 @@
8484
[CROSS_ATTN_Q_LENGTH, CONTEXT],
8585
[CROSS_ATTN_KV_LENGTH, None],
8686
]
87+
88+
### Common axis rules for ulysses attention ###
89+
ULYSSES_ATTENTION_AXIS_RULES = [
90+
[SELF_ATTN_HEAD, None],
91+
[SELF_ATTN_Q_LENGTH, CONTEXT],
92+
[SELF_ATTN_KV_LENGTH, CONTEXT],
93+
[CROSS_ATTN_HEAD, None],
94+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
95+
[CROSS_ATTN_KV_LENGTH, CONTEXT],
96+
]

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 0
6565

6666
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 0
6565

6666
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 4096
6565
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6666
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 4096
6565
dropout: 0.0
6666

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 4096
6565
dropout: 0.0
6666

src/maxdiffusion/models/attention_flax.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,147 @@ def ring_scan_body(carry, _):
443443
return x
444444

445445

446+
# ---------------------------------------------------------------------------
447+
# Ulysses sequence-parallel attention
448+
# ---------------------------------------------------------------------------
449+
450+
451+
def _ulysses_attention(
452+
query: jax.Array,
453+
key: jax.Array,
454+
value: jax.Array,
455+
heads: int,
456+
mesh: Mesh,
457+
axis_names_q: AxisNames,
458+
axis_names_kv: AxisNames,
459+
flash_block_sizes: BlockSizes,
460+
dtype: jnp.dtype = jnp.float32,
461+
mask_padding_tokens: bool = True,
462+
residual_checkpoint_name: str | None = None,
463+
attention_mask: jax.Array = None,
464+
) -> jax.Array:
465+
"""Ulysses sequence-parallel attention.
466+
467+
Tensors arrive sequence-sharded on the context axis. Inside a shard_map the
468+
all-to-all collectives trade sequence shards for head shards, run local
469+
splash attention on the full sequence with a subset of heads, then all-to-all
470+
back.
471+
472+
This function is a self-contained op that reuses _reshape_data_for_flash and
473+
_pad_data_for_flash from the existing flash attention path.
474+
"""
475+
axis_name = "context"
476+
num_shards = mesh.shape[axis_name]
477+
478+
# Reshape to [b, h, s, d] and pad sequence for even context-axis splitting.
479+
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards)
480+
key, _ = _reshape_data_for_flash(key, heads, num_shards)
481+
value, _ = _reshape_data_for_flash(value, heads, num_shards)
482+
num_heads = query.shape[1]
483+
# Ulysses only redistributes existing heads across the context mesh; unlike
484+
# the earlier draft, we fail fast instead of padding synthetic heads.
485+
if num_heads % num_shards != 0:
486+
raise ValueError(
487+
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
488+
f"got heads={num_heads} and context_shards={num_shards}."
489+
)
490+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
491+
492+
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
493+
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
494+
495+
@functools.partial(
496+
jax.shard_map,
497+
mesh=mesh,
498+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
499+
out_specs=q_axis_names,
500+
check_vma=False,
501+
)
502+
def wrap_ulysses_attention(query, key, value):
503+
# Swap sharding modes: each device gives up a slice of sequence and gathers
504+
# a slice of heads, so the local splash kernel sees the full sequence.
505+
query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
506+
key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
507+
value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
508+
509+
# Run the same local splash kernel as standard TPU flash attention, but now
510+
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
511+
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
512+
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
513+
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
514+
if uses_fused_kernel:
515+
block_q_sizes += (block_sizes.block_q_dkv,)
516+
block_kv_sizes += (block_sizes.block_kv_dkv,)
517+
else:
518+
block_q_sizes += (block_sizes.block_q_dq,)
519+
block_kv_sizes += (block_sizes.block_kv_dq,)
520+
521+
block_q = max(*block_q_sizes)
522+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
523+
block_kv = max(*block_kv_sizes)
524+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
525+
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
526+
527+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
528+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
529+
530+
q_padded_len = query.shape[2]
531+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
532+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
533+
534+
kv_padded_len = key.shape[2]
535+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
536+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
537+
538+
# Reuse the standard flash-attention masking convention by zeroing invalid
539+
# KV positions in the segment ids passed down to splash.
540+
if attention_mask is not None:
541+
mask_len = min(key_seq_len, attention_mask.shape[1])
542+
kv_mask_for_batch = attention_mask[0, :mask_len]
543+
if key_seq_len > mask_len:
544+
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
545+
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
546+
if kv_padded_len > key_seq_len:
547+
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
548+
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
549+
else:
550+
kv_mask_padded = kv_mask_for_batch
551+
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
552+
553+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
554+
if not mask_padding_tokens:
555+
segment_ids = None
556+
557+
splash_kernel = splash_attention_kernel.make_splash_mha(
558+
mask=multi_head_mask,
559+
head_shards=1,
560+
q_seq_shards=1,
561+
block_sizes=block_sizes,
562+
save_residuals=False,
563+
residual_checkpoint_name=residual_checkpoint_name,
564+
)
565+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
566+
attention_output = vmapped_splash(query, key, value, segment_ids)
567+
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
568+
569+
# Restore the original layout expected by the rest of the model:
570+
# head-sharded / full-sequence -> sequence-sharded / full-heads.
571+
attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True)
572+
return attention_output
573+
574+
devices_in_data_context = mesh.shape["data"] * num_shards
575+
if not (query.shape[0] / devices_in_data_context).is_integer():
576+
max_logging.log(
577+
"Warning, batch dimension should be shardable among the devices in data and context"
578+
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
579+
)
580+
x = wrap_ulysses_attention(query, key, value)
581+
x = x[:, :, :orig_q_seq_len, :]
582+
x = _reshape_heads_to_head_dim(x)
583+
584+
return x
585+
586+
446587
def _apply_attention_dot(
447588
query: Array,
448589
key: Array,
@@ -563,7 +704,7 @@ def _apply_attention(
563704
seq_len_idx = 1
564705
if query.ndim == 4:
565706
seq_len_idx = 2
566-
if attention_kernel in ["flash", "tokamax_flash"]:
707+
if attention_kernel in ["flash", "tokamax_flash", "ulysses"]:
567708
can_use_flash_attention = (
568709
query.shape[seq_len_idx] >= flash_min_seq_length
569710
and key.shape[seq_len_idx] >= flash_min_seq_length
@@ -575,6 +716,21 @@ def _apply_attention(
575716
return _apply_attention_dot(
576717
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
577718
)
719+
elif attention_kernel == "ulysses":
720+
return _ulysses_attention(
721+
query,
722+
key * scale,
723+
value,
724+
heads,
725+
mesh,
726+
axis_names_q,
727+
axis_names_kv,
728+
flash_block_sizes,
729+
dtype,
730+
mask_padding_tokens=mask_padding_tokens,
731+
residual_checkpoint_name=residual_checkpoint_name,
732+
attention_mask=attention_mask,
733+
)
578734
elif attention_kernel in ["flash", "tokamax_flash"]:
579735
return _tpu_flash_attention(
580736
query,

src/maxdiffusion/pyconfig.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,17 @@
2727
from . import max_logging
2828
from . import max_utils
2929
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
30-
from maxdiffusion.common_types import LENGTH, KV_LENGTH, WAN2_1, WAN2_2, LTX2_VIDEO, RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES
30+
from maxdiffusion.common_types import (
31+
CONTEXT,
32+
LENGTH,
33+
KV_LENGTH,
34+
WAN2_1,
35+
WAN2_2,
36+
LTX2_VIDEO,
37+
RING_ATTENTION_AXIS_RULES,
38+
SEQUENCE_PARALLEL_AXIS_RULES,
39+
ULYSSES_ATTENTION_AXIS_RULES,
40+
)
3141

3242
_ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO}
3343
_ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1}
@@ -200,25 +210,37 @@ def user_init(raw_keys):
200210

201211
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
202212
# Verify qkv is sharded across sequence.
203-
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
213+
attention = raw_keys["attention"]
214+
uses_ring_attention = attention == "ring"
215+
uses_ulysses_attention = attention == "ulysses"
216+
uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"]
217+
if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding:
204218
max_logging.log(
205-
f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set."
219+
"Adding sequence sharding to q and kv if not already present because "
220+
f"{attention=} requires it or attention_sharding_uniform={uses_uniform_sequence_sharding} is set."
206221
)
207222
logical_axis_rules = list(raw_keys["logical_axis_rules"])
208223
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
209224
new_rules = []
210-
q_seq_sharding = (LENGTH, "context")
211-
kv_seq_sharding = (KV_LENGTH, "context")
225+
q_seq_sharding = (LENGTH, CONTEXT)
226+
kv_seq_sharding = (KV_LENGTH, CONTEXT)
212227
if q_seq_sharding not in logical_axis_rules:
213228
logical_axis_rules.append(q_seq_sharding)
229+
max_logging.log(f"Adding sequence length axis rule {q_seq_sharding}")
214230
if kv_seq_sharding not in logical_axis_rules:
215231
logical_axis_rules.append(kv_seq_sharding)
216-
if raw_keys["attention"] == "ring":
232+
max_logging.log(f"Adding key/value sequence axis rule {kv_seq_sharding}")
233+
if uses_ring_attention:
217234
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
218235
if ring_attention_axis_rule not in logical_axis_rules:
219236
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
220237
new_rules.append(ring_attention_axis_rule)
221-
else: # attention =flash but sequence parallel sharding requested for both self and cross attention
238+
elif uses_ulysses_attention:
239+
for ulysses_attention_axis_rule in ULYSSES_ATTENTION_AXIS_RULES:
240+
if ulysses_attention_axis_rule not in logical_axis_rules:
241+
max_logging.log(f"Adding ulysses attention axis rule {ulysses_attention_axis_rule}")
242+
new_rules.append(ulysses_attention_axis_rule)
243+
else: # attention=flash but sequence parallel sharding requested for both self and cross attention
222244
for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES:
223245
if seq_parallel_axis_rule not in logical_axis_rules:
224246
max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}")

0 commit comments

Comments
 (0)