Skip to content

Commit e1b6e7c

Browse files
Merge pull request #368 from syhuang22:sen_cache_I2V
PiperOrigin-RevId: 892415099
2 parents 1d5d773 + 0d96f28 commit e1b6e7c

5 files changed

Lines changed: 411 additions & 1 deletion

File tree

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,10 @@ num_frames: 81
286286
guidance_scale: 5.0
287287
flow_shift: 5.0
288288

289-
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
289+
# Diffusion CFG cache (FasterCache-style)
290290
use_cfg_cache: False
291+
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
292+
use_sen_cache: False
291293
use_magcache: False
292294
magcache_thresh: 0.12
293295
magcache_K: 2

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ boundary_ratio: 0.875
300300

301301
# Diffusion CFG cache (FasterCache-style)
302302
use_cfg_cache: False
303+
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
304+
use_sen_cache: False
303305

304306
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
305307
guidance_rescale: 0.0

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
117117
guidance_scale_low=config.guidance_scale_low,
118118
guidance_scale_high=config.guidance_scale_high,
119119
use_cfg_cache=config.use_cfg_cache,
120+
use_sen_cache=config.use_sen_cache,
120121
)
121122
else:
122123
raise ValueError(f"Unsupported model_name for I2V in config: {model_key}")

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,25 @@ def __call__(
167167
output_type: Optional[str] = "np",
168168
rng: Optional[jax.Array] = None,
169169
use_cfg_cache: bool = False,
170+
use_sen_cache: bool = False,
170171
):
172+
if use_cfg_cache and use_sen_cache:
173+
raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.")
174+
171175
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
172176
raise ValueError(
173177
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
174178
f"(got {guidance_scale_low}, {guidance_scale_high}). "
175179
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
176180
)
177181

182+
if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
183+
raise ValueError(
184+
f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
185+
f"(got {guidance_scale_low}, {guidance_scale_high}). "
186+
"SenCache requires classifier-free guidance to be enabled for both transformer phases."
187+
)
188+
178189
height = height or self.config.height
179190
width = width or self.config.width
180191
num_frames = num_frames or self.config.num_frames
@@ -264,6 +275,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
264275
scheduler=self.scheduler,
265276
image_embeds=image_embeds,
266277
use_cfg_cache=use_cfg_cache,
278+
use_sen_cache=use_sen_cache,
267279
height=height,
268280
)
269281

@@ -308,11 +320,128 @@ def run_inference_2_2_i2v(
308320
scheduler: FlaxUniPCMultistepScheduler,
309321
scheduler_state,
310322
use_cfg_cache: bool = False,
323+
use_sen_cache: bool = False,
311324
height: int = 480,
312325
):
313326
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
314327
bsz = latents.shape[0]
315328

329+
# ── SenCache path (arXiv:2602.24208) ──
330+
if use_sen_cache and do_classifier_free_guidance:
331+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
332+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
333+
334+
# SenCache hyperparameters
335+
sen_epsilon = 0.1
336+
max_reuse = 3
337+
warmup_steps = 1
338+
nocache_start_ratio = 0.3
339+
nocache_end_ratio = 0.1
340+
alpha_x, alpha_t = 1.0, 1.0
341+
342+
nocache_start = int(num_inference_steps * nocache_start_ratio)
343+
nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio))
344+
num_train_timesteps = float(scheduler.config.num_train_timesteps)
345+
346+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
347+
if image_embeds is not None:
348+
image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0)
349+
else:
350+
image_embeds_combined = None
351+
condition_doubled = jnp.concatenate([condition] * 2)
352+
353+
# SenCache state
354+
ref_noise_pred = None
355+
ref_latent = None
356+
ref_timestep = 0.0
357+
accum_dx = 0.0
358+
accum_dt = 0.0
359+
reuse_count = 0
360+
cache_count = 0
361+
362+
for step in range(num_inference_steps):
363+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
364+
t_float = float(timesteps_np[step]) / num_train_timesteps
365+
366+
if step_uses_high[step]:
367+
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
368+
guidance_scale = guidance_scale_high
369+
else:
370+
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
371+
guidance_scale = guidance_scale_low
372+
373+
is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1]
374+
force_compute = (
375+
step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None
376+
)
377+
378+
if force_compute:
379+
latents_doubled = jnp.concatenate([latents, latents], axis=0)
380+
latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1)
381+
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
382+
timestep = jnp.broadcast_to(t, bsz * 2)
383+
noise_pred, _, _ = transformer_forward_pass_full_cfg(
384+
graphdef,
385+
state,
386+
rest,
387+
latent_model_input,
388+
timestep,
389+
prompt_embeds_combined,
390+
guidance_scale=guidance_scale,
391+
encoder_hidden_states_image=image_embeds_combined,
392+
)
393+
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
394+
ref_noise_pred = noise_pred
395+
ref_latent = latents
396+
ref_timestep = t_float
397+
accum_dx = 0.0
398+
accum_dt = 0.0
399+
reuse_count = 0
400+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
401+
continue
402+
403+
dx_norm = float(jnp.sqrt(jnp.mean((latents - ref_latent) ** 2)))
404+
dt = abs(t_float - ref_timestep)
405+
accum_dx += dx_norm
406+
accum_dt += dt
407+
408+
score = alpha_x * accum_dx + alpha_t * accum_dt
409+
410+
if score <= sen_epsilon and reuse_count < max_reuse:
411+
noise_pred = ref_noise_pred
412+
reuse_count += 1
413+
cache_count += 1
414+
else:
415+
latents_doubled = jnp.concatenate([latents, latents], axis=0)
416+
latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1)
417+
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
418+
timestep = jnp.broadcast_to(t, bsz * 2)
419+
noise_pred, _, _ = transformer_forward_pass_full_cfg(
420+
graphdef,
421+
state,
422+
rest,
423+
latent_model_input,
424+
timestep,
425+
prompt_embeds_combined,
426+
guidance_scale=guidance_scale,
427+
encoder_hidden_states_image=image_embeds_combined,
428+
)
429+
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
430+
ref_noise_pred = noise_pred
431+
ref_latent = latents
432+
ref_timestep = t_float
433+
accum_dx = 0.0
434+
accum_dt = 0.0
435+
reuse_count = 0
436+
437+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
438+
439+
print(
440+
f"[SenCache] Cached {cache_count}/{num_inference_steps} steps "
441+
f"({100*cache_count/num_inference_steps:.1f}% cache ratio)"
442+
)
443+
return latents
444+
316445
# ── CFG cache path ──
317446
if use_cfg_cache and do_classifier_free_guidance:
318447
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)

0 commit comments

Comments
 (0)