Skip to content

Commit 0b99f74

Browse files
Merge pull request #411 from AI-Hypercomputer:wan_vae_opt
PiperOrigin-RevId: 921650325
2 parents 4691a2c + 1710495 commit 0b99f74

21 files changed

Lines changed: 408 additions & 193 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,23 @@ text_encoder_dtype: 'float32'
4747
# Whether to compile the text_encoder with torch.compile
4848
compile_text_encoder: False
4949

50+
# Maximum sequence length for the text encoder
51+
max_sequence_length: 512
52+
53+
vae_weights_dtype: 'float32'
54+
vae_dtype: 'float32'
55+
scheduler_dtype: 'float32'
56+
5057
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5158
replicate_vae: False
52-
vae_spatial: -1 # default to total_device * 2 // (dp)
59+
60+
# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
61+
vae_decode_chunk: 1
62+
63+
# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
64+
# Increase to improve encode time at the cost of memory.
65+
vae_encode_chunk: 4
66+
vae_spatial: -1
5367

5468
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
5569
# Options are "DEFAULT", "HIGH", "HIGHEST"

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,24 @@ text_encoder_dtype: 'float32'
4747
# Whether to compile the text_encoder with torch.compile
4848
compile_text_encoder: False
4949

50+
# Maximum sequence length for the text encoder
51+
max_sequence_length: 512
52+
53+
vae_weights_dtype: 'float32'
54+
vae_dtype: 'float32'
55+
scheduler_dtype: 'float32'
56+
5057
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5158
replicate_vae: False
5259

60+
# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
61+
vae_decode_chunk: 1
62+
63+
# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
64+
# Increase to improve encode time at the cost of memory.
65+
vae_encode_chunk: 4
66+
vae_spatial: -1
67+
5368
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
5469
# Options are "DEFAULT", "HIGH", "HIGHEST"
5570
# fp32 activations and fp32 weights with HIGHEST will provide the best precision

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,23 @@ text_encoder_dtype: 'float32'
4747
# Whether to compile the text_encoder with torch.compile
4848
compile_text_encoder: False
4949

50+
# Maximum sequence length for the text encoder
51+
max_sequence_length: 512
52+
53+
vae_weights_dtype: 'float32'
54+
vae_dtype: 'float32'
55+
scheduler_dtype: 'float32'
56+
5057
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5158
replicate_vae: False
52-
vae_spatial: -1 # default to total_device * 2 // (dp)
59+
60+
# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
61+
vae_decode_chunk: 1
62+
63+
# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
64+
# Increase to improve encode time at the cost of memory.
65+
vae_encode_chunk: 4
66+
vae_spatial: -1
5367

5468
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
5569
# Options are "DEFAULT", "HIGH", "HIGHEST"

src/maxdiffusion/configs/base_wan_animate.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,22 @@ text_encoder_dtype: 'float32'
4747
# Whether to compile the text_encoder with torch.compile
4848
compile_text_encoder: False
4949

50+
# Maximum sequence length for the text encoder
51+
max_sequence_length: 512
52+
53+
vae_weights_dtype: 'float32'
54+
vae_dtype: 'float32'
55+
scheduler_dtype: 'float32'
56+
5057
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5158
replicate_vae: False
59+
60+
# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
61+
vae_decode_chunk: 1
62+
63+
# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
64+
# Increase to improve encode time at the cost of memory.
65+
vae_encode_chunk: 4
5266
# Number of devices to shard VAE spatial activations across. -1 uses all devices.
5367
vae_spatial: -1
5468

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,23 @@ text_encoder_dtype: 'float32'
4747
# Whether to compile the text_encoder with torch.compile
4848
compile_text_encoder: False
4949

50+
# Maximum sequence length for the text encoder
51+
max_sequence_length: 512
52+
53+
vae_weights_dtype: 'float32'
54+
vae_dtype: 'float32'
55+
scheduler_dtype: 'float32'
56+
5057
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5158
replicate_vae: False
52-
vae_spatial: -1 # default to total_device * 2 // (dp)
59+
60+
# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
61+
vae_decode_chunk: 1
62+
63+
# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
64+
# Increase to improve encode time at the cost of memory.
65+
vae_encode_chunk: 4
66+
vae_spatial: -1
5367

5468
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
5569
# Options are "DEFAULT", "HIGH", "HIGHEST"

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,23 @@ text_encoder_dtype: 'float32'
4747
# Whether to compile the text_encoder with torch.compile
4848
compile_text_encoder: False
4949

50+
# Maximum sequence length for the text encoder
51+
max_sequence_length: 512
52+
53+
vae_weights_dtype: 'float32'
54+
vae_dtype: 'float32'
55+
scheduler_dtype: 'float32'
56+
5057
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
5158
replicate_vae: False
52-
vae_spatial: -1 # default to total_device * 2 // (dp)
59+
60+
# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
61+
vae_decode_chunk: 1
62+
63+
# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
64+
# Increase to improve encode time at the cost of memory.
65+
vae_encode_chunk: 4
66+
vae_spatial: -1
5367

5468
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
5569
# Options are "DEFAULT", "HIGH", "HIGHEST"

src/maxdiffusion/generate_wan.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,17 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
315315
f" Inference: {generation_time:>7.1f}s",
316316
]
317317
if trace:
318+
vae_decode_total = trace.get("vae_decode", 0.0)
319+
vae_decode_tpu = trace.get("vae_decode_tpu", 0.0)
320+
vae_decode_post = vae_decode_total - vae_decode_tpu
318321
summary.extend([
319322
f" {'─' * 40}",
320323
f" Conditioning: {trace.get('conditioning', 0.0):>7.1f}s",
324+
f" - VAE Encode: {trace.get('vae_encode', 0.0):>7.1f}s",
321325
f" Denoise Total: {trace.get('denoise_total', 0.0):>7.1f}s",
322-
f" VAE Decode: {trace.get('vae_decode', 0.0):>7.1f}s",
326+
f" VAE Decode: {vae_decode_total:>7.1f}s",
327+
f" - TPU Compute: {vae_decode_tpu:>7.1f}s",
328+
f" - Host Formatting: {vae_decode_post:>7.1f}s",
323329
])
324330
summary.append(f"{'=' * 50}")
325331
max_logging.log("\n".join(summary))

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _tpu_flash_attention(
325325
) -> jax.Array:
326326
"""TPU Flash Attention"""
327327

328-
num_context_shards = mesh.shape["context"]
328+
num_context_shards = mesh.shape["context"] if "context" in mesh.shape else 1
329329
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
330330
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
331331
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
@@ -491,7 +491,9 @@ def ring_scan_body(carry, _):
491491
raise ValueError("ring attention requires context > 1")
492492
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
493493

494-
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
494+
data_dim = mesh.shape["data"] if "data" in mesh.shape else 1
495+
fsdp_dim = mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1
496+
devices_in_batch_sharding = data_dim * fsdp_dim
495497
# This warning might show up when doing model eval for example, when calculating model flops
496498
# and that is expected.
497499
if not (query.shape[0] / devices_in_batch_sharding).is_integer():

0 commit comments

Comments
 (0)