Skip to content

Commit 5c4d053

Browse files
committed
Add explicit Ulysses ring attention sharding
1 parent 19d4e4d commit 5c4d053

14 files changed

Lines changed: 564 additions & 57 deletions

src/maxdiffusion/common_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,15 @@
9595
[CROSS_ATTN_Q_LENGTH, CONTEXT],
9696
[CROSS_ATTN_KV_LENGTH, CONTEXT],
9797
]
98+
99+
### Common axis rules for 2D Ulysses + ring attention ###
100+
# Public configs shard sequence on `context`; attention code privately reshapes
101+
# that axis into hidden ring and Ulysses axes for the hybrid kernel.
102+
ULYSSES_RING_ATTENTION_AXIS_RULES = [
103+
[SELF_ATTN_HEAD, None],
104+
[SELF_ATTN_Q_LENGTH, CONTEXT],
105+
[SELF_ATTN_KV_LENGTH, CONTEXT],
106+
[CROSS_ATTN_HEAD, None],
107+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
108+
[CROSS_ATTN_KV_LENGTH, None],
109+
]

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ jit_initializers: True
6464
# Set true to load weights from pytorch
6565
from_pt: True
6666
split_head_dim: True
67-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7072
flash_min_seq_length: 4096
7173
dropout: 0.0
7274

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
63+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6464
use_base2_exp: True
6565
use_experimental_scheduler: True
66+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
67+
ulysses_shards: -1
6668
flash_min_seq_length: 0
6769

6870
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ jit_initializers: True
6464
# Set true to load weights from pytorch
6565
from_pt: True
6666
split_head_dim: True
67-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7072
flash_min_seq_length: 4096
7173
dropout: 0.0
7274

@@ -81,14 +83,14 @@ mask_padding_tokens: True
8183
attention_sharding_uniform: True
8284

8385
flash_block_sizes: {
84-
"block_q" : 512,
85-
"block_kv_compute" : 512,
86-
"block_kv" : 512,
87-
"block_q_dkv" : 512,
88-
"block_kv_dkv" : 512,
89-
"block_kv_dkv_compute" : 512,
90-
"block_q_dq" : 512,
91-
"block_kv_dq" : 512,
86+
"block_q" : 2048,
87+
"block_kv_compute" : 1024,
88+
"block_kv" : 2048,
89+
"block_q_dkv" : 2048,
90+
"block_kv_dkv" : 2048,
91+
"block_kv_dkv_compute" : 1024,
92+
"block_q_dq" : 2048,
93+
"block_kv_dq" : 2048,
9294
"use_fused_bwd_kernel": False,
9395
}
9496
# Use on v6e

src/maxdiffusion/configs/base_wan_animate.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ jit_initializers: True
6262
# Set true to load weights from pytorch
6363
from_pt: True
6464
split_head_dim: True
65-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
65+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6666
use_base2_exp: True
6767
use_experimental_scheduler: True
68+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
69+
ulysses_shards: -1
6870
flash_min_seq_length: 4096
6971
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
7072
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ jit_initializers: True
6464
# Set true to load weights from pytorch
6565
from_pt: True
6666
split_head_dim: True
67-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7072
flash_min_seq_length: 4096
7173
dropout: 0.0
7274

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ jit_initializers: True
6464
# Set true to load weights from pytorch
6565
from_pt: True
6666
split_head_dim: True
67-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7072
flash_min_seq_length: 4096
7173
dropout: 0.0
7274

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def get_flash_block_sizes(config):
617617
"""Create custom flash attention BlockSizes."""
618618
flash_block_sizes = None
619619
if len(config.flash_block_sizes.keys()) > 0:
620-
attention_is_tokamax = "tokamax" in config.attention
620+
attention_is_tokamax = "tokamax" in config.attention or config.attention == "ulysses_ring"
621621
user_block_sizes: Dict[str, int] = config.flash_block_sizes
622622
if attention_is_tokamax:
623623
max_logging.log(

0 commit comments

Comments
 (0)