Skip to content

Commit 88b1c2a

Browse files
committed
feat: add optional batched text encoder and diffusion loop
1 parent bb3b0c6 commit 88b1c2a

10 files changed

Lines changed: 252 additions & 15 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ flow_shift: 3.0
348348
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
349349
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
350350
use_cfg_cache: False
351+
352+
# Batch positive and negative prompts in text encoder to save compute.
353+
use_batched_text_encoder: False
354+
355+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
356+
# Note: Enabling this will disable per-step profiling.
357+
scan_diffusion_loop: False
351358
use_magcache: False
352359
magcache_thresh: 0.12
353360
magcache_K: 2

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ flow_shift: 3.0
302302
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
303303
use_cfg_cache: False
304304

305+
# Batch positive and negative prompts in text encoder to save compute.
306+
use_batched_text_encoder: False
307+
308+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
309+
# Note: Enabling this will disable per-step profiling.
310+
scan_diffusion_loop: False
311+
305312
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
306313
guidance_rescale: 0.0
307314
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,13 @@ boundary_ratio: 0.875
323323

324324
# Diffusion CFG cache (FasterCache-style)
325325
use_cfg_cache: False
326+
327+
# Batch positive and negative prompts in text encoder to save compute.
328+
use_batched_text_encoder: False
329+
330+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
331+
# Note: Enabling this will disable per-step profiling.
332+
scan_diffusion_loop: False
326333
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
327334
# when predicted output change (based on accumulated latent/timestep drift) is small
328335
use_sen_cache: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,13 @@ flow_shift: 5.0
306306

307307
# Diffusion CFG cache (FasterCache-style)
308308
use_cfg_cache: False
309+
310+
# Batch positive and negative prompts in text encoder to save compute.
311+
use_batched_text_encoder: False
312+
313+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
314+
# Note: Enabling this will disable per-step profiling.
315+
scan_diffusion_loop: False
309316
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
310317
use_sen_cache: False
311318
use_magcache: False

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,13 @@ boundary_ratio: 0.875
318318

