@@ -226,6 +226,7 @@ def __init__(
226226 time_shift_type : Literal ["exponential" ] = "exponential" ,
227227 sigma_min : Optional [float ] = None ,
228228 sigma_max : Optional [float ] = None ,
229+ shift_terminal : Optional [float ] = None ,
229230 ) -> None :
230231 if self .config .use_beta_sigmas and not is_scipy_available ():
231232 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
@@ -245,6 +246,8 @@ def __init__(
245246 self .betas = betas_for_alpha_bar (num_train_timesteps )
246247 else :
247248 raise NotImplementedError (f"{ beta_schedule } is not implemented for { self .__class__ } " )
249+ if shift_terminal is not None and not use_flow_sigmas :
250+ raise ValueError ("`shift_terminal` is only supported when `use_flow_sigmas=True`." )
248251
249252 if rescale_betas_zero_snr :
250253 self .betas = rescale_zero_terminal_snr (self .betas )
@@ -313,8 +316,12 @@ def set_begin_index(self, begin_index: int = 0) -> None:
313316 self ._begin_index = begin_index
314317
315318 def set_timesteps (
316- self , num_inference_steps : int , device : Optional [Union [str , torch .device ]] = None , mu : Optional [float ] = None
317- ) -> None :
319+ self ,
320+ num_inference_steps : Optional [int ] = None ,
321+ device : Union [str , torch .device ] = None ,
322+ sigmas : Optional [List [float ]] = None ,
323+ mu : Optional [float ] = None ,
324+ ):
318325 """
319326 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
320327
@@ -323,13 +330,24 @@ def set_timesteps(
323330 The number of diffusion steps used when generating samples with a pre-trained model.
324331 device (`str` or `torch.device`, *optional*):
325332 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
333+ sigmas (`List[float]`, *optional*):
334+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
335+ automatically.
326336 mu (`float`, *optional*):
327337 Optional mu parameter for dynamic shifting when using exponential time shift type.
328338 """
339+ if self .config .use_dynamic_shifting and mu is None :
340+ raise ValueError ("`mu` must be passed when `use_dynamic_shifting` is set to be `True`" )
341+
342+ if sigmas is not None :
343+ if not self .config .use_flow_sigmas :
344+ raise ValueError (
345+ "Passing `sigmas` is only supported when `use_flow_sigmas=True`. "
346+ "Please set `use_flow_sigmas=True` during scheduler initialization."
347+ )
348+ num_inference_steps = len (sigmas )
349+
329350 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
330- if mu is not None :
331- assert self .config .use_dynamic_shifting and self .config .time_shift_type == "exponential"
332- self .config .flow_shift = np .exp (mu )
333351 if self .config .timestep_spacing == "linspace" :
334352 timesteps = (
335353 np .linspace (0 , self .config .num_train_timesteps - 1 , num_inference_steps + 1 )
@@ -354,8 +372,9 @@ def set_timesteps(
354372 f"{ self .config .timestep_spacing } is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
355373 )
356374
357- sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
358375 if self .config .use_karras_sigmas :
376+ if sigmas is None :
377+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
359378 log_sigmas = np .log (sigmas )
360379 sigmas = np .flip (sigmas ).copy ()
361380 sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
@@ -375,6 +394,8 @@ def set_timesteps(
375394 )
376395 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
377396 elif self .config .use_exponential_sigmas :
397+ if sigmas is None :
398+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
378399 log_sigmas = np .log (sigmas )
379400 sigmas = np .flip (sigmas ).copy ()
380401 sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
@@ -389,6 +410,8 @@ def set_timesteps(
389410 )
390411 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
391412 elif self .config .use_beta_sigmas :
413+ if sigmas is None :
414+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
392415 log_sigmas = np .log (sigmas )
393416 sigmas = np .flip (sigmas ).copy ()
394417 sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
@@ -403,9 +426,18 @@ def set_timesteps(
403426 )
404427 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
405428 elif self .config .use_flow_sigmas :
406- alphas = np .linspace (1 , 1 / self .config .num_train_timesteps , num_inference_steps + 1 )
407- sigmas = 1.0 - alphas
408- sigmas = np .flip (self .config .flow_shift * sigmas / (1 + (self .config .flow_shift - 1 ) * sigmas ))[:- 1 ].copy ()
429+ if sigmas is None :
430+ sigmas = np .linspace (1 , 1 / self .config .num_train_timesteps , num_inference_steps + 1 )[:- 1 ]
431+ if self .config .use_dynamic_shifting :
432+ sigmas = self .time_shift (mu , 1.0 , sigmas )
433+ else :
434+ sigmas = self .config .flow_shift * sigmas / (1 + (self .config .flow_shift - 1 ) * sigmas )
435+ if self .config .shift_terminal :
436+ sigmas = self .stretch_shift_to_terminal (sigmas )
437+ eps = 1e-6
438+ if np .fabs (sigmas [0 ] - 1 ) < eps :
439+ # to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
440+ sigmas [0 ] -= eps
409441 timesteps = (sigmas * self .config .num_train_timesteps ).copy ()
410442 if self .config .final_sigmas_type == "sigma_min" :
411443 sigma_last = sigmas [- 1 ]
@@ -417,6 +449,8 @@ def set_timesteps(
417449 )
418450 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
419451 else :
452+ if sigmas is None :
453+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
420454 sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
421455 if self .config .final_sigmas_type == "sigma_min" :
422456 sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
@@ -446,6 +480,43 @@ def set_timesteps(
446480 self ._begin_index = None
447481 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
448482
483+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
484+ def time_shift (self , mu : float , sigma : float , t : torch .Tensor ):
485+ if self .config .time_shift_type == "exponential" :
486+ return self ._time_shift_exponential (mu , sigma , t )
487+ elif self .config .time_shift_type == "linear" :
488+ return self ._time_shift_linear (mu , sigma , t )
489+
490+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
491+ def stretch_shift_to_terminal (self , t : torch .Tensor ) -> torch .Tensor :
492+ r"""
493+ Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
494+ value.
495+
496+ Reference:
497+ https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
498+
499+ Args:
500+ t (`torch.Tensor`):
501+ A tensor of timesteps to be stretched and shifted.
502+
503+ Returns:
504+ `torch.Tensor`:
505+ A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
506+ """
507+ one_minus_z = 1 - t
508+ scale_factor = one_minus_z [- 1 ] / (1 - self .config .shift_terminal )
509+ stretched_t = 1 - (one_minus_z / scale_factor )
510+ return stretched_t
511+
512+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
513+ def _time_shift_exponential (self , mu , sigma , t ):
514+ return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
515+
516+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
517+ def _time_shift_linear (self , mu , sigma , t ):
518+ return mu / (mu + (1 / t - 1 ) ** sigma )
519+
449520 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
450521 def _threshold_sample (self , sample : torch .Tensor ) -> torch .Tensor :
451522 """
0 commit comments