Skip to content

Commit 48a900e

Browse files
committed
Add explicit Ulysses ring attention sharding
1 parent 19d4e4d commit 48a900e

15 files changed

Lines changed: 600 additions & 156 deletions

src/maxdiffusion/common_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,15 @@
9595
[CROSS_ATTN_Q_LENGTH, CONTEXT],
9696
[CROSS_ATTN_KV_LENGTH, CONTEXT],
9797
]
98+
99+
### Common axis rules for 2D Ulysses + ring attention ###
100+
# Public configs shard sequence on `context`; attention code privately reshapes
101+
# that axis into hidden ring and Ulysses axes for the hybrid kernel.
102+
ULYSSES_RING_ATTENTION_AXIS_RULES = [
103+
[SELF_ATTN_HEAD, None],
104+
[SELF_ATTN_Q_LENGTH, CONTEXT],
105+
[SELF_ATTN_KV_LENGTH, CONTEXT],
106+
[CROSS_ATTN_HEAD, None],
107+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
108+
[CROSS_ATTN_KV_LENGTH, None],
109+
]

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
2020
write_metrics: True
2121

2222
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23-
write_timing_metrics: True
23+
write_timing_metrics: True
2424

2525
gcs_metrics: False
2626
# If true save config to GCS in {base_output_directory}/{run_name}/
@@ -64,9 +64,11 @@ jit_initializers: True
6464
# Set true to load weights from pytorch
6565
from_pt: True
6666
split_head_dim: True
67-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7072
flash_min_seq_length: 4096
7173
dropout: 0.0
7274

@@ -81,38 +83,38 @@ mask_padding_tokens: True
8183
attention_sharding_uniform: True
8284

8385
flash_block_sizes: {
84-
"block_q" : 512,
85-
"block_kv_compute" : 512,
86-
"block_kv" : 512,
87-
"block_q_dkv" : 512,
88-
"block_kv_dkv" : 512,
89-
"block_kv_dkv_compute" : 512,
90-
"block_q_dq" : 512,
91-
"block_kv_dq" : 512,
86+
"block_q": 512,
87+
"block_kv_compute": 512,
88+
"block_kv": 512,
89+
"block_q_dkv": 512,
90+
"block_kv_dkv": 512,
91+
"block_kv_dkv_compute": 512,
92+
"block_q_dq": 512,
93+
"block_kv_dq": 512,
9294
"use_fused_bwd_kernel": False,
9395
}
9496
# Use on v6e
9597
# flash_block_sizes: {
96-
# "block_q" : 3024,
97-
# "block_kv_compute" : 1024,
98-
# "block_kv" : 2048,
99-
# "block_q_dkv" : 3024,
100-
# "block_kv_dkv" : 2048,
101-
# "block_kv_dkv_compute" : 1024,
102-
# "block_q_dq" : 3024,
103-
# "block_kv_dq" : 2048,
98+
# "block_q": 3024,
99+
# "block_kv_compute": 1024,
100+
# "block_kv": 2048,
101+
# "block_q_dkv": 3024,
102+
# "block_kv_dkv": 2048,
103+
# "block_kv_dkv_compute": 1024,
104+
# "block_q_dq": 3024,
105+
# "block_kv_dq": 2048,
104106
# "use_fused_bwd_kernel": False,
105107
# }
106108
# Use on v5p
107109
# flash_block_sizes: {
108-
# "block_q" : 3024,
109-
# "block_kv_compute" : 1024,
110-
# "block_kv" : 2048,
111-
# "block_q_dkv" : 1024,
112-
# "block_kv_dkv" : 3072,
113-
# "block_kv_dkv_compute" : 256,
114-
# "block_q_dq" : 1024,
115-
# "block_kv_dq" : 3072
110+
# "block_q": 3024,
111+
# "block_kv_compute": 1024,
112+
# "block_kv": 2048,
113+
# "block_q_dkv": 1024,
114+
# "block_kv_dkv": 3072,
115+
# "block_kv_dkv_compute": 256,
116+
# "block_q_dq": 1024,
117+
# "block_kv_dq": 3072
116118
# }
117119
# GroupNorm groups
118120
norm_num_groups: 32

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
63+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6464
use_base2_exp: True
6565
use_experimental_scheduler: True
66+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
67+
ulysses_shards: -1
6668
flash_min_seq_length: 0
6769

