@@ -486,6 +486,15 @@ def __call__(
486486 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
487487 self ._num_timesteps = len (timesteps )
488488
489+ # We set the index here to remove DtoH sync, helpful especially during compilation.
490+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
491+ self .scheduler .set_begin_index (0 )
492+
493+ if self .do_classifier_free_guidance and self ._cfg_truncation is not None and float (self ._cfg_truncation ) <= 1 :
494+ _precomputed_t_norms = ((1000 - timesteps .float ()) / 1000 ).tolist ()
495+ else :
496+ _precomputed_t_norms = None
497+
489498 # 6. Denoising loop
490499 with self .progress_bar (total = num_inference_steps ) as progress_bar :
491500 for i , t in enumerate (timesteps ):
@@ -495,17 +504,9 @@ def __call__(
495504 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
496505 timestep = t .expand (latents .shape [0 ])
497506 timestep = (1000 - timestep ) / 1000
498- # Normalized time for time-aware config (0 at start, 1 at end)
499- t_norm = timestep [0 ].item ()
500-
501- # Handle cfg truncation
502507 current_guidance_scale = self .guidance_scale
503- if (
504- self .do_classifier_free_guidance
505- and self ._cfg_truncation is not None
506- and float (self ._cfg_truncation ) <= 1
507- ):
508- if t_norm > self ._cfg_truncation :
508+ if _precomputed_t_norms is not None :
509+ if _precomputed_t_norms [i ] > self ._cfg_truncation :
509510 current_guidance_scale = 0.0
510511
511512 # Run CFG only if configured AND scale is non-zero
0 commit comments