Skip to content

Commit 2e247d0

Browse files
committed
Fix VAE decoding error and use base 2 and experimental scheduler
1 parent ae22683 commit 2e247d0

15 files changed

Lines changed: 140 additions & 28 deletions

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -628,16 +628,16 @@ To generate images, run the following command:
628628
We added ring attention support for Wan models. Below are the stats for one `720p` (81 frames) video generation (with CFG DP):
629629
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
630630
| -- | -- | -- | -- | -- | -- |
631-
| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | 264.2 |
632-
| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | **252.4** |
633-
| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | 212.7 |
634-
| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | **201.7** |
631+
| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | **249.3** |
632+
| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | 252.4 |
633+
| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | **194.4** |
634+
| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | 201.7 |
635635

636636
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
637637
| -- | -- | -- | -- | -- | -- |
638-
| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | 146.6 |
639-
| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | **137.2** |
640-
| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **117.8** |
638+
| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | **127.1** |
639+
| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | 137.2 |
640+
| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **106.0** |
641641
| v7x-16 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context8-tp1 | 137.5 |
642642

643643
(* There are some known stability issues for ring attention on 16 TPUs, please use `tokamax_flash` attention instead.)

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ jit_initializers: True
6262
from_pt: True
6363
split_head_dim: True
6464
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses
65+
use_base2_exp: True
66+
use_experimental_scheduler: True
6567
flash_min_seq_length: 0
6668

6769
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
use_base2_exp: True
65+
use_experimental_scheduler: True
6466
flash_min_seq_length: 0
6567

6668
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ jit_initializers: True
6262
from_pt: True
6363
split_head_dim: True
6464
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
65+
use_base2_exp: True
66+
use_experimental_scheduler: True
6567
flash_min_seq_length: 4096
6668
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6769
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
use_base2_exp: True
65+
use_experimental_scheduler: True
6466
flash_min_seq_length: 4096
6567
dropout: 0.0
6668

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
use_base2_exp: True
65+
use_experimental_scheduler: True
6466
flash_min_seq_length: 4096
6567
dropout: 0.0
6668

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,14 +302,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
302302
f"{'=' * 50}"
303303
)
304304

305-
s0 = time.perf_counter()
306-
if max_utils.profiler_enabled(config):
307-
with max_utils.Profiler(config):
308-
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
309-
generation_time_with_profiler = time.perf_counter() - s0
310-
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
311-
if writer and jax.process_index() == 0:
312-
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
305+
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
313306

314307
return saved_video_path
315308

src/maxdiffusion/models/attention_flax.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def convert_to_tokamax_splash_config(
272272
attn_logits_soft_cap: float | None = None,
273273
fuse_reciprocal: bool = True,
274274
use_base2_exp: bool = False,
275+
use_experimental_scheduler: bool = False,
275276
max_logit_const: float | None = None,
276277
interpret: bool = False,
277278
dq_reduction_steps: int | None = None,
@@ -294,6 +295,7 @@ def convert_to_tokamax_splash_config(
294295
attn_logits_soft_cap=attn_logits_soft_cap,
295296
fuse_reciprocal=fuse_reciprocal,
296297
use_base2_exp=use_base2_exp,
298+
use_experimental_scheduler=use_experimental_scheduler,
297299
max_logit_const=max_logit_const,
298300
interpret=interpret,
299301
dq_reduction_steps=dq_reduction_steps,
@@ -314,6 +316,8 @@ def _tpu_flash_attention(
314316
mask_padding_tokens: bool = True,
315317
residual_checkpoint_name: str | None = None,
316318
attention_mask: jax.Array = None,
319+
use_base2_exp: bool = False,
320+
use_experimental_scheduler: bool = False,
317321
) -> jax.Array:
318322
"""TPU Flash Attention"""
319323

@@ -399,7 +403,12 @@ def wrap_flash_attention(query, key, value):
399403
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
400404
mask=mask,
401405
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
402-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
406+
config=convert_to_tokamax_splash_config(
407+
block_sizes,
408+
residual_checkpoint_name=residual_checkpoint_name,
409+
use_base2_exp=use_base2_exp,
410+
use_experimental_scheduler=use_experimental_scheduler,
411+
),
403412
save_residuals=False,
404413
)
405414
elif attention_kernel == "tokamax_ring":
@@ -409,7 +418,12 @@ def wrap_flash_attention(query, key, value):
409418
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
410419
mask=mask,
411420
is_mqa=False,
412-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
421+
config=convert_to_tokamax_splash_config(
422+
block_sizes,
423+
residual_checkpoint_name=residual_checkpoint_name,
424+
use_base2_exp=use_base2_exp,
425+
use_experimental_scheduler=use_experimental_scheduler,
426+
),
413427
save_residuals=False,
414428
ring_axis="context",
415429
rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
@@ -473,13 +487,13 @@ def ring_scan_body(carry, _):
473487
raise ValueError("ring attention requires context > 1")
474488
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
475489

476-
devices_in_data_context = mesh.shape["data"] * mesh.shape["context"]
490+
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
477491
# This warning might show up when doing model eval for example, when calculating model flops
478492
# and that is expected.
479-
if not (query.shape[0] / devices_in_data_context).is_integer():
493+
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
480494
max_logging.log(
481-
"Warning, batch dimension should be shardable among the devices in data and context"
482-
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
495+
"Warning, batch dimension should be shardable among the devices in data and fsdp"
496+
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
483497
)
484498
x = wrap_flash_attention(query, key, value)
485499
# Trim back to original sequence length after context-axis padding.
@@ -614,11 +628,11 @@ def wrap_ulysses_attention(query, key, value):
614628
attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True)
615629
return attention_output
616630

