Skip to content

Commit 64e2adf

Browse files
authored
docs: improve docstring scheduling_edm_dpmsolver_multistep.py (#13122)
Improve docstring scheduling edm dpmsolver multistep
1 parent c3a4cd1 commit 64e2adf

File tree

1 file changed

+30
-21
lines changed

1 file changed

+30
-21
lines changed

src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
1616

1717
import math
18-
from typing import List, Optional, Tuple, Union
18+
from typing import List, Literal, Optional, Tuple, Union
1919

2020
import numpy as np
2121
import torch
@@ -51,13 +51,15 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
5151
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
5252
num_train_timesteps (`int`, defaults to 1000):
5353
The number of diffusion steps to train the model.
54-
solver_order (`int`, defaults to 2):
55-
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
56-
sampling, and `solver_order=3` for unconditional sampling.
5754
prediction_type (`str`, defaults to `epsilon`, *optional*):
5855
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
5956
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
6057
Video](https://huggingface.co/papers/2210.02303) paper).
58+
rho (`float`, *optional*, defaults to 7.0):
59+
The rho parameter in the Karras sigma schedule. This was set to 7.0 in the EDM paper [1].
60+
solver_order (`int`, defaults to 2):
61+
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
62+
sampling, and `solver_order=3` for unconditional sampling.
6163
thresholding (`bool`, defaults to `False`):
6264
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
6365
as Stable Diffusion.
@@ -94,19 +96,19 @@ def __init__(
9496
sigma_min: float = 0.002,
9597
sigma_max: float = 80.0,
9698
sigma_data: float = 0.5,
97-
sigma_schedule: str = "karras",
99+
sigma_schedule: Literal["karras", "exponential"] = "karras",
98100
num_train_timesteps: int = 1000,
99-
prediction_type: str = "epsilon",
101+
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
100102
rho: float = 7.0,
101103
solver_order: int = 2,
102104
thresholding: bool = False,
103105
dynamic_thresholding_ratio: float = 0.995,
104106
sample_max_value: float = 1.0,
105-
algorithm_type: str = "dpmsolver++",
106-
solver_type: str = "midpoint",
107+
algorithm_type: Literal["dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++",
108+
solver_type: Literal["midpoint", "heun"] = "midpoint",
107109
lower_order_final: bool = True,
108110
euler_at_final: bool = False,
109-
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
111+
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", # "zero", "sigma_min"
110112
):
111113
# settings for DPM-Solver
112114
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
@@ -145,19 +147,19 @@ def __init__(
145147
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
146148

147149
@property
148-
def init_noise_sigma(self):
150+
def init_noise_sigma(self) -> float:
149151
# standard deviation of the initial noise distribution
150152
return (self.config.sigma_max**2 + 1) ** 0.5
151153

152154
@property
153-
def step_index(self):
155+
def step_index(self) -> int:
154156
"""
155157
The index counter for current timestep. It will increase 1 after each scheduler step.
156158
"""
157159
return self._step_index
158160

159161
@property
160-
def begin_index(self):
162+
def begin_index(self) -> int:
161163
"""
162164
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
163165
"""
@@ -274,7 +276,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
274276
self.is_scale_input_called = True
275277
return sample
276278

277-
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
279+
def set_timesteps(
280+
self,
281+
num_inference_steps: int = None,
282+
device: Optional[Union[str, torch.device]] = None,
283+
):
278284
"""
279285
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
280286
@@ -460,13 +466,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
460466
def _sigma_to_alpha_sigma_t(self, sigma):
461467
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
462468
sigma_t = sigma
463-
464469
return alpha_t, sigma_t
465470

466471
def convert_model_output(
467472
self,
468473
model_output: torch.Tensor,
469-
sample: torch.Tensor = None,
474+
sample: torch.Tensor,
470475
) -> torch.Tensor:
471476
"""
472477
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
@@ -497,7 +502,7 @@ def convert_model_output(
497502
def dpm_solver_first_order_update(
498503
self,
499504
model_output: torch.Tensor,
500-
sample: torch.Tensor = None,
505+
sample: torch.Tensor,
501506
noise: Optional[torch.Tensor] = None,
502507
) -> torch.Tensor:
503508
"""
@@ -508,6 +513,8 @@ def dpm_solver_first_order_update(
508513
The direct output from the learned diffusion model.
509514
sample (`torch.Tensor`):
510515
A current instance of a sample created by the diffusion process.
516+
noise (`torch.Tensor`, *optional*):
517+
The noise tensor to add to the original samples.
511518
512519
Returns:
513520
`torch.Tensor`:
@@ -538,7 +545,7 @@ def dpm_solver_first_order_update(
538545
def multistep_dpm_solver_second_order_update(
539546
self,
540547
model_output_list: List[torch.Tensor],
541-
sample: torch.Tensor = None,
548+
sample: torch.Tensor,
542549
noise: Optional[torch.Tensor] = None,
543550
) -> torch.Tensor:
544551
"""
@@ -549,6 +556,8 @@ def multistep_dpm_solver_second_order_update(
549556
The direct outputs from learned diffusion model at current and latter timesteps.
550557
sample (`torch.Tensor`):
551558
A current instance of a sample created by the diffusion process.
559+
noise (`torch.Tensor`, *optional*):
560+
The noise tensor to add to the original samples.
552561
553562
Returns:
554563
`torch.Tensor`:
@@ -609,7 +618,7 @@ def multistep_dpm_solver_second_order_update(
609618
def multistep_dpm_solver_third_order_update(
610619
self,
611620
model_output_list: List[torch.Tensor],
612-
sample: torch.Tensor = None,
621+
sample: torch.Tensor,
613622
) -> torch.Tensor:
614623
"""
615624
One step for the third-order multistep DPMSolver.
@@ -698,7 +707,7 @@ def index_for_timestep(
698707
return step_index
699708

700709
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
701-
def _init_step_index(self, timestep):
710+
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
702711
"""
703712
Initialize the step_index counter for the scheduler.
704713
@@ -719,7 +728,7 @@ def step(
719728
model_output: torch.Tensor,
720729
timestep: Union[int, torch.Tensor],
721730
sample: torch.Tensor,
722-
generator=None,
731+
generator: Optional[torch.Generator] = None,
723732
return_dict: bool = True,
724733
) -> Union[SchedulerOutput, Tuple]:
725734
"""
@@ -860,5 +869,5 @@ def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[flo
860869
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
861870
return c_in
862871

863-
def __len__(self):
872+
def __len__(self) -> int:
864873
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)