Skip to content

Commit 867ae29

Browse files
committed
feat: add optional batched text encoder and diffusion loop
1 parent bb3b0c6 commit 867ae29

7 files changed

Lines changed: 126 additions & 15 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ flow_shift: 3.0
348348
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
349349
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
350350
use_cfg_cache: False
351+
352+
# Batch positive and negative prompts in text encoder to save compute.
353+
use_batched_text_encoder: False
354+
355+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
356+
# Note: Enabling this will disable per-step profiling.
357+
scan_diffusion_loop: False
351358
use_magcache: False
352359
magcache_thresh: 0.12
353360
magcache_K: 2

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ flow_shift: 3.0
302302
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
303303
use_cfg_cache: False
304304

305+
# Batch positive and negative prompts in text encoder to save compute.
306+
use_batched_text_encoder: False
307+
308+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
309+
# Note: Enabling this will disable per-step profiling.
310+
scan_diffusion_loop: False
311+
305312
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
306313
guidance_rescale: 0.0
307314
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,13 @@ boundary_ratio: 0.875
323323

324324
# Diffusion CFG cache (FasterCache-style)
325325
use_cfg_cache: False
326+
327+
# Batch positive and negative prompts in text encoder to save compute.
328+
use_batched_text_encoder: False
329+
330+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
331+
# Note: Enabling this will disable per-step profiling.
332+
scan_diffusion_loop: False
326333
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
327334
# when predicted output change (based on accumulated latent/timestep drift) is small
328335
use_sen_cache: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,13 @@ flow_shift: 5.0
306306

307307
# Diffusion CFG cache (FasterCache-style)
308308
use_cfg_cache: False
309+
310+
# Batch positive and negative prompts in text encoder to save compute.
311+
use_batched_text_encoder: False
312+
313+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
314+
# Note: Enabling this will disable per-step profiling.
315+
scan_diffusion_loop: False
309316
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
310317
use_sen_cache: False
311318
use_magcache: False

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,13 @@ boundary_ratio: 0.875
318318

319319
# Diffusion CFG cache (FasterCache-style)
320320
use_cfg_cache: False
321+
322+
# Batch positive and negative prompts in text encoder to save compute.
323+
use_batched_text_encoder: False
324+
325+
# Use jax.lax.scan for the diffusion loop (non-cache path only).
326+
# Note: Enabling this will disable per-step profiling.
327+
scan_diffusion_loop: False
321328
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
322329
use_sen_cache: False
323330

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -501,24 +501,45 @@ def encode_prompt(
501501
negative_prompt_embeds: jax.Array = None,
502502
):
503503
prompt = [prompt] if isinstance(prompt, str) else prompt
504-
if prompt_embeds is None:
505-
prompt_embeds = self._get_t5_prompt_embeds(
506-
prompt=prompt,
507-
num_videos_per_prompt=num_videos_per_prompt,
508-
max_sequence_length=max_sequence_length,
509-
)
510-
prompt_embeds = jnp.array(prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
511-
512-
if negative_prompt_embeds is None:
513-
batch_size = len(prompt_embeds)
514-
negative_prompt = negative_prompt or ""
515-
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
516-
negative_prompt_embeds = self._get_t5_prompt_embeds(
517-
prompt=negative_prompt,
504+
batch_size = len(prompt)
505+
506+
if negative_prompt is None:
507+
negative_prompt = [""] * batch_size
508+
elif isinstance(negative_prompt, str):
509+
negative_prompt = [negative_prompt] * batch_size
510+
511+
use_batched_text_encoder = getattr(self.config, "use_batched_text_encoder", False)
512+
if use_batched_text_encoder and prompt_embeds is None and negative_prompt_embeds is None:
513+
# Batch both together
514+
combined_prompts = prompt + negative_prompt
515+
combined_embeds = self._get_t5_prompt_embeds(
516+
prompt=combined_prompts,
518517
num_videos_per_prompt=num_videos_per_prompt,
519518
max_sequence_length=max_sequence_length,
520519
)
521-
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
520+
combined_embeds = jnp.array(combined_embeds.detach().float().numpy(), dtype=jnp.float32)
521+
522+
# Split back
523+
prompt_embeds = combined_embeds[: batch_size * num_videos_per_prompt]
524+
negative_prompt_embeds = combined_embeds[batch_size * num_videos_per_prompt :]
525+
526+
else:
527+
# Fallback to separate encoding if one of them is already provided
528+
if prompt_embeds is None:
529+
prompt_embeds = self._get_t5_prompt_embeds(
530+
prompt=prompt,
531+
num_videos_per_prompt=num_videos_per_prompt,
532+
max_sequence_length=max_sequence_length,
533+
)
534+
prompt_embeds = jnp.array(prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
535+
536+
if negative_prompt_embeds is None:
537+
negative_prompt_embeds = self._get_t5_prompt_embeds(
538+
prompt=negative_prompt,
539+
num_videos_per_prompt=num_videos_per_prompt,
540+
max_sequence_length=max_sequence_length,
541+
)
542+
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
522543

523544
return prompt_embeds, negative_prompt_embeds
524545

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,61 @@ def run_inference_2_2(
471471
profiler_steps = config.profiler_steps if config else 0
472472
last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1)
473473

474+
scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False)
475+
476+
if scan_diffusion_loop:
477+
timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)
478+
479+
scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32))
480+
481+
def scan_body(carry, t):
482+
current_latents, current_scheduler_state = carry
483+
484+
if do_classifier_free_guidance:
485+
model_latents = jnp.concatenate([current_latents] * 2)
486+
else:
487+
model_latents = current_latents
488+
489+
timestep = jnp.broadcast_to(t, model_latents.shape[0])
490+
use_high_noise = jnp.greater_equal(t, boundary)
491+
492+
def high_branch(_):
493+
return transformer_forward_pass(
494+
high_noise_graphdef,
495+
high_noise_state,
496+
high_noise_rest,
497+
model_latents,
498+
timestep,
499+
prompt_embeds_combined,
500+
do_classifier_free_guidance,
501+
guidance_scale_high,
502+
)
503+
504+
def low_branch(_):
505+
return transformer_forward_pass(
506+
low_noise_graphdef,
507+
low_noise_state,
508+
low_noise_rest,
509+
model_latents,
510+
timestep,
511+
prompt_embeds_combined,
512+
do_classifier_free_guidance,
513+
guidance_scale_low,
514+
)
515+
516+
noise_pred, latents_out = jax.lax.cond(use_high_noise, high_branch, low_branch, operand=None)
517+
518+
new_latents, new_scheduler_state = scheduler.step(current_scheduler_state, noise_pred, t, latents_out).to_tuple()
519+
520+
return (new_latents, new_scheduler_state), None
521+
522+
initial_carry = (latents, scheduler_state)
523+
524+
final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps)
525+
526+
final_latents, _ = final_carry
527+
return final_latents
528+
474529
profiler = None
475530
for step in range(num_inference_steps):
476531
if config and max_utils.profiler_enabled(config) and step == first_profiling_step:

0 commit comments

Comments
 (0)