@@ -105,7 +105,6 @@ def rescale_zero_terminal_snr(alphas_cumprod):
105105 """
106106 Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
107107
108-
109108 Args:
110109 betas (`torch.Tensor`):
111110 the betas that the scheduler is being initialized with.
@@ -175,11 +174,14 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
175174 The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
176175 timestep_spacing (`str`, defaults to `"leading"`):
177176 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
178- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
177+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Choose from
178+ `leading`, `linspace` or `trailing`.
179179 rescale_betas_zero_snr (`bool`, defaults to `False`):
180180 Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
181181 dark samples instead of limiting it to samples with medium brightness. Loosely related to
182182 [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
183+ snr_shift_scale (`float`, defaults to 3.0):
184+ Shift scale for SNR.
183185 """
184186
185187 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -191,15 +193,15 @@ def __init__(
191193 num_train_timesteps : int = 1000 ,
192194 beta_start : float = 0.00085 ,
193195 beta_end : float = 0.0120 ,
194- beta_schedule : str = "scaled_linear" ,
196+ beta_schedule : Literal [ "linear" , "scaled_linear" , "squaredcos_cap_v2" ] = "scaled_linear" ,
195197 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
196198 clip_sample : bool = True ,
197199 set_alpha_to_one : bool = True ,
198200 steps_offset : int = 0 ,
199- prediction_type : str = "epsilon" ,
201+ prediction_type : Literal [ "epsilon" , "sample" , "v_prediction" ] = "epsilon" ,
200202 clip_sample_range : float = 1.0 ,
201203 sample_max_value : float = 1.0 ,
202- timestep_spacing : str = "leading" ,
204+ timestep_spacing : Literal [ "leading" , "linspace" , "trailing" ] = "leading" ,
203205 rescale_betas_zero_snr : bool = False ,
204206 snr_shift_scale : float = 3.0 ,
205207 ):
@@ -209,7 +211,15 @@ def __init__(
209211 self .betas = torch .linspace (beta_start , beta_end , num_train_timesteps , dtype = torch .float32 )
210212 elif beta_schedule == "scaled_linear" :
211213 # this schedule is very specific to the latent diffusion model.
212- self .betas = torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = torch .float64 ) ** 2
214+ self .betas = (
215+ torch .linspace (
216+ beta_start ** 0.5 ,
217+ beta_end ** 0.5 ,
218+ num_train_timesteps ,
219+ dtype = torch .float64 ,
220+ )
221+ ** 2
222+ )
213223 elif beta_schedule == "squaredcos_cap_v2" :
214224 # Glide cosine schedule
215225 self .betas = betas_for_alpha_bar (num_train_timesteps )
@@ -266,13 +276,20 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
266276 """
267277 return sample
268278
269- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
279+ def set_timesteps (
280+ self ,
281+ num_inference_steps : int ,
282+ device : Optional [Union [str , torch .device ]] = None ,
283+ ):
270284 """
271285 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
272286
273287 Args:
274288 num_inference_steps (`int`):
275289 The number of diffusion steps used when generating samples with a pre-trained model.
290+ device (`str` or `torch.device`, *optional*):
291+ The device to which the timesteps should be moved to. If `None` (the default), the timesteps are not
292+ moved.
276293 """
277294
278295 if num_inference_steps > self .config .num_train_timesteps :
@@ -311,7 +328,27 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
311328
312329 self .timesteps = torch .from_numpy (timesteps ).to (device )
313330
314- def get_variables (self , alpha_prod_t , alpha_prod_t_prev , alpha_prod_t_back = None ):
331+ def get_variables (
332+ self ,
333+ alpha_prod_t : torch .Tensor ,
334+ alpha_prod_t_prev : torch .Tensor ,
335+ alpha_prod_t_back : Optional [torch .Tensor ] = None ,
336+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], torch .Tensor , torch .Tensor ]:
337+ """
338+ Compute the variables used for DPM-Solver++ (2M) referencing the original implementation.
339+
340+ Args:
341+ alpha_prod_t (`torch.Tensor`):
342+ The cumulative product of alphas at the current timestep.
343+ alpha_prod_t_prev (`torch.Tensor`):
344+ The cumulative product of alphas at the previous timestep.
345+ alpha_prod_t_back (`torch.Tensor`, *optional*):
346+ The cumulative product of alphas at the timestep before the previous timestep.
347+
348+ Returns:
349+ `tuple`:
350+ A tuple containing the variables `h`, `r`, `lamb`, `lamb_next`.
351+ """
315352 lamb = ((alpha_prod_t / (1 - alpha_prod_t )) ** 0.5 ).log ()
316353 lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev )) ** 0.5 ).log ()
317354 h = lamb_next - lamb
@@ -324,7 +361,36 @@ def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None)
324361 else :
325362 return h , None , lamb , lamb_next
326363
327- def get_mult (self , h , r , alpha_prod_t , alpha_prod_t_prev , alpha_prod_t_back ):
364+ def get_mult (
365+ self ,
366+ h : torch .Tensor ,
367+ r : Optional [torch .Tensor ],
368+ alpha_prod_t : torch .Tensor ,
369+ alpha_prod_t_prev : torch .Tensor ,
370+ alpha_prod_t_back : Optional [torch .Tensor ] = None ,
371+ ) -> Union [
372+ Tuple [torch .Tensor , torch .Tensor ],
373+ Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ],
374+ ]:
375+ """
376+ Compute the multipliers for the previous sample and the predicted original sample.
377+
378+ Args:
379+ h (`torch.Tensor`):
380+ The log-SNR difference.
381+ r (`torch.Tensor`):
382+ The ratio of log-SNR differences.
383+ alpha_prod_t (`torch.Tensor`):
384+ The cumulative product of alphas at the current timestep.
385+ alpha_prod_t_prev (`torch.Tensor`):
386+ The cumulative product of alphas at the previous timestep.
387+ alpha_prod_t_back (`torch.Tensor`, *optional*):
388+ The cumulative product of alphas at the timestep before the previous timestep.
389+
390+ Returns:
391+ `tuple`:
392+ A tuple containing the multipliers.
393+ """
328394 mult1 = ((1 - alpha_prod_t_prev ) / (1 - alpha_prod_t )) ** 0.5 * (- h ).exp ()
329395 mult2 = (- 2 * h ).expm1 () * alpha_prod_t_prev ** 0.5
330396
@@ -338,13 +404,13 @@ def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
338404 def step (
339405 self ,
340406 model_output : torch .Tensor ,
341- old_pred_original_sample : torch .Tensor ,
407+ old_pred_original_sample : Optional [ torch .Tensor ] ,
342408 timestep : int ,
343409 timestep_back : int ,
344410 sample : torch .Tensor ,
345411 eta : float = 0.0 ,
346412 use_clipped_model_output : bool = False ,
347- generator = None ,
413+ generator : Optional [ torch . Generator ] = None ,
348414 variance_noise : Optional [torch .Tensor ] = None ,
349415 return_dict : bool = False ,
350416 ) -> Union [DDIMSchedulerOutput , Tuple ]:
@@ -355,8 +421,12 @@ def step(
355421 Args:
356422 model_output (`torch.Tensor`):
357423 The direct output from learned diffusion model.
358- timestep (`float`):
424+ old_pred_original_sample (`torch.Tensor`):
425+ The predicted original sample from the previous timestep.
426+ timestep (`int`):
359427 The current discrete timestep in the diffusion chain.
428+ timestep_back (`int`):
429+ The timestep to look back to.
360430 sample (`torch.Tensor`):
361431 A current instance of a sample created by the diffusion process.
362432 eta (`float`):
@@ -436,7 +506,12 @@ def step(
436506 return prev_sample , pred_original_sample
437507 else :
438508 denoised_d = mult [2 ] * pred_original_sample - mult [3 ] * old_pred_original_sample
439- noise = randn_tensor (sample .shape , generator = generator , device = sample .device , dtype = sample .dtype )
509+ noise = randn_tensor (
510+ sample .shape ,
511+ generator = generator ,
512+ device = sample .device ,
513+ dtype = sample .dtype ,
514+ )
440515 x_advanced = mult [0 ] * sample - mult [1 ] * denoised_d + mult_noise * noise
441516
442517 prev_sample = x_advanced
@@ -524,5 +599,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
524599 velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
525600 return velocity
526601
527- def __len__ (self ):
602+ def __len__ (self ) -> int :
528603 return self .config .num_train_timesteps
0 commit comments