319319
# Diffusion CFG cache (FasterCache-style)
320320
use_cfg_cache: False
321+
322+
# Batch positive and negative prompts in text encoder to save compute.
323+
use_batched_text_encoder: False
324+
325+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
326+
# Note: Enabling this will disable per-step profiling.
327+
scan_diffusion_loop: False
321328
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
322329
use_sen_cache: False
323330

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -501,24 +501,45 @@ def encode_prompt(
501501
negative_prompt_embeds: jax.Array = None,
502502
):
503503
prompt = [prompt] if isinstance(prompt, str) else prompt
504-
if prompt_embeds is None:
505-
prompt_embeds = self._get_t5_prompt_embeds(
506-
prompt=prompt,
507-
num_videos_per_prompt=num_videos_per_prompt,
508-
max_sequence_length=max_sequence_length,
509-
)
510-
prompt_embeds = jnp.array(prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
511-
512-
if negative_prompt_embeds is None:
513-
batch_size = len(prompt_embeds)
514-
negative_prompt = negative_prompt or ""
515-
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
516-
negative_prompt_embeds = self._get_t5_prompt_embeds(
517-
prompt=negative_prompt,
504+
batch_size = len(prompt)
505+
506+
if negative_prompt is None:
507+
negative_prompt = [""] * batch_size
508+
elif isinstance(negative_prompt, str):
509+
negative_prompt = [negative_prompt] * batch_size
510+
511+
use_batched_text_encoder = getattr(self.config, "use_batched_text_encoder", False)
512+
if use_batched_text_encoder and prompt_embeds is None and negative_prompt_embeds is None:
513+
# Batch both together
514+
combined_prompts = prompt + negative_prompt
515+
combined_embeds = self._get_t5_prompt_embeds(
516+
prompt=combined_prompts,
518517
num_videos_per_prompt=num_videos_per_prompt,
519518
max_sequence_length=max_sequence_length,
520519
)
521-
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
520+
combined_embeds = jnp.array(combined_embeds.detach().float().numpy(), dtype=jnp.float32)
521+
522+
# Split back
523+
prompt_embeds = combined_embeds[: batch_size * num_videos_per_prompt]
524+
negative_prompt_embeds = combined_embeds[batch_size * num_videos_per_prompt :]
525+
526+
else:
527+
# Fallback to separate encoding if one of them is already provided
528+
if prompt_embeds is None:
529+
prompt_embeds = self._get_t5_prompt_embeds(
530+
prompt=prompt,
531+
num_videos_per_prompt=num_videos_per_prompt,
532+
max_sequence_length=max_sequence_length,
533+
)
534+
prompt_embeds = jnp.array(prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
535+
536+
if negative_prompt_embeds is None:
537+
negative_prompt_embeds = self._get_t5_prompt_embeds(
538+
prompt=negative_prompt,
539+
num_videos_per_prompt=num_videos_per_prompt,
540+
max_sequence_length=max_sequence_length,
541+
)
542+
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
522543

523544
return prompt_embeds, negative_prompt_embeds
524545

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,52 @@ def run_inference_2_1(
261261
profiler_steps = config.profiler_steps if config else 0
262262
last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1)
263263

264+
scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False
265+
266+
if scan_diffusion_loop and not use_magcache and not use_cfg_cache:
267+
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)
268+
269+
scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32))
270+
271+
def scan_body(carry, t):
272+
current_latents, current_scheduler_state = carry
273+
274+
if do_cfg:
275+
latents_doubled = jnp.concatenate([current_latents] * 2)
276+
timestep = jnp.broadcast_to(t, bsz * 2)
277+
noise_pred, _, _ = transformer_forward_pass_full_cfg(
278+
graphdef,
279+
sharded_state,
280+
rest_of_state,
281+
latents_doubled,
282+
timestep,
283+
prompt_embeds_combined,
284+
guidance_scale=guidance_scale,
285+
)
286+
else:
287+
timestep = jnp.broadcast_to(t, bsz)
288+
noise_pred, _ = transformer_forward_pass(
289+
graphdef,
290+
sharded_state,
291+
rest_of_state,
292+
current_latents,
293+
timestep,
294+
prompt_cond_embeds,
295+
do_classifier_free_guidance=False,
296+
guidance_scale=guidance_scale,
297+
)
298+
299+
new_latents, new_scheduler_state = scheduler.step(current_scheduler_state, noise_pred, t, current_latents).to_tuple()
300+
301+
return (new_latents, new_scheduler_state), None
302+
303+
initial_carry = (latents, scheduler_state)
304+
305+
final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps)
306+
307+
final_latents, _ = final_carry
308+
return final_latents
309+
264310
profiler = None
265311
for step in range(num_inference_steps):
266312
if config and max_utils.profiler_enabled(config) and step == first_profiling_step:

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,61 @@ def run_inference_2_2(
471471
profiler_steps = config.profiler_steps if config else 0
472472
last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1)
473473

474+
scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False
475+
476+
if scan_diffusion_loop:
477+
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)
478+
479+
scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32))
480+
481+
def scan_body(carry, t):
482+
current_latents, current_scheduler_state = carry
483+
484+
if do_classifier_free_guidance:
485+
model_latents = jnp.concatenate([current_latents] * 2)
486+
else:
487+
model_latents = current_latents
488+
489+
timestep = jnp.broadcast_to(t, model_latents.shape[0])
490+
use_high_noise = jnp.greater_equal(t, boundary)
491+
492+
def high_branch(_):
493+
return transformer_forward_pass(
494+
high_noise_graphdef,
495+
high_noise_state,
496+
high_noise_rest,
497+
model_latents,
498+
timestep,
499+
prompt_embeds_combined,
500+
do_classifier_free_guidance,
501+
guidance_scale_high,
502+
)
503+
504+
def low_branch(_):
505+
return transformer_forward_pass(
506+
low_noise_graphdef,
507+
low_noise_state,
508+
low_noise_rest,
509+
model_latents,
510+
timestep,
511+
prompt_embeds_combined,
512+
do_classifier_free_guidance,
513+
guidance_scale_low,
514+
)
515+
516+
noise_pred, latents_out = jax.lax.cond(use_high_noise, high_branch, low_branch, operand=None)
517+
518+
new_latents, new_scheduler_state = scheduler.step(current_scheduler_state, noise_pred, t, latents_out).to_tuple()
519+
520+
return (new_latents, new_scheduler_state), None
521+
522+
initial_carry = (latents, scheduler_state)
523+
524+
final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps)
525+
526+
final_latents, _ = final_carry
527+
return final_latents
528+
474529
profiler = None
475530
for step in range(num_inference_steps):
476531
if config and max_utils.profiler_enabled(config) and step == first_profiling_step:

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,54 @@ def run_inference_2_1_i2v(
317317
profiler_steps = config.profiler_steps if config else 0
318318
last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1)
319319