6870
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
@@ -77,14 +79,14 @@ attention_sharding_uniform: True
7779
dropout: 0.0
7880

7981
flash_block_sizes: {
80-
"block_q" : 512,
81-
"block_kv_compute" : 512,
82-
"block_kv" : 512,
83-
"block_q_dkv" : 512,
84-
"block_kv_dkv" : 512,
85-
"block_kv_dkv_compute" : 512,
86-
"block_q_dq" : 512,
87-
"block_kv_dq" : 512,
82+
"block_q": 512,
83+
"block_kv_compute": 512,
84+
"block_kv": 512,
85+
"block_q_dkv": 512,
86+
"block_kv_dkv": 512,
87+
"block_kv_dkv_compute": 512,
88+
"block_q_dq": 512,
89+
"block_kv_dq": 512,
8890
"use_fused_bwd_kernel": False,
8991
}
9092
# GroupNorm groups

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
2020
write_metrics: True
2121

2222
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23-
write_timing_metrics: True
23+
write_timing_metrics: True
2424

2525
gcs_metrics: False
2626
# If true save config to GCS in {base_output_directory}/{run_name}/
@@ -64,9 +64,11 @@ jit_initializers: True
6464
# Set true to load weights from pytorch
6565
from_pt: True
6666
split_head_dim: True
67-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7072
flash_min_seq_length: 4096
7173
dropout: 0.0
7274

@@ -81,26 +83,26 @@ mask_padding_tokens: True
8183
attention_sharding_uniform: True
8284

8385
flash_block_sizes: {
84-
"block_q" : 512,
85-
"block_kv_compute" : 512,
86-
"block_kv" : 512,
87-
"block_q_dkv" : 512,
88-
"block_kv_dkv" : 512,
89-
"block_kv_dkv_compute" : 512,
90-
"block_q_dq" : 512,
91-
"block_kv_dq" : 512,
86+
"block_q": 2048,
87+
"block_kv_compute": 1024,
88+
"block_kv": 2048,
89+
"block_q_dkv": 2048,
90+
"block_kv_dkv": 2048,
91+
"block_kv_dkv_compute": 1024,
92+
"block_q_dq": 2048,
93+
"block_kv_dq": 2048,
9294
"use_fused_bwd_kernel": False,
9395
}
9496
# Use on v6e
9597
# flash_block_sizes: {
96-
# "block_q" : 3024,
97-
# "block_kv_compute" : 1024,
98-
# "block_kv" : 2048,
99-
# "block_q_dkv" : 3024,
100-
# "block_kv_dkv" : 2048,
101-
# "block_kv_dkv_compute" : 2048,
102-
# "block_q_dq" : 3024,
103-
# "block_kv_dq" : 2048
98+
# "block_q": 3024,
99+
# "block_kv_compute": 1024,
100+
# "block_kv": 2048,
101+
# "block_q_dkv": 3024,
102+
# "block_kv_dkv": 2048,
103+
# "block_kv_dkv_compute": 2048,
104+
# "block_q_dq": 3024,
105+
# "block_kv_dq": 2048
104106
# "use_fused_bwd_kernel": False,
105107
# }
106108
# GroupNorm groups

src/maxdiffusion/configs/base_wan_animate.yml

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
2020
write_metrics: True
2121

2222
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23-
write_timing_metrics: True
23+
write_timing_metrics: True
2424

2525
gcs_metrics: False
2626
# If true save config to GCS in {base_output_directory}/{run_name}/
@@ -62,45 +62,47 @@ jit_initializers: True
6262
# Set true to load weights from pytorch
6363
from_pt: True
6464
split_head_dim: True
65-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
65+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6666
use_base2_exp: True
6767
use_experimental_scheduler: True
68+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
69+
ulysses_shards: -1
6870
flash_min_seq_length: 4096
6971
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
7072
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
7173
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
72-
mask_padding_tokens: True
74+
mask_padding_tokens: True
7375
# Maxdiffusion has 2 types of attention sharding strategies:
7476
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
7577
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7678
# in cross attention q.
77-
attention_sharding_uniform: True
79+
attention_sharding_uniform: True
7880
dropout: 0.0
7981

