@@ -245,21 +245,42 @@ def __init__(
245245 ):
246246 if self .config .use_beta_sigmas and not is_scipy_available ():
247247 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
248- if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
248+ if (
249+ sum (
250+ [
251+ self .config .use_beta_sigmas ,
252+ self .config .use_exponential_sigmas ,
253+ self .config .use_karras_sigmas ,
254+ ]
255+ )
256+ > 1
257+ ):
249258 raise ValueError (
250259 "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
251260 )
252261 if algorithm_type in ["dpmsolver" , "sde-dpmsolver" ]:
253262 deprecation_message = f"algorithm_type { algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
254- deprecate ("algorithm_types dpmsolver and sde-dpmsolver" , "1.0.0" , deprecation_message )
263+ deprecate (
264+ "algorithm_types dpmsolver and sde-dpmsolver" ,
265+ "1.0.0" ,
266+ deprecation_message ,
267+ )
255268
256269 if trained_betas is not None :
257270 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
258271 elif beta_schedule == "linear" :
259272 self .betas = torch .linspace (beta_start , beta_end , num_train_timesteps , dtype = torch .float32 )
260273 elif beta_schedule == "scaled_linear" :
261274 # this schedule is very specific to the latent diffusion model.
262- self .betas = torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = torch .float32 ) ** 2
275+ self .betas = (
276+ torch .linspace (
277+ beta_start ** 0.5 ,
278+ beta_end ** 0.5 ,
279+ num_train_timesteps ,
280+ dtype = torch .float32 ,
281+ )
282+ ** 2
283+ )
263284 elif beta_schedule == "squaredcos_cap_v2" :
264285 # Glide cosine schedule
265286 self .betas = betas_for_alpha_bar (num_train_timesteps )
@@ -287,7 +308,12 @@ def __init__(
287308 self .init_noise_sigma = 1.0
288309
289310 # settings for DPM-Solver
290- if algorithm_type not in ["dpmsolver" , "dpmsolver++" , "sde-dpmsolver" , "sde-dpmsolver++" ]:
311+ if algorithm_type not in [
312+ "dpmsolver" ,
313+ "dpmsolver++" ,
314+ "sde-dpmsolver" ,
315+ "sde-dpmsolver++" ,
316+ ]:
291317 if algorithm_type == "deis" :
292318 self .register_to_config (algorithm_type = "dpmsolver++" )
293319 else :
@@ -724,7 +750,7 @@ def convert_model_output(
724750 self ,
725751 model_output : torch .Tensor ,
726752 * args ,
727- sample : torch .Tensor = None ,
753+ sample : Optional [ torch .Tensor ] = None ,
728754 ** kwargs ,
729755 ) -> torch .Tensor :
730756 """
@@ -738,7 +764,7 @@ def convert_model_output(
738764 Args:
739765 model_output (`torch.Tensor`):
740766 The direct output from the learned diffusion model.
741- sample (`torch.Tensor`):
767+ sample (`torch.Tensor`, *optional* ):
742768 A current instance of a sample created by the diffusion process.
743769
744770 Returns:
@@ -822,7 +848,7 @@ def dpm_solver_first_order_update(
822848 self ,
823849 model_output : torch .Tensor ,
824850 * args ,
825- sample : torch .Tensor = None ,
851+ sample : Optional [ torch .Tensor ] = None ,
826852 noise : Optional [torch .Tensor ] = None ,
827853 ** kwargs ,
828854 ) -> torch .Tensor :
@@ -832,8 +858,10 @@ def dpm_solver_first_order_update(
832858 Args:
833859 model_output (`torch.Tensor`):
834860 The direct output from the learned diffusion model.
835- sample (`torch.Tensor`):
861+ sample (`torch.Tensor`, *optional* ):
836862 A current instance of a sample created by the diffusion process.
863+ noise (`torch.Tensor`, *optional*):
864+ The noise tensor.
837865
838866 Returns:
839867 `torch.Tensor`:
@@ -860,7 +888,10 @@ def dpm_solver_first_order_update(
860888 "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`" ,
861889 )
862890
863- sigma_t , sigma_s = self .sigmas [self .step_index + 1 ], self .sigmas [self .step_index ]
891+ sigma_t , sigma_s = (
892+ self .sigmas [self .step_index + 1 ],
893+ self .sigmas [self .step_index ],
894+ )
864895 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma_t )
865896 alpha_s , sigma_s = self ._sigma_to_alpha_sigma_t (sigma_s )
866897 lambda_t = torch .log (alpha_t ) - torch .log (sigma_t )
@@ -891,7 +922,7 @@ def multistep_dpm_solver_second_order_update(
891922 self ,
892923 model_output_list : List [torch .Tensor ],
893924 * args ,
894- sample : torch .Tensor = None ,
925+ sample : Optional [ torch .Tensor ] = None ,
895926 noise : Optional [torch .Tensor ] = None ,
896927 ** kwargs ,
897928 ) -> torch .Tensor :
@@ -901,7 +932,7 @@ def multistep_dpm_solver_second_order_update(
901932 Args:
902933 model_output_list (`List[torch.Tensor]`):
903934 The direct outputs from learned diffusion model at current and latter timesteps.
904- sample (`torch.Tensor`):
935+ sample (`torch.Tensor`, *optional* ):
905936 A current instance of a sample created by the diffusion process.
906937
907938 Returns:
@@ -1014,7 +1045,7 @@ def multistep_dpm_solver_third_order_update(
10141045 self ,
10151046 model_output_list : List [torch .Tensor ],
10161047 * args ,
1017- sample : torch .Tensor = None ,
1048+ sample : Optional [ torch .Tensor ] = None ,
10181049 noise : Optional [torch .Tensor ] = None ,
10191050 ** kwargs ,
10201051 ) -> torch .Tensor :
@@ -1024,8 +1055,10 @@ def multistep_dpm_solver_third_order_update(
10241055 Args:
10251056 model_output_list (`List[torch.Tensor]`):
10261057 The direct outputs from learned diffusion model at current and latter timesteps.
1027- sample (`torch.Tensor`):
1058+ sample (`torch.Tensor`, *optional* ):
10281059 A current instance of a sample created by diffusion process.
1060+ noise (`torch.Tensor`, *optional*):
1061+ The noise tensor.
10291062
10301063 Returns:
10311064 `torch.Tensor`:
@@ -1106,7 +1139,9 @@ def multistep_dpm_solver_third_order_update(
11061139 return x_t
11071140
11081141 def index_for_timestep (
1109- self , timestep : Union [int , torch .Tensor ], schedule_timesteps : Optional [torch .Tensor ] = None
1142+ self ,
1143+ timestep : Union [int , torch .Tensor ],
1144+ schedule_timesteps : Optional [torch .Tensor ] = None ,
11101145 ) -> int :
11111146 """
11121147 Find the index for a given timestep in the schedule.
@@ -1216,7 +1251,10 @@ def step(
12161251 sample = sample .to (torch .float32 )
12171252 if self .config .algorithm_type in ["sde-dpmsolver" , "sde-dpmsolver++" ] and variance_noise is None :
12181253 noise = randn_tensor (
1219- model_output .shape , generator = generator , device = model_output .device , dtype = torch .float32
1254+ model_output .shape ,
1255+ generator = generator ,
1256+ device = model_output .device ,
1257+ dtype = torch .float32 ,
12201258 )
12211259 elif self .config .algorithm_type in ["sde-dpmsolver" , "sde-dpmsolver++" ]:
12221260 noise = variance_noise .to (device = model_output .device , dtype = torch .float32 )
0 commit comments