320+
scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False
321+
322+
if scan_diffusion_loop and not use_magcache:
323+
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)
324+
325+
scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32))
326+
327+
def scan_body(carry, t):
328+
current_latents, current_scheduler_state = carry
329+
330+
latents_input = current_latents
331+
if do_cfg:
332+
latents_input = jnp.concatenate([current_latents, current_latents], axis=0)
333+
334+
latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1)
335+
timestep = jnp.broadcast_to(t, latents_input.shape[0])
336+
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
337+
338+
outputs = transformer_forward_pass(
339+
graphdef,
340+
sharded_state,
341+
rest_of_state,
342+
latent_model_input,
343+
timestep,
344+
prompt_embeds_combined,
345+
do_classifier_free_guidance=do_cfg,
346+
guidance_scale=guidance_scale,
347+
encoder_hidden_states_image=image_embeds_combined,
348+
skip_blocks=None,
349+
cached_residual=None,
350+
return_residual=False,
351+
)
352+
noise_pred, _ = outputs
353+
354+
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
355+
new_latents, new_scheduler_state = scheduler.step(
356+
current_scheduler_state, noise_pred, t, current_latents, return_dict=False
357+
)
358+
359+
return (new_latents, new_scheduler_state), None
360+
361+
initial_carry = (latents, scheduler_state)
362+
363+
final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps)
364+
365+
final_latents, _ = final_carry
366+
return final_latents
367+
320368
profiler = None
321369
for step in range(num_inference_steps):
322370
if config and max_utils.profiler_enabled(config) and step == first_profiling_step:

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,38 @@ def low_noise_branch(operands):
609609
profiler_steps = config.profiler_steps if config else 0
610610
last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1)
611611

612+
scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False
613+
614+
if scan_diffusion_loop:
615+
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)
616+
617+
scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32))
618+
619+
def scan_body(carry, t):
620+
current_latents, current_scheduler_state = carry
621+
622+
latents_input = current_latents
623+
if do_classifier_free_guidance:
624+
latents_input = jnp.concatenate([current_latents, current_latents], axis=0)
625+
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
626+
timestep = jnp.broadcast_to(t, latents_input.shape[0])
627+
628+
use_high_noise = jnp.greater_equal(t, boundary)
629+
noise_pred, _ = jax.lax.cond(
630+
use_high_noise, high_noise_branch, low_noise_branch, (latent_model_input, timestep, prompt_embeds, image_embeds)
631+
)
632+
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
633+
new_latents, new_scheduler_state = scheduler.step(current_scheduler_state, noise_pred, t, current_latents).to_tuple()
634+
635+
return (new_latents, new_scheduler_state), None
636+
637+
initial_carry = (latents, scheduler_state)
638+
639+
final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps)
640+
641+
final_latents, _ = final_carry
642+
return final_latents
643+
612644
profiler = None
613645
for step in range(num_inference_steps):
614646
if config and max_utils.profiler_enabled(config) and step == first_profiling_step:

0 commit comments

Comments
 (0)