8082
# Tuned for 720p (720x1280), 81 frames, CP=8 on Trillium (32MB VMEM):
8183
# block_q=2048, block_kv=4096, block_kv_compute=1024
8284
# ~31% faster than default (512,512,512): 389s vs 508s at 40 steps
8385
flash_block_sizes: {
84-
"block_q" : 2048,
85-
"block_kv_compute" : 1024,
86-
"block_kv" : 4096,
87-
"block_q_dkv" : 512,
88-
"block_kv_dkv" : 512,
89-
"block_kv_dkv_compute" : 512,
90-
"block_q_dq" : 512,
91-
"block_kv_dq" : 512,
86+
"block_q": 2048,
87+
"block_kv_compute": 1024,
88+
"block_kv": 4096,
89+
"block_q_dkv": 512,
90+
"block_kv_dkv": 512,
91+
"block_kv_dkv_compute": 512,
92+
"block_q_dq": 512,
93+
"block_kv_dq": 512,
9294
"use_fused_bwd_kernel": False,
9395
}
9496
# Default smaller-shape block sizes:
9597
# flash_block_sizes: {
96-
# "block_q" : 512,
97-
# "block_kv_compute" : 512,
98-
# "block_kv" : 512,
99-
# "block_q_dkv" : 512,
100-
# "block_kv_dkv" : 512,
101-
# "block_kv_dkv_compute" : 512,
102-
# "block_q_dq" : 512,
103-
# "block_kv_dq" : 512,
98+
# "block_q": 512,
99+
# "block_kv_compute": 512,
100+
# "block_kv": 512,
101+
# "block_q_dkv": 512,
102+
# "block_kv_dkv": 512,
103+
# "block_kv_dkv_compute": 512,
104+
# "block_q_dq": 512,
105+
# "block_kv_dq": 512,
104106
# "use_fused_bwd_kernel": False,
105107
# }
106108
# GroupNorm groups

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
2020
write_metrics: True
2121

2222
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23-
write_timing_metrics: True
23+
write_timing_metrics: True
2424

2525
gcs_metrics: False
2626
# If true save config to GCS in {base_output_directory}/{run_name}/
@@ -64,9 +64,11 @@ jit_initializers: True
6464
# Set true to load weights from pytorch
6565
from_pt: True
6666
split_head_dim: True
67-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868
use_base2_exp: True
6969
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7072
flash_min_seq_length: 4096
7173
dropout: 0.0
7274

@@ -81,24 +83,24 @@ mask_padding_tokens: True
8183
attention_sharding_uniform: True
8284

8385
flash_block_sizes: {
84-
"block_q" : 2048,
85-
"block_kv_compute" : 512,
86-
"block_kv" : 2048,
87-
"block_q_dkv" : 2048,
88-
"block_kv_dkv" : 2048,
89-
"block_kv_dkv_compute" : 512,
90-
"use_fused_bwd_kernel" : True
86+
"block_q": 2048,
87+
"block_kv_compute": 512,
88+
"block_kv": 2048,
89+
"block_q_dkv": 2048,
90+
"block_kv_dkv": 2048,
91+
"block_kv_dkv_compute": 512,
92+
"use_fused_bwd_kernel": True
9193
}
9294
# Use on v6e
9395
# flash_block_sizes: {
94-
# "block_q" : 3024,
95-
# "block_kv_compute" : 1024,
96-
# "block_kv" : 2048,
97-
# "block_q_dkv" : 3024,
98-
# "block_kv_dkv" : 2048,
99-
# "block_kv_dkv_compute" : 2048,
100-
# "block_q_dq" : 3024,
101-
# "block_kv_dq" : 2048,
96+
# "block_q": 3024,
97+
# "block_kv_compute": 1024,
98+
# "block_kv": 2048,
99+
# "block_q_dkv": 3024,
100+
# "block_kv_dkv": 2048,
101+
# "block_kv_dkv_compute": 2048,
102+
# "block_q_dq": 3024,
103+
# "block_kv_dq": 2048,
102104
# "use_fused_bwd_kernel": False,
103105
# }
104106
# GroupNorm groups

0 commit comments

Comments
 (0)