Skip to content

Commit a4e0ae7

Browse files
committed
Add Ulysses attention support
1 parent 7f698e3 commit a4e0ae7

11 files changed

Lines changed: 517 additions & 14 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,6 @@ wandb
181181
# Gemini CLI
182182
.gemini/
183183
gha-creds-*.json
184+
185+
# JAX cache
186+
.jax_cache/

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,26 @@ To generate images, run the following command:
572572
* For Wan2.2 T2V, use `base_wan_27b.yml`.
573573
* For Wan2.2 I2V, use `base_wan_i2v_27b.yml`.
574574

575+
### Ulysses Attention
576+
577+
MaxDiffusion supports Ulysses attention for WAN TPU inference. Enable it by setting `attention="ulysses"`.
578+
579+
Internally, this follows the Ulysses sequence-parallel attention pattern and trades sequence shards for head shards around the local TPU splash kernel. For background, see [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://arxiv.org/abs/2309.14509).
580+
581+
To enable Ulysses attention, set the corresponding override in your config YAML or pass it as a command-line override:
582+
583+
```bash
584+
python src/maxdiffusion/generate_wan.py \
585+
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
586+
attention="ulysses" \
587+
ici_context_parallelism=4 \
588+
...
589+
```
590+
591+
Ulysses requires `ici_context_parallelism` greater than 1, and the number of attention heads must be divisible by the context shard count. `flash_block_sizes` tuning is optional and can still be used for hardware-specific tuning.
592+
593+
In our Wan2.2 I2V benchmarks at 40 inference steps, 81 frames, and `720x1280` resolution, Ulysses improved inference time by roughly `~10%` compared with flash attention, with about `~20s` lower latency on the v6e-8 and v7x-8 TPU setup.
594+
575595
### Caching Mechanisms
576596

577597
Wan 2.x pipelines support several caching strategies to accelerate inference by skipping redundant transformer forward passes. These are **mutually exclusive** — enable only one at a time.
@@ -774,4 +794,4 @@ This script will automatically format your code with `pyink` and help you identi
774794
The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.
775795
776796
## Profiling
777-
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).
797+
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).

src/maxdiffusion/common_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,13 @@
8484
[CROSS_ATTN_Q_LENGTH, CONTEXT],
8585
[CROSS_ATTN_KV_LENGTH, None],
8686
]
87+
88+
### Common axis rules for ulysses attention ###
89+
ULYSSES_ATTENTION_AXIS_RULES = [
90+
[SELF_ATTN_HEAD, None],
91+
[SELF_ATTN_Q_LENGTH, CONTEXT],
92+
[SELF_ATTN_KV_LENGTH, CONTEXT],
93+
[CROSS_ATTN_HEAD, None],
94+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
95+
[CROSS_ATTN_KV_LENGTH, CONTEXT],
96+
]

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 0
6565

6666
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 0
6565

6666
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 4096
6565
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6666
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 4096
6565
dropout: 0.0
6666

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
6464
flash_min_seq_length: 4096
6565
dropout: 0.0
6666

src/maxdiffusion/models/attention_flax.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,144 @@ def ring_scan_body(carry, _):
443443
return x
444444

445445