617-
devices_in_data_context = mesh.shape["data"] * num_shards
618-
if not (query.shape[0] / devices_in_data_context).is_integer():
631+
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
632+
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
619633
max_logging.log(
620-
"Warning, batch dimension should be shardable among the devices in data and context"
621-
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
634+
"Warning, batch dimension should be shardable among the devices in data and fsdp"
635+
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
622636
)
623637
x = wrap_ulysses_attention(query, key, value)
624638
x = x[:, :, :orig_q_seq_len, :]
@@ -741,6 +755,8 @@ def _apply_attention(
741755
mask_padding_tokens: bool = True,
742756
residual_checkpoint_name: str | None = None,
743757
attention_mask: Array = None,
758+
use_base2_exp: bool = False,
759+
use_experimental_scheduler: bool = False,
744760
):
745761
"""Routes to different attention kernels."""
746762
_check_attention_inputs(query, key, value)
@@ -789,6 +805,8 @@ def _apply_attention(
789805
mask_padding_tokens=mask_padding_tokens,
790806
residual_checkpoint_name=residual_checkpoint_name,
791807
attention_mask=attention_mask,
808+
use_base2_exp=use_base2_exp,
809+
use_experimental_scheduler=use_experimental_scheduler,
792810
)
793811
elif "ring" in attention_kernel:
794812
return _tpu_flash_attention(
@@ -983,8 +1001,12 @@ def __init__(
9831001
quant: Quant = None,
9841002
mask_padding_tokens: bool = True,
9851003
residual_checkpoint_name: str | None = None,
1004+
use_base2_exp: bool = False,
1005+
use_experimental_scheduler: bool = False,
9861006
):
9871007
self.dpa_layer = None
1008+
self.use_base2_exp = use_base2_exp
1009+
self.use_experimental_scheduler = use_experimental_scheduler
9881010
if attention_kernel == "cudnn_flash_te":
9891011
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
9901012

@@ -1045,6 +1067,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
10451067
mask_padding_tokens=self.mask_padding_tokens,
10461068
residual_checkpoint_name=self.residual_checkpoint_name,
10471069
attention_mask=attention_mask,
1070+
use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False,
1071+
use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False,
10481072
)
10491073

10501074

@@ -1063,6 +1087,8 @@ class AttentionOp(nn.Module):
10631087
flash_block_sizes: BlockSizes = None
10641088
dtype: DType = jnp.float32
10651089
quant: Quant = None
1090+
use_base2_exp: bool = False
1091+
use_experimental_scheduler: bool = False
10661092

