Skip to content

Commit ee00dce

Browse files
Merge pull request #404 from AI-Hypercomputer:wan-ulysses-bshd-attention
PiperOrigin-RevId: 931248347
2 parents b2d31df + c104db5 commit ee00dce

16 files changed

Lines changed: 695 additions & 114 deletions

src/maxdiffusion/common_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,15 @@
9595
[CROSS_ATTN_Q_LENGTH, CONTEXT],
9696
[CROSS_ATTN_KV_LENGTH, CONTEXT],
9797
]
98+
99+
### Common axis rules for 2D Ulysses + ring attention ###
100+
# Public configs shard sequence on `context`; attention code privately reshapes
101+
# that axis into hidden ring and Ulysses axes for the hybrid kernel.
102+
ULYSSES_RING_ATTENTION_AXIS_RULES = [
103+
[SELF_ATTN_HEAD, None],
104+
[SELF_ATTN_Q_LENGTH, CONTEXT],
105+
[SELF_ATTN_KV_LENGTH, CONTEXT],
106+
[CROSS_ATTN_HEAD, None],
107+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
108+
[CROSS_ATTN_KV_LENGTH, CONTEXT],
109+
]

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ jit_initializers: True
8383
# Set true to load weights from pytorch
8484
from_pt: True
8585
split_head_dim: True
86-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
86+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
8787
use_base2_exp: True
8888
use_experimental_scheduler: True
89+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
90+
ulysses_shards: -1
8991
flash_min_seq_length: 4096
9092
dropout: 0.0
9193

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,11 @@ jit_initializers: True
8080
# Set true to load weights from pytorch
8181
from_pt: True
8282
split_head_dim: True
83-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
83+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
8484
use_base2_exp: True
8585
use_experimental_scheduler: True
86+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
87+
ulysses_shards: -1
8688
flash_min_seq_length: 0
8789

8890
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ jit_initializers: True
8383
# Set true to load weights from pytorch
8484
from_pt: True
8585
split_head_dim: True
86-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
86+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
8787
use_base2_exp: True
8888
use_experimental_scheduler: True
89+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
90+
ulysses_shards: -1
8991
flash_min_seq_length: 4096
9092
dropout: 0.0
9193

src/maxdiffusion/configs/base_wan_animate.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ jit_initializers: True
8181
# Set true to load weights from pytorch
8282
from_pt: True
8383
split_head_dim: True
84-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
84+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
8585
use_base2_exp: True
8686
use_experimental_scheduler: True
87+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
88+
ulysses_shards: -1
8789
flash_min_seq_length: 4096
8890
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
8991
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ jit_initializers: True
8383
# Set true to load weights from pytorch
8484
from_pt: True
8585
split_head_dim: True
86-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
86+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
8787
use_base2_exp: True
8888
use_experimental_scheduler: True
89+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
90+
ulysses_shards: -1
8991
flash_min_seq_length: 4096
9092
dropout: 0.0
9193

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ jit_initializers: True
8383
# Set true to load weights from pytorch
8484
from_pt: True
8585
split_head_dim: True
86-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
86+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
8787
use_base2_exp: True
8888
use_experimental_scheduler: True
89+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
90+
ulysses_shards: -1
8991
flash_min_seq_length: 4096
9092
dropout: 0.0
9193

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def get_flash_block_sizes(config):
637637
"""Create custom flash attention BlockSizes."""
638638
flash_block_sizes = None
639639
if len(config.flash_block_sizes.keys()) > 0:
640-
attention_is_tokamax = "tokamax" in config.attention
640+
attention_is_tokamax = "tokamax" in config.attention or config.attention == "ulysses_ring"
641641
user_block_sizes: Dict[str, int] = config.flash_block_sizes
642642
# The custom splash kernel reads flash_block_sizes via getattr and needs
643643
# fields the JAX BlockSizes dataclass cannot hold. Return a frozen, hashable

0 commit comments

Comments
 (0)