446+
# ---------------------------------------------------------------------------
447+
# Ulysses sequence-parallel attention
448+
# ---------------------------------------------------------------------------
449+
450+
451+
def _ulysses_attention(
452+
query: jax.Array,
453+
key: jax.Array,
454+
value: jax.Array,
455+
heads: int,
456+
mesh: Mesh,
457+
axis_names_q: AxisNames,
458+
axis_names_kv: AxisNames,
459+
flash_block_sizes: BlockSizes,
460+
dtype: jnp.dtype = jnp.float32,
461+
mask_padding_tokens: bool = True,
462+
residual_checkpoint_name: str | None = None,
463+
attention_mask: jax.Array = None,
464+
) -> jax.Array:
465+
"""Ulysses sequence-parallel attention.
466+
467+
Tensors arrive sequence-sharded on the context axis. Inside a shard_map the
468+
all-to-all collectives trade sequence shards for head shards, run local
469+
splash attention on the full sequence with a subset of heads, then all-to-all
470+
back.
471+
"""
472+
axis_name = "context"
473+
num_shards = mesh.shape[axis_name]
474+
475+
# Reshape to [b, h, s, d] and pad sequence for even context-axis splitting.
476+
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards)
477+
key, _ = _reshape_data_for_flash(key, heads, num_shards)
478+
value, _ = _reshape_data_for_flash(value, heads, num_shards)
479+
num_heads = query.shape[1]
480+
# Ulysses only redistributes existing heads across the context mesh; unlike
481+
# the earlier draft, we fail fast instead of padding synthetic heads.
482+
if num_heads % num_shards != 0:
483+
raise ValueError(
484+
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
485+
f"got heads={num_heads} and context_shards={num_shards}."
486+
)
487+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
488+
489+
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
490+
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
491+
492+
@functools.partial(
493+
jax.shard_map,
494+
mesh=mesh,
495+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
496+
out_specs=q_axis_names,
497+
check_vma=False,
498+
)
499+
def wrap_ulysses_attention(query, key, value):
500+
# Swap sharding modes: each device gives up a slice of sequence and gathers
501+
# a slice of heads, so the local splash kernel sees the full sequence.
502+
query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
503+
key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
504+
value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
505+
506+
# Run the same local splash kernel as standard TPU flash attention, but now
507+
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
508+
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
509+
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
510+
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
511+
if uses_fused_kernel:
512+
block_q_sizes += (block_sizes.block_q_dkv,)
513+
block_kv_sizes += (block_sizes.block_kv_dkv,)
514+
else:
515+
block_q_sizes += (block_sizes.block_q_dq,)
516+
block_kv_sizes += (block_sizes.block_kv_dq,)
517+
518+
block_q = max(*block_q_sizes)
519+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
520+
block_kv = max(*block_kv_sizes)
521+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
522+
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
523+
524+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
525+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
526+
527+
q_padded_len = query.shape[2]
528+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
529+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
530+
531+
kv_padded_len = key.shape[2]
532+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
533+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
534+
535+
# Reuse the standard flash-attention masking convention by zeroing invalid
536+
# KV positions in the segment ids passed down to splash.
537+
if attention_mask is not None:
538+
mask_len = min(key_seq_len, attention_mask.shape[1])
539+
kv_mask_for_batch = attention_mask[0, :mask_len]
540+
if key_seq_len > mask_len:
541+
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
542+
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
543+
if kv_padded_len > key_seq_len:
544+
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
545+
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
546+
else:
547+
kv_mask_padded = kv_mask_for_batch
548+
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
549+
550+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
551+
if not mask_padding_tokens:
552+
segment_ids = None
553+
554+
splash_kernel = splash_attention_kernel.make_splash_mha(
555+
mask=multi_head_mask,
556+
head_shards=1,
557+
q_seq_shards=1,
558+
block_sizes=block_sizes,
559+
save_residuals=False,
560+
residual_checkpoint_name=residual_checkpoint_name,
561+
)
562+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
563+
attention_output = vmapped_splash(query, key, value, segment_ids)
564+
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
565+
566+
# Restore the original layout expected by the rest of the model:
567+
# head-sharded / full-sequence -> sequence-sharded / full-heads.
568+
attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True)
569+
return attention_output
570+
571+
devices_in_data_context = mesh.shape["data"] * num_shards
572+
if not (query.shape[0] / devices_in_data_context).is_integer():
573+
max_logging.log(
574+
"Warning, batch dimension should be shardable among the devices in data and context"
575+
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
576+
)
577+
x = wrap_ulysses_attention(query, key, value)
578+
x = x[:, :, :orig_q_seq_len, :]
579+
x = _reshape_heads_to_head_dim(x)
580+
581+
return x
582+
583+
446584
def _apply_attention_dot(
447585
query: Array,
448586
key: Array,
@@ -563,7 +701,7 @@ def _apply_attention(
563701
seq_len_idx = 1
564702
if query.ndim == 4:
565703
seq_len_idx = 2
566-
if attention_kernel in ["flash", "tokamax_flash"]:
704+
if attention_kernel in ["flash", "tokamax_flash", "ulysses"]:
567705
can_use_flash_attention = (
568706
query.shape[seq_len_idx] >= flash_min_seq_length
569707
and key.shape[seq_len_idx] >= flash_min_seq_length
@@ -575,6 +713,21 @@ def _apply_attention(
575713
return _apply_attention_dot(
576714
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
577715
)
716+
elif attention_kernel == "ulysses":
717+
return _ulysses_attention(
718+
query,
719+
key * scale,
720+
value,
721+
heads,
722+
mesh,
723+
axis_names_q,
724+
axis_names_kv,
725+
flash_block_sizes,
726+
dtype,
727+
mask_padding_tokens=mask_padding_tokens,
728+
residual_checkpoint_name=residual_checkpoint_name,
729+
attention_mask=attention_mask,
730+
)
578731
elif attention_kernel in ["flash", "tokamax_flash"]:
579732
return _tpu_flash_attention(
580733
query,

src/maxdiffusion/pyconfig.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,17 @@
2727
from . import max_logging
2828
from . import max_utils
2929
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
30-
from maxdiffusion.common_types import LENGTH, KV_LENGTH, WAN2_1, WAN2_2, LTX2_VIDEO, RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES
30+
from maxdiffusion.common_types import (
31+
CONTEXT,
32+
LENGTH,
33+
KV_LENGTH,
34+
WAN2_1,
35+
WAN2_2,
36+
LTX2_VIDEO,
37+
RING_ATTENTION_AXIS_RULES,
38+
SEQUENCE_PARALLEL_AXIS_RULES,
39+
ULYSSES_ATTENTION_AXIS_RULES,
40+
)
3141

3242
_ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO}
3343
_ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1}
@@ -200,25 +210,37 @@ def user_init(raw_keys):
200210

201211
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
202212
# Verify qkv is sharded across sequence.
203-
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
213+
attention = raw_keys["attention"]
214+
uses_ring_attention = attention == "ring"
215+
uses_ulysses_attention = attention == "ulysses"
216+
uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"]
217+
if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding:
204218
max_logging.log(
205-
f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set."
219+
"Adding sequence sharding to q and kv if not already present because "
220+
f"{attention=} requires it or attention_sharding_uniform={uses_uniform_sequence_sharding} is set."
206221
)
207222
logical_axis_rules = list(raw_keys["logical_axis_rules"])
208223
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
209224
new_rules = []
210-
q_seq_sharding = (LENGTH, "context")
211-
kv_seq_sharding = (KV_LENGTH, "context")
225+
q_seq_sharding = (LENGTH, CONTEXT)
226+
kv_seq_sharding = (KV_LENGTH, CONTEXT)
212227
if q_seq_sharding not in logical_axis_rules:
213228
logical_axis_rules.append(q_seq_sharding)
229+
max_logging.log(f"Adding sequence length axis rule {q_seq_sharding}")
214230
if kv_seq_sharding not in logical_axis_rules:
215231
logical_axis_rules.append(kv_seq_sharding)
216-
if raw_keys["attention"] == "ring":
232+
max_logging.log(f"Adding key/value sequence axis rule {kv_seq_sharding}")
233+
if uses_ring_attention:
217234
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
218235
if ring_attention_axis_rule not in logical_axis_rules:
219236
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
220237
new_rules.append(ring_attention_axis_rule)
221-
else: # attention =flash but sequence parallel sharding requested for both self and cross attention
238+
elif uses_ulysses_attention:
239+
for ulysses_attention_axis_rule in ULYSSES_ATTENTION_AXIS_RULES:
240+
if ulysses_attention_axis_rule not in logical_axis_rules:
241+
max_logging.log(f"Adding ulysses attention axis rule {ulysses_attention_axis_rule}")
242+
new_rules.append(ulysses_attention_axis_rule)
243+
else: # attention=flash but sequence parallel sharding requested for both self and cross attention
222244
for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES:
223245
if seq_parallel_axis_rule not in logical_axis_rules:
224246
max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}")

0 commit comments

Comments
 (0)