Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,6 @@ wandb
# Gemini CLI
.gemini/
gha-creds-*.json

# JAX cache
.jax_cache/
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,26 @@ To generate images, run the following command:
* For Wan2.2 T2V, use `base_wan_27b.yml`.
* For Wan2.2 I2V, use `base_wan_i2v_27b.yml`.

### Ulysses Attention

MaxDiffusion supports Ulysses attention for WAN TPU inference. Enable it by setting `attention="ulysses"`.

Internally, this follows the Ulysses sequence-parallel attention pattern and trades sequence shards for head shards around the local TPU splash kernel. For background, see [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://arxiv.org/abs/2309.14509).

To enable Ulysses attention, set the corresponding override in your config YAML or pass it as a command-line override:

```bash
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
attention="ulysses" \
ici_context_parallelism=4 \
...
```

Ulysses requires `ici_context_parallelism` greater than 1, and the number of attention heads must be divisible by the context shard count. `flash_block_sizes` tuning is optional and can still be used for hardware-specific tuning.

In our Wan2.2 I2V benchmarks at 40 inference steps, 81 frames, and `720x1280` resolution, Ulysses improved inference time by roughly `~10%` compared with flash attention, with about `~20s` lower latency on the v6e-8 and v7x-8 TPU setup.

### Caching Mechanisms

Wan 2.x pipelines support several caching strategies to accelerate inference by skipping redundant transformer forward passes. These are **mutually exclusive** — enable only one at a time.
Expand Down Expand Up @@ -774,4 +794,4 @@ This script will automatically format your code with `pyink` and help you identi
The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.

## Profiling
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).
10 changes: 10 additions & 0 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@
[CROSS_ATTN_Q_LENGTH, CONTEXT],
[CROSS_ATTN_KV_LENGTH, None],
]

### Common axis rules for ulysses attention ###
ULYSSES_ATTENTION_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, CONTEXT],
[SELF_ATTN_KV_LENGTH, CONTEXT],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, CONTEXT],
[CROSS_ATTN_KV_LENGTH, CONTEXT],
]
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_wan_1_3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
flash_min_seq_length: 4096
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
flash_min_seq_length: 4096
dropout: 0.0

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
flash_min_seq_length: 4096
dropout: 0.0

Expand Down
155 changes: 154 additions & 1 deletion src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,144 @@ def ring_scan_body(carry, _):
return x


# ---------------------------------------------------------------------------
# Ulysses sequence-parallel attention
# ---------------------------------------------------------------------------


def _ulysses_attention(
Comment thread
entrpn marked this conversation as resolved.
query: jax.Array,
key: jax.Array,
value: jax.Array,
heads: int,
mesh: Mesh,
axis_names_q: AxisNames,
axis_names_kv: AxisNames,
flash_block_sizes: BlockSizes,
dtype: jnp.dtype = jnp.float32,
mask_padding_tokens: bool = True,
residual_checkpoint_name: str | None = None,
attention_mask: jax.Array = None,
) -> jax.Array:
"""Ulysses sequence-parallel attention.

Tensors arrive sequence-sharded on the context axis. Inside a shard_map the
all-to-all collectives trade sequence shards for head shards, run local
splash attention on the full sequence with a subset of heads, then all-to-all
back.
"""
axis_name = "context"
num_shards = mesh.shape[axis_name]

# Reshape to [b, h, s, d] and pad sequence for even context-axis splitting.
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards)
key, _ = _reshape_data_for_flash(key, heads, num_shards)
value, _ = _reshape_data_for_flash(value, heads, num_shards)
num_heads = query.shape[1]
# Ulysses only redistributes existing heads across the context mesh; unlike
# the earlier draft, we fail fast instead of padding synthetic heads.
if num_heads % num_shards != 0:
raise ValueError(
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
f"got heads={num_heads} and context_shards={num_shards}."
)
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")

q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)

@functools.partial(
jax.shard_map,
mesh=mesh,
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
out_specs=q_axis_names,
check_vma=False,
)
def wrap_ulysses_attention(query, key, value):
# Swap sharding modes: each device gives up a slice of sequence and gathers
# a slice of heads, so the local splash kernel sees the full sequence.
query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)

# Run the same local splash kernel as standard TPU flash attention, but now
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
if uses_fused_kernel:
block_q_sizes += (block_sizes.block_q_dkv,)
block_kv_sizes += (block_sizes.block_kv_dkv,)
else:
block_q_sizes += (block_sizes.block_q_dq,)
block_kv_sizes += (block_sizes.block_kv_dq,)

block_q = max(*block_q_sizes)
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
block_kv = max(*block_kv_sizes)
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
value, _, _ = _pad_data_for_flash(value, heads, block_kv)

mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])

q_padded_len = query.shape[2]
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)

kv_padded_len = key.shape[2]
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)