10671093
def setup(self):
10681094
self.dpa_layer = None
@@ -1108,6 +1134,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
11081134
flash_block_sizes=self.flash_block_sizes,
11091135
dpa_layer=self.dpa_layer,
11101136
attention_mask=attention_mask,
1137+
use_base2_exp=self.use_base2_exp,
1138+
use_experimental_scheduler=self.use_experimental_scheduler,
11111139
)
11121140

11131141

@@ -1144,6 +1172,8 @@ def __init__(
11441172
enable_jax_named_scopes: bool = False,
11451173
added_kv_proj_dim: Optional[int] = None, # New for I2V
11461174
image_seq_len: Optional[int] = None, # New for I2V
1175+
use_base2_exp: bool = False,
1176+
use_experimental_scheduler: bool = False,
11471177
):
11481178
if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
11491179
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
@@ -1186,6 +1216,8 @@ def __init__(
11861216
quant=quant,
11871217
mask_padding_tokens=mask_padding_tokens,
11881218
residual_checkpoint_name=residual_checkpoint_name,
1219+
use_base2_exp=use_base2_exp,
1220+
use_experimental_scheduler=use_experimental_scheduler,
11891221
)
11901222
# None axes corresponds to the stacked weights across all blocks
11911223
# because of the use of nnx.vmap and nnx.scan.

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ def _decode(
12061206
fm1, fm2, fm3, fm4 = out_chunk_1[:, 0, ...], out_chunk_1[:, 1, ...], out_chunk_1[:, 2, ...], out_chunk_1[:, 3, ...]
12071207
axis = 1 if fm1.shape[0] > 1 else 0
12081208
fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]]
1209-
out_1 = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1209+
out_1 = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1)
12101210

12111211
out_list = [out_0, out_1]
12121212

@@ -1226,7 +1226,7 @@ def scan_fn(carry, chunk_in):
12261226
fm1, fm2, fm3, fm4 = out_chunk[:, 0, ...], out_chunk[:, 1, ...], out_chunk[:, 2, ...], out_chunk[:, 3, ...]
12271227
axis = 1 if fm1.shape[0] > 1 else 0
12281228
fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]]
1229-
new_chunk = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1229+
new_chunk = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1)
12301230

12311231
return next_feat_map, new_chunk
12321232

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ def __init__(
291291
dropout: float = 0.0,
292292
mask_padding_tokens: bool = True,
293293
enable_jax_named_scopes: bool = False,
294+
use_base2_exp: bool = False,
295+
use_experimental_scheduler: bool = False,
294296
):
295297
self.enable_jax_named_scopes = enable_jax_named_scopes
296298

@@ -315,6 +317,8 @@ def __init__(
315317
mask_padding_tokens=mask_padding_tokens,
316318
residual_checkpoint_name="self_attn",
317319
enable_jax_named_scopes=enable_jax_named_scopes,
320+
use_base2_exp=use_base2_exp,
321+
use_experimental_scheduler=use_experimental_scheduler,
318322
)
319323

320324
# 1. Cross-attention
@@ -339,6 +343,8 @@ def __init__(
339343
mask_padding_tokens=mask_padding_tokens,
340344
residual_checkpoint_name="cross_attn",
341345
enable_jax_named_scopes=enable_jax_named_scopes,
346+
use_base2_exp=use_base2_exp,
347+
use_experimental_scheduler=use_experimental_scheduler,
342348
)
343349
assert cross_attn_norm is True
344350
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -486,6 +492,8 @@ def __init__(
486492
mask_padding_tokens: bool = True,
487493
scan_layers: bool = True,
488494
enable_jax_named_scopes: bool = False,
495+
use_base2_exp: bool = False,
496+
use_experimental_scheduler: bool = False,
489497
):
490498
inner_dim = num_attention_heads * attention_head_dim
491499
out_channels = out_channels or in_channels
@@ -547,6 +555,8 @@ def init_block(rngs):
547555
enable_jax_named_scopes=enable_jax_named_scopes,
548556
added_kv_proj_dim=added_kv_proj_dim,
549557
image_seq_len=image_seq_len,
558+
use_base2_exp=use_base2_exp,
559+
use_experimental_scheduler=use_experimental_scheduler,
550560
)
551561

552562
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)

0 commit comments

Comments
 (0)