Skip to content

Commit c60e65d

Browse files
PR #359: Ring attention integration and other optimizations
Imported from GitHub PR #359 In this PR, we integrated tokamax ring attention kernels for WAN models. Below are the main changes made: 1. Added ring attention kernel and splash attention kernel under . Here is the doc for the modification we made: [Ring Attention Kernel Precision Issue](https://docs.google.com/document/d/11FPxDoT0PfdnEAGPko-6V5oblzWmwyCJUKMMqCq04e4). Modified to support 2. JITted VAE and sharded VAE: added new config (default to 1) to let users decide how to shard VAE. 3. Xprof: modified profiler code to actually use (for example ) instead of profiling the entire generation Fix BUILD file by adding missing :kernels target to fix ModuleNotFoundError. Copybara import of the project: -- 616bf63 by Elisa Tsai <elisatsai@google.com>: Feat: Ring attention kernel and VAE optimization Merging this change closes #359 PiperOrigin-RevId: 902809224
1 parent ad6391a commit c60e65d

34 files changed

Lines changed: 7533 additions & 101 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'
4444

4545
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4646
replicate_vae: False
47+
vae_spatial: -1 # default to total_device * 2 // (dp)
4748

4849
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4950
# Options are "DEFAULT", "HIGH", "HIGHEST"
@@ -60,7 +61,7 @@ jit_initializers: True
6061
# Set true to load weights from pytorch
6162
from_pt: True
6263
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses
6465
flash_min_seq_length: 0
6566

6667
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
@@ -180,6 +181,19 @@ logical_axis_rules: [
180181
['out_channels', 'tensor'],
181182
['conv_out', 'context'],
182183
]
184+
vae_logical_axis_rules: [
185+
['activation_batch', 'redundant'],
186+
['activation_length', 'vae_spatial'],
187+
['activation_heads', null],
188+
['activation_kv_length', null],
189+
['embed', null],
190+
['heads', null],
191+
['norm', null],
192+
['conv_batch', 'redundant'],
193+
['out_channels', 'vae_spatial'],
194+
['conv_out', 'vae_spatial'],
195+
['conv_in', 'vae_spatial'],
196+
]
183197
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
184198

185199
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Google LLC
1+
# Copyright 2023 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -157,6 +157,19 @@ logical_axis_rules: [
157157
['out_channels', 'tensor'],
158158
['conv_out', 'context'],
159159
]
160+
vae_logical_axis_rules: [
161+
['activation_batch', 'redundant'],
162+
['activation_length', 'vae_spatial'],
163+
['activation_heads', null],
164+
['activation_kv_length', null],
165+
['embed', null],
166+
['heads', null],
167+
['norm', null],
168+
['conv_batch', 'redundant'],
169+
['out_channels', 'vae_spatial'],
170+
['conv_out', 'vae_spatial'],
171+
['conv_in', 'vae_spatial'],
172+
]
160173
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
161174

162175
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'
4444

4545
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4646
replicate_vae: False
47+
vae_spatial: 1
4748

4849
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4950
# Options are "DEFAULT", "HIGH", "HIGHEST"
@@ -168,6 +169,19 @@ logical_axis_rules: [
168169
['out_channels', 'tensor'],
169170
['conv_out', 'context'],
170171
]
172+
vae_logical_axis_rules: [
173+
['activation_batch', 'redundant'],
174+
['activation_length', 'vae_spatial'],
175+
['activation_heads', null],
176+
['activation_kv_length', null],
177+
['embed', null],
178+
['heads', null],
179+
['norm', null],
180+
['conv_batch', 'redundant'],
181+
['out_channels', 'vae_spatial'],
182+
['conv_out', 'vae_spatial'],
183+
['conv_in', 'vae_spatial'],
184+
]
171185
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
172186

173187
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,19 @@ logical_axis_rules: [
163163
['out_channels', 'tensor'],
164164
['conv_out', 'context'],
165165
]
166+
vae_logical_axis_rules: [
167+
['activation_batch', 'redundant'],
168+
['activation_length', 'vae_spatial'],
169+
['activation_heads', null],
170+
['activation_kv_length', null],
171+
['embed', null],
172+
['heads', null],
173+
['norm', null],
174+
['conv_batch', 'redundant'],
175+
['out_channels', 'vae_spatial'],
176+
['conv_out', 'vae_spatial'],
177+
['conv_in', 'vae_spatial'],
178+
]
166179
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
167180

168181
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ logical_axis_rules: [
164164
['out_channels', 'tensor'],
165165
['conv_out', 'context'],
166166
]
167+
vae_logical_axis_rules: [
168+
['activation_batch', 'redundant'],
169+
['activation_length', 'vae_spatial'],
170+
['activation_heads', null],
171+
['activation_kv_length', null],
172+
['embed', null],
173+
['heads', null],
174+
['norm', null],
175+
['conv_batch', 'redundant'],
176+
['out_channels', 'vae_spatial'],
177+
['conv_out', 'vae_spatial'],
178+
['conv_in', 'vae_spatial'],
179+
]
167180
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
168181

169182
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configuration_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def load_config(
394394
proxies=proxies,
395395
resume_download=resume_download,
396396
local_files_only=local_files_only,
397-
use_auth_token=use_auth_token,
397+
token=use_auth_token,
398398
user_agent=user_agent,
399399
subfolder=subfolder,
400400
revision=revision,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Splash Attention kernels."""

0 commit comments

Comments
 (0)