# Reuse the standard flash-attention masking convention by zeroing invalid
# KV positions in the segment ids passed down to splash.
if attention_mask is not None:
mask_len = min(key_seq_len, attention_mask.shape[1])
kv_mask_for_batch = attention_mask[0, :mask_len]
if key_seq_len > mask_len:
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
if kv_padded_len > key_seq_len:
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
else:
kv_mask_padded = kv_mask_for_batch
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)

segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
if not mask_padding_tokens:
segment_ids = None

splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1,
q_seq_shards=1,
block_sizes=block_sizes,
save_residuals=False,
residual_checkpoint_name=residual_checkpoint_name,
)
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
attention_output = vmapped_splash(query, key, value, segment_ids)
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)

# Restore the original layout expected by the rest of the model:
# head-sharded / full-sequence -> sequence-sharded / full-heads.
attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True)
return attention_output

devices_in_data_context = mesh.shape["data"] * num_shards
if not (query.shape[0] / devices_in_data_context).is_integer():
max_logging.log(
"Warning, batch dimension should be shardable among the devices in data and context"
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
)
x = wrap_ulysses_attention(query, key, value)
x = x[:, :, :orig_q_seq_len, :]
x = _reshape_heads_to_head_dim(x)

return x


def _apply_attention_dot(
query: Array,
key: Array,
Expand Down Expand Up @@ -563,7 +701,7 @@ def _apply_attention(
seq_len_idx = 1
if query.ndim == 4:
seq_len_idx = 2
if attention_kernel in ["flash", "tokamax_flash"]:
if attention_kernel in ["flash", "tokamax_flash", "ulysses"]:
can_use_flash_attention = (
query.shape[seq_len_idx] >= flash_min_seq_length
and key.shape[seq_len_idx] >= flash_min_seq_length
Expand All @@ -575,6 +713,21 @@ def _apply_attention(
return _apply_attention_dot(
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
)
elif attention_kernel == "ulysses":
return _ulysses_attention(
query,
key * scale,
value,
heads,
mesh,
axis_names_q,
axis_names_kv,
flash_block_sizes,
dtype,
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name=residual_checkpoint_name,
attention_mask=attention_mask,
)
elif attention_kernel in ["flash", "tokamax_flash"]:
return _tpu_flash_attention(
query,
Expand Down
36 changes: 29 additions & 7 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,17 @@
from . import max_logging
from . import max_utils
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
from maxdiffusion.common_types import LENGTH, KV_LENGTH, WAN2_1, WAN2_2, LTX2_VIDEO, RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES
from maxdiffusion.common_types import (
CONTEXT,
LENGTH,
KV_LENGTH,
WAN2_1,
WAN2_2,
LTX2_VIDEO,
RING_ATTENTION_AXIS_RULES,
SEQUENCE_PARALLEL_AXIS_RULES,
ULYSSES_ATTENTION_AXIS_RULES,
)

_ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO}
_ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1}
Expand Down Expand Up @@ -200,25 +210,37 @@ def user_init(raw_keys):

raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
# Verify qkv is sharded across sequence.
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
attention = raw_keys["attention"]
uses_ring_attention = attention == "ring"
uses_ulysses_attention = attention == "ulysses"
uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"]
if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding:
max_logging.log(
f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set."
"Adding sequence sharding to q and kv if not already present because "
f"{attention=} requires it or attention_sharding_uniform={uses_uniform_sequence_sharding} is set."
)
logical_axis_rules = list(raw_keys["logical_axis_rules"])
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
new_rules = []
q_seq_sharding = (LENGTH, "context")
kv_seq_sharding = (KV_LENGTH, "context")
q_seq_sharding = (LENGTH, CONTEXT)
kv_seq_sharding = (KV_LENGTH, CONTEXT)
if q_seq_sharding not in logical_axis_rules:
logical_axis_rules.append(q_seq_sharding)
max_logging.log(f"Adding sequence length axis rule {q_seq_sharding}")
if kv_seq_sharding not in logical_axis_rules:
logical_axis_rules.append(kv_seq_sharding)
if raw_keys["attention"] == "ring":
max_logging.log(f"Adding key/value sequence axis rule {kv_seq_sharding}")
if uses_ring_attention:
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
if ring_attention_axis_rule not in logical_axis_rules:
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
new_rules.append(ring_attention_axis_rule)
else: # attention =flash but sequence parallel sharding requested for both self and cross attention
elif uses_ulysses_attention:
for ulysses_attention_axis_rule in ULYSSES_ATTENTION_AXIS_RULES:
if ulysses_attention_axis_rule not in logical_axis_rules:
max_logging.log(f"Adding ulysses attention axis rule {ulysses_attention_axis_rule}")
new_rules.append(ulysses_attention_axis_rule)
else: # attention=flash but sequence parallel sharding requested for both self and cross attention
for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES:
if seq_parallel_axis_rule not in logical_axis_rules:
max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}")
Expand Down
Loading
Loading