Skip to content

Commit d237477

Browse files
committed
feat: add optional batched text encoder and diffusion loop
1 parent bb3b0c6 commit d237477

11 files changed

Lines changed: 414 additions & 107 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ vae_spatial: -1 # default to total_device * 2 // (dp)
5353
precision: "DEFAULT"
5454
# Use jax.lax.scan for transformer layers
5555
scan_layers: True
56+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
57+
# Note: Enabling this will disable per-step profiling.
58+
scan_diffusion_loop: False
5659

5760
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
5861
# It must be True for multi-host.
@@ -61,21 +64,21 @@ jit_initializers: True
6164
# Set true to load weights from pytorch
6265
from_pt: True
6366
split_head_dim: True
64-
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
6568
use_base2_exp: True
6669
use_experimental_scheduler: True
67-
flash_min_seq_length: 0
70+
flash_min_seq_length: 4096
71+
dropout: 0.0
6872

6973
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
7074
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
7175
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
72-
mask_padding_tokens: True
76+
mask_padding_tokens: True
7377
# Maxdiffusion has 2 types of attention sharding strategies:
7478
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
7579
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7680
# in cross attention q.
77-
attention_sharding_uniform: True
78-
dropout: 0.0
81+
attention_sharding_uniform: True
7982

8083
flash_block_sizes: {
8184
"block_q" : 512,
@@ -202,9 +205,9 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']]
202205
# value to auto-shard based on available slices and devices.
203206
# By default, product of the DCN axes should equal number of slices
204207
# and product of the ICI axes should equal number of devices per slice.
205-
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
208+
dcn_data_parallelism: 1
206209
dcn_fsdp_parallelism: 1
207-
dcn_context_parallelism: -1
210+
dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded
208211
dcn_tensor_parallelism: 1
209212
ici_data_parallelism: 1
210213
ici_fsdp_parallelism: 1
@@ -338,16 +341,20 @@ prompt: "A cat and a dog baking a cake together in a kitchen. The cat is careful
338341
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
339342
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
340343
do_classifier_free_guidance: True
341-
height: 480
342-
width: 832
344+
height: 720
345+
width: 1280
343346
num_frames: 81
344347
guidance_scale: 5.0
345-
flow_shift: 3.0
348+
flow_shift: 5.0
346349

347350
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
348351
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
349352
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
350353
use_cfg_cache: False
354+
355+
# Batch positive and negative prompts in text encoder to save compute.
356+
use_batched_text_encoder: False
357+
351358
use_magcache: False
352359
magcache_thresh: 0.12
353360
magcache_K: 2
@@ -356,7 +363,7 @@ mag_ratios_base: [1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962,
356363

357364
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
358365
guidance_rescale: 0.0
359-
num_inference_steps: 30
366+
num_inference_steps: 50
360367
fps: 16
361368
save_final_checkpoint: False
362369

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ flow_shift: 3.0
302302
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
303303
use_cfg_cache: False
304304

305+
# Batch positive and negative prompts in text encoder to save compute.
306+
use_batched_text_encoder: False
307+
308+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
309+
# Note: Enabling this will disable per-step profiling.
310+
scan_diffusion_loop: False
311+
305312
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
306313
guidance_rescale: 0.0
307314
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ activations_dtype: 'bfloat16'
4444

4545
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4646
replicate_vae: False
47-
vae_spatial: 1
47+
vae_spatial: -1 # default to total_device * 2 // (dp)
4848

4949
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
5050
# Options are "DEFAULT", "HIGH", "HIGHEST"
@@ -53,6 +53,9 @@ vae_spatial: 1
5353
precision: "DEFAULT"
5454
# Use jax.lax.scan for transformer layers
5555
scan_layers: True
56+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
57+
# Note: Enabling this will disable per-step profiling.
58+
scan_diffusion_loop: False
5659

5760
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
5861
# It must be True for multi-host.
@@ -61,20 +64,21 @@ jit_initializers: True
6164
# Set true to load weights from pytorch
6265
from_pt: True
6366
split_head_dim: True
64-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
6568
use_base2_exp: True
6669
use_experimental_scheduler: True
6770
flash_min_seq_length: 4096
71+
dropout: 0.0
72+
6873
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6974
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
7075
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
71-
mask_padding_tokens: True
76+
mask_padding_tokens: True
7277
# Maxdiffusion has 2 types of attention sharding strategies:
7378
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
7479
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7580
# in cross attention q.
76-
attention_sharding_uniform: True
77-
dropout: 0.0
81+
attention_sharding_uniform: True
7882

7983
flash_block_sizes: {
8084
"block_q" : 512,
@@ -159,7 +163,7 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
159163
logical_axis_rules: [
160164
['batch', ['data', 'fsdp']],
161165
['activation_batch', ['data', 'fsdp']],
162-
['activation_self_attn_heads', ['context', 'tensor']],
166+
['activation_self_attn_heads', ['context', 'tensor']],
163167
['activation_cross_attn_q_length', ['context', 'tensor']],
164168
['activation_length', 'context'],
165169
['activation_heads', 'tensor'],
@@ -190,9 +194,9 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']]
190194
# value to auto-shard based on available slices and devices.
191195
# By default, product of the DCN axes should equal number of slices
192196
# and product of the ICI axes should equal number of devices per slice.
193-
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
197+
dcn_data_parallelism: 1
194198
dcn_fsdp_parallelism: 1
195-
dcn_context_parallelism: -1
199+
dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded
196200
dcn_tensor_parallelism: 1
197201
ici_data_parallelism: 1
198202
ici_fsdp_parallelism: 1
@@ -304,17 +308,17 @@ prompt: "A cat and a dog baking a cake together in a kitchen. The cat is careful
304308
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
305309
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
306310
do_classifier_free_guidance: True
307-
height: 480
308-
width: 832
311+
height: 720
312+
width: 1280
309313
num_frames: 81
310-
flow_shift: 3.0
314+
flow_shift: 5.0
311315

312316
# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py
313317
# guidance scale factor for low noise transformer
314-
guidance_scale_low: 3.0
318+
guidance_scale_low: 3.0
315319

316320
# guidance scale factor for high noise transformer
317-
guidance_scale_high: 4.0
321+
guidance_scale_high: 4.0
318322

319323
# The timestep threshold. If `t` is at or above this value,
320324
# the `high_noise_model` is considered as the required model.
@@ -323,14 +327,19 @@ boundary_ratio: 0.875
323327

324328
# Diffusion CFG cache (FasterCache-style)
325329
use_cfg_cache: False
330+
331+
# Batch positive and negative prompts in text encoder to save compute.
332+
use_batched_text_encoder: False
333+
334+
326335
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
327336
# when predicted output change (based on accumulated latent/timestep drift) is small
328337
use_sen_cache: False
329338

330339
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
331340
guidance_rescale: 0.0
332-
num_inference_steps: 30
333-
fps: 24
341+
num_inference_steps: 40
342+
fps: 16
334343
save_final_checkpoint: False
335344

336345
# SDXL Lightning parameters

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'
4444

4545
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4646
replicate_vae: False
47+
vae_spatial: -1 # default to total_device * 2 // (dp)
4748

4849
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4950
# Options are "DEFAULT", "HIGH", "HIGHEST"
@@ -52,6 +53,9 @@ replicate_vae: False
5253
precision: "DEFAULT"
5354
# Use jax.lax.scan for transformer layers
5455
scan_layers: True
56+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
57+
# Note: Enabling this will disable per-step profiling.
58+
scan_diffusion_loop: False
5559

5660
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
5761
# It must be True for multi-host.
@@ -60,7 +64,7 @@ jit_initializers: True
6064
# Set true to load weights from pytorch
6165
from_pt: True
6266
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
6468
use_base2_exp: True
6569
use_experimental_scheduler: True
6670
flash_min_seq_length: 4096
@@ -69,7 +73,11 @@ dropout: 0.0
6973
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
7074
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
7175
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
72-
mask_padding_tokens: True
76+
mask_padding_tokens: True
77+
# Maxdiffusion has 2 types of attention sharding strategies:
78+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
79+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
80+
# in cross attention q.
7381
attention_sharding_uniform: True
7482

7583
flash_block_sizes: {
@@ -184,13 +192,13 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']]
184192
# value to auto-shard based on available slices and devices.
185193
# By default, product of the DCN axes should equal number of slices
186194
# and product of the ICI axes should equal number of devices per slice.
187-
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
188-
dcn_fsdp_parallelism: -1
189-
dcn_context_parallelism: 1
195+
dcn_data_parallelism: 1
196+
dcn_fsdp_parallelism: 1
197+
dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded
190198
dcn_tensor_parallelism: 1
191199
ici_data_parallelism: 1
192-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
193-
ici_context_parallelism: 1
200+
ici_fsdp_parallelism: 1
201+
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
194202
ici_tensor_parallelism: 1
195203

196204
allow_split_physical_axes: False
@@ -306,6 +314,11 @@ flow_shift: 5.0
306314

307315
# Diffusion CFG cache (FasterCache-style)
308316
use_cfg_cache: False
317+
318+
# Batch positive and negative prompts in text encoder to save compute.
319+
use_batched_text_encoder: False
320+
321+
309322
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
310323
use_sen_cache: False
311324
use_magcache: False

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'
4444

4545
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4646
replicate_vae: False
47+
vae_spatial: -1 # default to total_device * 2 // (dp)
4748

4849
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4950
# Options are "DEFAULT", "HIGH", "HIGHEST"
@@ -52,6 +53,9 @@ replicate_vae: False
5253
precision: "DEFAULT"
5354
# Use jax.lax.scan for transformer layers
5455
scan_layers: True
56+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
57+
# Note: Enabling this will disable per-step profiling.
58+
scan_diffusion_loop: False
5559

5660
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
5761
# It must be True for multi-host.
@@ -60,7 +64,7 @@ jit_initializers: True
6064
# Set true to load weights from pytorch
6165
from_pt: True
6266
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
67+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
6468
use_base2_exp: True
6569
use_experimental_scheduler: True
6670
flash_min_seq_length: 4096
@@ -69,7 +73,11 @@ dropout: 0.0
6973
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
7074
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
7175
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
72-
mask_padding_tokens: True
76+
mask_padding_tokens: True
77+
# Maxdiffusion has 2 types of attention sharding strategies:
78+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
79+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
80+
# in cross attention q.
7381
attention_sharding_uniform: True
7482

7583
flash_block_sizes: {
@@ -185,13 +193,13 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']]
185193
# value to auto-shard based on available slices and devices.
186194
# By default, product of the DCN axes should equal number of slices
187195
# and product of the ICI axes should equal number of devices per slice.
188-
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
189-
dcn_fsdp_parallelism: -1
190-
dcn_context_parallelism: 1
196+
dcn_data_parallelism: 1
197+
dcn_fsdp_parallelism: 1
198+
dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded
191199
dcn_tensor_parallelism: 1
192200
ici_data_parallelism: 1
193-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
194-
ici_context_parallelism: 1
201+
ici_fsdp_parallelism: 1
202+
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
195203
ici_tensor_parallelism: 1
196204

197205
allow_split_physical_axes: False
@@ -318,12 +326,17 @@ boundary_ratio: 0.875
318326

319327
# Diffusion CFG cache (FasterCache-style)
320328
use_cfg_cache: False
329+
330+
# Batch positive and negative prompts in text encoder to save compute.
331+
use_batched_text_encoder: False
332+
333+
321334
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
322335
use_sen_cache: False
323336

324337
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
325338
guidance_rescale: 0.0
326-
num_inference_steps: 50
339+
num_inference_steps: 40
327340
fps: 16
328341
save_final_checkpoint: False
329342

src/maxdiffusion/generate_wan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,14 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
304304
f"{'=' * 50}",
305305
f" Load (checkpoint): {load_time:>7.1f}s",
306306
f" Compile: {compile_time:>7.1f}s",
307-
f" {'─' * 40}",
308307
f" Inference: {generation_time:>7.1f}s",
309308
]
310309
if trace:
311310
summary.extend([
312-
f" Conditioning: {trace.get('conditioning', 0.0):>7.1f}s",
311+
f" {'─' * 40}",
312+
f" Text Encoder: {trace.get('text_encoder', 0.0):>7.1f}s",
313+
f" Image Encoder: {trace.get('image_encoder', 0.0):>7.1f}s",
314+
f" Latent Generation: {trace.get('latent_generation', 0.0):>7.1f}s",
313315
f" Denoise Total: {trace.get('denoise_total', 0.0):>7.1f}s",
314316
f" VAE Decode: {trace.get('vae_decode', 0.0):>7.1f}s",
315317
])

0 commit comments

Comments
 (0)