1515# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
1616
1717import math
18- from typing import List , Optional , Tuple , Union
18+ from typing import List , Literal , Optional , Tuple , Union
1919
2020import numpy as np
2121import torch
@@ -51,13 +51,15 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
5151 schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
5252 num_train_timesteps (`int`, defaults to 1000):
5353 The number of diffusion steps to train the model.
54- solver_order (`int`, defaults to 2):
55- The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
56- sampling, and `solver_order=3` for unconditional sampling.
5754 prediction_type (`str`, defaults to `epsilon`, *optional*):
5855 Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
5956 `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
6057 Video](https://huggingface.co/papers/2210.02303) paper).
58+ rho (`float`, *optional*, defaults to 7.0):
59+ The rho parameter in the Karras sigma schedule. This was set to 7.0 in the EDM paper [1].
60+ solver_order (`int`, defaults to 2):
61+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
62+ sampling, and `solver_order=3` for unconditional sampling.
6163 thresholding (`bool`, defaults to `False`):
6264 Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
6365 as Stable Diffusion.
@@ -94,19 +96,19 @@ def __init__(
9496 sigma_min : float = 0.002 ,
9597 sigma_max : float = 80.0 ,
9698 sigma_data : float = 0.5 ,
97- sigma_schedule : str = "karras" ,
99+ sigma_schedule : Literal [ "karras" , "exponential" ] = "karras" ,
98100 num_train_timesteps : int = 1000 ,
99- prediction_type : str = "epsilon" ,
101+ prediction_type : Literal [ "epsilon" , "sample" , "v_prediction" ] = "epsilon" ,
100102 rho : float = 7.0 ,
101103 solver_order : int = 2 ,
102104 thresholding : bool = False ,
103105 dynamic_thresholding_ratio : float = 0.995 ,
104106 sample_max_value : float = 1.0 ,
105- algorithm_type : str = "dpmsolver++" ,
106- solver_type : str = "midpoint" ,
107+ algorithm_type : Literal [ "dpmsolver++" , "sde-dpmsolver++" ] = "dpmsolver++" ,
108+ solver_type : Literal [ "midpoint" , "heun" ] = "midpoint" ,
107109 lower_order_final : bool = True ,
108110 euler_at_final : bool = False ,
109- final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
111+ final_sigmas_type : Optional [Literal [ "zero" , "sigma_min" ] ] = "zero" , # "zero", "sigma_min"
110112 ):
111113 # settings for DPM-Solver
112114 if algorithm_type not in ["dpmsolver++" , "sde-dpmsolver++" ]:
@@ -145,19 +147,19 @@ def __init__(
145147 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
146148
147149 @property
148- def init_noise_sigma (self ):
150+ def init_noise_sigma (self ) -> float :
149151 # standard deviation of the initial noise distribution
150152 return (self .config .sigma_max ** 2 + 1 ) ** 0.5
151153
152154 @property
153- def step_index (self ):
155+ def step_index (self ) -> int :
154156 """
155157 The index counter for current timestep. It will increase 1 after each scheduler step.
156158 """
157159 return self ._step_index
158160
159161 @property
160- def begin_index (self ):
162+ def begin_index (self ) -> int :
161163 """
162164 The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
163165 """
@@ -274,7 +276,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
274276 self .is_scale_input_called = True
275277 return sample
276278
277- def set_timesteps (self , num_inference_steps : int = None , device : Union [str , torch .device ] = None ):
279+ def set_timesteps (
280+ self ,
281+ num_inference_steps : int = None ,
282+ device : Optional [Union [str , torch .device ]] = None ,
283+ ):
278284 """
279285 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
280286
@@ -460,13 +466,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
460466 def _sigma_to_alpha_sigma_t (self , sigma ):
461467 alpha_t = torch .tensor (1 ) # Inputs are pre-scaled before going into unet, so alpha_t = 1
462468 sigma_t = sigma
463-
464469 return alpha_t , sigma_t
465470
466471 def convert_model_output (
467472 self ,
468473 model_output : torch .Tensor ,
469- sample : torch .Tensor = None ,
474+ sample : torch .Tensor ,
470475 ) -> torch .Tensor :
471476 """
472477 Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
@@ -497,7 +502,7 @@ def convert_model_output(
497502 def dpm_solver_first_order_update (
498503 self ,
499504 model_output : torch .Tensor ,
500- sample : torch .Tensor = None ,
505+ sample : torch .Tensor ,
501506 noise : Optional [torch .Tensor ] = None ,
502507 ) -> torch .Tensor :
503508 """
@@ -508,6 +513,8 @@ def dpm_solver_first_order_update(
508513 The direct output from the learned diffusion model.
509514 sample (`torch.Tensor`):
510515 A current instance of a sample created by the diffusion process.
516+ noise (`torch.Tensor`, *optional*):
517+ The noise tensor to add to the original samples.
511518
512519 Returns:
513520 `torch.Tensor`:
@@ -538,7 +545,7 @@ def dpm_solver_first_order_update(
538545 def multistep_dpm_solver_second_order_update (
539546 self ,
540547 model_output_list : List [torch .Tensor ],
541- sample : torch .Tensor = None ,
548+ sample : torch .Tensor ,
542549 noise : Optional [torch .Tensor ] = None ,
543550 ) -> torch .Tensor :
544551 """
@@ -549,6 +556,8 @@ def multistep_dpm_solver_second_order_update(
549556 The direct outputs from learned diffusion model at current and latter timesteps.
550557 sample (`torch.Tensor`):
551558 A current instance of a sample created by the diffusion process.
559+ noise (`torch.Tensor`, *optional*):
560+ The noise tensor to add to the original samples.
552561
553562 Returns:
554563 `torch.Tensor`:
@@ -609,7 +618,7 @@ def multistep_dpm_solver_second_order_update(
609618 def multistep_dpm_solver_third_order_update (
610619 self ,
611620 model_output_list : List [torch .Tensor ],
612- sample : torch .Tensor = None ,
621+ sample : torch .Tensor ,
613622 ) -> torch .Tensor :
614623 """
615624 One step for the third-order multistep DPMSolver.
@@ -698,7 +707,7 @@ def index_for_timestep(
698707 return step_index
699708
700709 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
701- def _init_step_index (self , timestep ) :
710+ def _init_step_index (self , timestep : Union [ int , torch . Tensor ]) -> None :
702711 """
703712 Initialize the step_index counter for the scheduler.
704713
@@ -719,7 +728,7 @@ def step(
719728 model_output : torch .Tensor ,
720729 timestep : Union [int , torch .Tensor ],
721730 sample : torch .Tensor ,
722- generator = None ,
731+ generator : Optional [ torch . Generator ] = None ,
723732 return_dict : bool = True ,
724733 ) -> Union [SchedulerOutput , Tuple ]:
725734 """
@@ -860,5 +869,5 @@ def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[flo
860869 c_in = 1 / ((sigma ** 2 + self .config .sigma_data ** 2 ) ** 0.5 )
861870 return c_in
862871
863- def __len__ (self ):
872+ def __len__ (self ) -> int :
864873 return self .config .num_train_timesteps
0 commit comments