Skip to content

Commit 4541243

Browse files
committed
Use segment ids for WAN ulysses+ring attention masking
Replace the NumpyMask padding mask in ulysses_ring with segment ids, matching the other attention kernels. Add a shared _build_padding_segment_ids helper used by flash, ulysses, and ulysses+ring paths.
1 parent 4691a2c commit 4541243

51 files changed

Lines changed: 788 additions & 564 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
absl-py
33
accelerate
44
aqtp
5-
av
65
chex
76
datasets
87
einops

dependencies/requirements/generated_requirements/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ astroid>=4.0.4
1515
astunparse>=1.6.3
1616
attrs>=25.4.0
1717
auditwheel>=6.6.0
18-
av>=17.0.1
1918
black>=25.12.0
2019
build>=1.4.0
2120
certifi>=2026.1.4

maxdiffusion_dependencies.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ ENV DEBIAN_FRONTEND=noninteractive
1717
RUN python -m pip install --upgrade pip uv --no-warn-script-location
1818

1919
# Install system dependencies
20-
RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool g++ && rm -rf /var/lib/apt/lists/*
20+
RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool && rm -rf /var/lib/apt/lists/*
2121

2222
# Add the Google Cloud SDK package repository
2323
RUN curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \

src/maxdiffusion/common_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,15 @@
9595
[CROSS_ATTN_Q_LENGTH, CONTEXT],
9696
[CROSS_ATTN_KV_LENGTH, CONTEXT],
9797
]
98+
99+
### Common axis rules for 2D Ulysses + ring attention ###
100+
# Public configs shard sequence on `context`; attention code privately reshapes
101+
# that axis into hidden ring and Ulysses axes for the hybrid kernel.
102+
ULYSSES_RING_ATTENTION_AXIS_RULES = [
103+
[SELF_ATTN_HEAD, None],
104+
[SELF_ATTN_Q_LENGTH, CONTEXT],
105+
[SELF_ATTN_KV_LENGTH, CONTEXT],
106+
[CROSS_ATTN_HEAD, None],
107+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
108+
[CROSS_ATTN_KV_LENGTH, None],
109+
]

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44-
# The dtype for text_encoder model during load/compile
45-
text_encoder_dtype: 'float32'
46-
47-
# Whether to compile the text_encoder with torch.compile
48-
compile_text_encoder: False
4944

5045
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5146
replicate_vae: False
@@ -69,9 +64,11 @@ jit_initializers: True
6964
# Set true to load weights from pytorch
7065
from_pt: True
7166
split_head_dim: True
72-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
7368
use_base2_exp: True
7469
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7572
flash_min_seq_length: 4096
7673
dropout: 0.0
7774

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44-
# The dtype for text_encoder model during load/compile
45-
text_encoder_dtype: 'float32'
46-
47-
# Whether to compile the text_encoder with torch.compile
48-
compile_text_encoder: False
4944

5045
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5146
replicate_vae: False
@@ -65,9 +60,11 @@ jit_initializers: True
6560
# Set true to load weights from pytorch
6661
from_pt: True
6762
split_head_dim: True
68-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
63+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6964
use_base2_exp: True
7065
use_experimental_scheduler: True
66+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
67+
ulysses_shards: -1
7168
flash_min_seq_length: 0
7269

7370
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44-
# The dtype for text_encoder model during load/compile
45-
text_encoder_dtype: 'float32'
46-
47-
# Whether to compile the text_encoder with torch.compile
48-
compile_text_encoder: False
4944

5045
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5146
replicate_vae: False
@@ -69,9 +64,11 @@ jit_initializers: True
6964
# Set true to load weights from pytorch
7065
from_pt: True
7166
split_head_dim: True
72-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
7368
use_base2_exp: True
7469
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7572
flash_min_seq_length: 4096
7673
dropout: 0.0
7774

src/maxdiffusion/configs/base_wan_animate.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44-
# The dtype for text_encoder model during load/compile
45-
text_encoder_dtype: 'float32'
46-
47-
# Whether to compile the text_encoder with torch.compile
48-
compile_text_encoder: False
4944

5045
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5146
replicate_vae: False
@@ -67,9 +62,11 @@ jit_initializers: True
6762
# Set true to load weights from pytorch
6863
from_pt: True
6964
split_head_dim: True
70-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
65+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
7166
use_base2_exp: True
7267
use_experimental_scheduler: True
68+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
69+
ulysses_shards: -1
7370
flash_min_seq_length: 4096
7471
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
7572
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44-
# The dtype for text_encoder model during load/compile
45-
text_encoder_dtype: 'float32'
46-
47-
# Whether to compile the text_encoder with torch.compile
48-
compile_text_encoder: False
4944

5045
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5146
replicate_vae: False
@@ -69,9 +64,11 @@ jit_initializers: True
6964
# Set true to load weights from pytorch
7065
from_pt: True
7166
split_head_dim: True
72-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
7368
use_base2_exp: True
7469
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7572
flash_min_seq_length: 4096
7673
dropout: 0.0
7774

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44-
# The dtype for text_encoder model during load/compile
45-
text_encoder_dtype: 'float32'
46-
47-
# Whether to compile the text_encoder with torch.compile
48-
compile_text_encoder: False
4944

5045
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5146
replicate_vae: False
@@ -69,9 +64,11 @@ jit_initializers: True
6964
# Set true to load weights from pytorch
7065
from_pt: True
7166
split_head_dim: True
72-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
7368
use_base2_exp: True
7469
use_experimental_scheduler: True
70+
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+
ulysses_shards: -1
7572
flash_min_seq_length: 4096
7673
dropout: 0.0
7774

0 commit comments

Comments
 (0)