Skip to content

Commit 03af690

Browse files
authored
docs: improve docstring scheduling_dpmsolver_multistep_inverse.py (#13083)
Improve docstring scheduling dpmsolver multistep inverse
1 parent 90818e8 commit 03af690

8 files changed

+157
-48
lines changed

src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,9 @@ def multistep_dpm_solver_second_order_update(
545545

546546
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
547547
def index_for_timestep(
548-
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
548+
self,
549+
timestep: Union[int, torch.Tensor],
550+
schedule_timesteps: Optional[torch.Tensor] = None,
549551
) -> int:
550552
"""
551553
Find the index for a given timestep in the schedule.

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,9 @@ def ind_fn(t, b, c, d):
867867

868868
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
869869
def index_for_timestep(
870-
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
870+
self,
871+
timestep: Union[int, torch.Tensor],
872+
schedule_timesteps: Optional[torch.Tensor] = None,
871873
) -> int:
872874
"""
873875
Find the index for a given timestep in the schedule.

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -245,21 +245,42 @@ def __init__(
245245
):
246246
if self.config.use_beta_sigmas and not is_scipy_available():
247247
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
248-
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
248+
if (
249+
sum(
250+
[
251+
self.config.use_beta_sigmas,
252+
self.config.use_exponential_sigmas,
253+
self.config.use_karras_sigmas,
254+
]
255+
)
256+
> 1
257+
):
249258
raise ValueError(
250259
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
251260
)
252261
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
253262
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
254-
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
263+
deprecate(
264+
"algorithm_types dpmsolver and sde-dpmsolver",
265+
"1.0.0",
266+
deprecation_message,
267+
)
255268

256269
if trained_betas is not None:
257270
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
258271
elif beta_schedule == "linear":
259272
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
260273
elif beta_schedule == "scaled_linear":
261274
# this schedule is very specific to the latent diffusion model.
262-
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
275+
self.betas = (
276+
torch.linspace(
277+
beta_start**0.5,
278+
beta_end**0.5,
279+
num_train_timesteps,
280+
dtype=torch.float32,
281+
)
282+
** 2
283+
)
263284
elif beta_schedule == "squaredcos_cap_v2":
264285
# Glide cosine schedule
265286
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -287,7 +308,12 @@ def __init__(
287308
self.init_noise_sigma = 1.0
288309

289310
# settings for DPM-Solver
290-
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
311+
if algorithm_type not in [
312+
"dpmsolver",
313+
"dpmsolver++",
314+
"sde-dpmsolver",
315+
"sde-dpmsolver++",
316+
]:
291317
if algorithm_type == "deis":
292318
self.register_to_config(algorithm_type="dpmsolver++")
293319
else:
@@ -724,7 +750,7 @@ def convert_model_output(
724750
self,
725751
model_output: torch.Tensor,
726752
*args,
727-
sample: torch.Tensor = None,
753+
sample: Optional[torch.Tensor] = None,
728754
**kwargs,
729755
) -> torch.Tensor:
730756
"""
@@ -738,7 +764,7 @@ def convert_model_output(
738764
Args:
739765
model_output (`torch.Tensor`):
740766
The direct output from the learned diffusion model.
741-
sample (`torch.Tensor`):
767+
sample (`torch.Tensor`, *optional*):
742768
A current instance of a sample created by the diffusion process.
743769
744770
Returns:
@@ -822,7 +848,7 @@ def dpm_solver_first_order_update(
822848
self,
823849
model_output: torch.Tensor,
824850
*args,
825-
sample: torch.Tensor = None,
851+
sample: Optional[torch.Tensor] = None,
826852
noise: Optional[torch.Tensor] = None,
827853
**kwargs,
828854
) -> torch.Tensor:
@@ -832,8 +858,10 @@ def dpm_solver_first_order_update(
832858
Args:
833859
model_output (`torch.Tensor`):
834860
The direct output from the learned diffusion model.
835-
sample (`torch.Tensor`):
861+
sample (`torch.Tensor`, *optional*):
836862
A current instance of a sample created by the diffusion process.
863+
noise (`torch.Tensor`, *optional*):
864+
The noise tensor.
837865
838866
Returns:
839867
`torch.Tensor`:
@@ -860,7 +888,10 @@ def dpm_solver_first_order_update(
860888
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
861889
)
862890

863-
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
891+
sigma_t, sigma_s = (
892+
self.sigmas[self.step_index + 1],
893+
self.sigmas[self.step_index],
894+
)
864895
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
865896
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
866897
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -891,7 +922,7 @@ def multistep_dpm_solver_second_order_update(
891922
self,
892923
model_output_list: List[torch.Tensor],
893924
*args,
894-
sample: torch.Tensor = None,
925+
sample: Optional[torch.Tensor] = None,
895926
noise: Optional[torch.Tensor] = None,
896927
**kwargs,
897928
) -> torch.Tensor:
@@ -901,7 +932,7 @@ def multistep_dpm_solver_second_order_update(
901932
Args:
902933
model_output_list (`List[torch.Tensor]`):
903934
The direct outputs from learned diffusion model at current and latter timesteps.
904-
sample (`torch.Tensor`):
935+
sample (`torch.Tensor`, *optional*):
905936
A current instance of a sample created by the diffusion process.
906937
907938
Returns:
@@ -1014,7 +1045,7 @@ def multistep_dpm_solver_third_order_update(
10141045
self,
10151046
model_output_list: List[torch.Tensor],
10161047
*args,
1017-
sample: torch.Tensor = None,
1048+
sample: Optional[torch.Tensor] = None,
10181049
noise: Optional[torch.Tensor] = None,
10191050
**kwargs,
10201051
) -> torch.Tensor:
@@ -1024,8 +1055,10 @@ def multistep_dpm_solver_third_order_update(
10241055
Args:
10251056
model_output_list (`List[torch.Tensor]`):
10261057
The direct outputs from learned diffusion model at current and latter timesteps.
1027-
sample (`torch.Tensor`):
1058+
sample (`torch.Tensor`, *optional*):
10281059
A current instance of a sample created by diffusion process.
1060+
noise (`torch.Tensor`, *optional*):
1061+
The noise tensor.
10291062
10301063
Returns:
10311064
`torch.Tensor`:
@@ -1106,7 +1139,9 @@ def multistep_dpm_solver_third_order_update(
11061139
return x_t
11071140

11081141
def index_for_timestep(
1109-
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
1142+
self,
1143+
timestep: Union[int, torch.Tensor],
1144+
schedule_timesteps: Optional[torch.Tensor] = None,
11101145
) -> int:
11111146
"""
11121147
Find the index for a given timestep in the schedule.
@@ -1216,7 +1251,10 @@ def step(
12161251
sample = sample.to(torch.float32)
12171252
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
12181253
noise = randn_tensor(
1219-
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
1254+
model_output.shape,
1255+
generator=generator,
1256+
device=model_output.device,
1257+
dtype=torch.float32,
12201258
)
12211259
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
12221260
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)

0 commit comments

Comments
 (0)