@@ -167,14 +167,25 @@ def __call__(
167167 output_type : Optional [str ] = "np" ,
168168 rng : Optional [jax .Array ] = None ,
169169 use_cfg_cache : bool = False ,
170+ use_sen_cache : bool = False ,
170171 ):
172+ if use_cfg_cache and use_sen_cache :
173+ raise ValueError ("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one." )
174+
171175 if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
172176 raise ValueError (
173177 f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
174178 f"(got { guidance_scale_low } , { guidance_scale_high } ). "
175179 "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
176180 )
177181
182+ if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
183+ raise ValueError (
184+ f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
185+ f"(got { guidance_scale_low } , { guidance_scale_high } ). "
186+ "SenCache requires classifier-free guidance to be enabled for both transformer phases."
187+ )
188+
178189 height = height or self .config .height
179190 width = width or self .config .width
180191 num_frames = num_frames or self .config .num_frames
@@ -264,6 +275,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
264275 scheduler = self .scheduler ,
265276 image_embeds = image_embeds ,
266277 use_cfg_cache = use_cfg_cache ,
278+ use_sen_cache = use_sen_cache ,
267279 height = height ,
268280 )
269281
@@ -308,11 +320,128 @@ def run_inference_2_2_i2v(
308320 scheduler : FlaxUniPCMultistepScheduler ,
309321 scheduler_state ,
310322 use_cfg_cache : bool = False ,
323+ use_sen_cache : bool = False ,
311324 height : int = 480 ,
312325):
313326 do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
314327 bsz = latents .shape [0 ]
315328
329+ # ── SenCache path (arXiv:2602.24208) ──
330+ if use_sen_cache and do_classifier_free_guidance :
331+ timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
332+ step_uses_high = [bool (timesteps_np [s ] >= boundary ) for s in range (num_inference_steps )]
333+
334+ # SenCache hyperparameters
335+ sen_epsilon = 0.1
336+ max_reuse = 3
337+ warmup_steps = 1
338+ nocache_start_ratio = 0.3
339+ nocache_end_ratio = 0.1
340+ alpha_x , alpha_t = 1.0 , 1.0
341+
342+ nocache_start = int (num_inference_steps * nocache_start_ratio )
343+ nocache_end_begin = int (num_inference_steps * (1.0 - nocache_end_ratio ))
344+ num_train_timesteps = float (scheduler .config .num_train_timesteps )
345+
346+ prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
347+ if image_embeds is not None :
348+ image_embeds_combined = jnp .concatenate ([image_embeds , image_embeds ], axis = 0 )
349+ else :
350+ image_embeds_combined = None
351+ condition_doubled = jnp .concatenate ([condition ] * 2 )
352+
353+ # SenCache state
354+ ref_noise_pred = None
355+ ref_latent = None
356+ ref_timestep = 0.0
357+ accum_dx = 0.0
358+ accum_dt = 0.0
359+ reuse_count = 0
360+ cache_count = 0
361+
362+ for step in range (num_inference_steps ):
363+ t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
364+ t_float = float (timesteps_np [step ]) / num_train_timesteps
365+
366+ if step_uses_high [step ]:
367+ graphdef , state , rest = high_noise_graphdef , high_noise_state , high_noise_rest
368+ guidance_scale = guidance_scale_high
369+ else :
370+ graphdef , state , rest = low_noise_graphdef , low_noise_state , low_noise_rest
371+ guidance_scale = guidance_scale_low
372+
373+ is_boundary = step > 0 and step_uses_high [step ] != step_uses_high [step - 1 ]
374+ force_compute = (
375+ step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None
376+ )
377+
378+ if force_compute :
379+ latents_doubled = jnp .concatenate ([latents , latents ], axis = 0 )
380+ latent_model_input = jnp .concatenate ([latents_doubled , condition_doubled ], axis = - 1 )
381+ latent_model_input = jnp .transpose (latent_model_input , (0 , 4 , 1 , 2 , 3 ))
382+ timestep = jnp .broadcast_to (t , bsz * 2 )
383+ noise_pred , _ , _ = transformer_forward_pass_full_cfg (
384+ graphdef ,
385+ state ,
386+ rest ,
387+ latent_model_input ,
388+ timestep ,
389+ prompt_embeds_combined ,
390+ guidance_scale = guidance_scale ,
391+ encoder_hidden_states_image = image_embeds_combined ,
392+ )
393+ noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
394+ ref_noise_pred = noise_pred
395+ ref_latent = latents
396+ ref_timestep = t_float
397+ accum_dx = 0.0
398+ accum_dt = 0.0
399+ reuse_count = 0
400+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
401+ continue
402+
403+ dx_norm = float (jnp .sqrt (jnp .mean ((latents - ref_latent ) ** 2 )))
404+ dt = abs (t_float - ref_timestep )
405+ accum_dx += dx_norm
406+ accum_dt += dt
407+
408+ score = alpha_x * accum_dx + alpha_t * accum_dt
409+
410+ if score <= sen_epsilon and reuse_count < max_reuse :
411+ noise_pred = ref_noise_pred
412+ reuse_count += 1
413+ cache_count += 1
414+ else :
415+ latents_doubled = jnp .concatenate ([latents , latents ], axis = 0 )
416+ latent_model_input = jnp .concatenate ([latents_doubled , condition_doubled ], axis = - 1 )
417+ latent_model_input = jnp .transpose (latent_model_input , (0 , 4 , 1 , 2 , 3 ))
418+ timestep = jnp .broadcast_to (t , bsz * 2 )
419+ noise_pred , _ , _ = transformer_forward_pass_full_cfg (
420+ graphdef ,
421+ state ,
422+ rest ,
423+ latent_model_input ,
424+ timestep ,
425+ prompt_embeds_combined ,
426+ guidance_scale = guidance_scale ,
427+ encoder_hidden_states_image = image_embeds_combined ,
428+ )
429+ noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
430+ ref_noise_pred = noise_pred
431+ ref_latent = latents
432+ ref_timestep = t_float
433+ accum_dx = 0.0
434+ accum_dt = 0.0
435+ reuse_count = 0
436+
437+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
438+
439+ print (
440+ f"[SenCache] Cached { cache_count } /{ num_inference_steps } steps "
441+ f"({ 100 * cache_count / num_inference_steps :.1f} % cache ratio)"
442+ )
443+ return latents
444+
316445 # ── CFG cache path ──
317446 if use_cfg_cache and do_classifier_free_guidance :
318447 timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
0 commit comments