Skip to content

Commit 06a0f98

Browse files
authored
docs: improve docstring scheduling_flow_match_euler_discrete.py (#13127)
Improve docstring scheduling flow match euler discrete
1 parent d324839 commit 06a0f98

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import math
1616
from dataclasses import dataclass
17-
from typing import List, Optional, Tuple, Union
17+
from typing import List, Literal, Optional, Tuple, Union
1818

1919
import numpy as np
2020
import torch
@@ -102,12 +102,21 @@ def __init__(
102102
use_karras_sigmas: Optional[bool] = False,
103103
use_exponential_sigmas: Optional[bool] = False,
104104
use_beta_sigmas: Optional[bool] = False,
105-
time_shift_type: str = "exponential",
105+
time_shift_type: Literal["exponential", "linear"] = "exponential",
106106
stochastic_sampling: bool = False,
107107
):
108108
if self.config.use_beta_sigmas and not is_scipy_available():
109109
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
110-
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
110+
if (
111+
sum(
112+
[
113+
self.config.use_beta_sigmas,
114+
self.config.use_exponential_sigmas,
115+
self.config.use_karras_sigmas,
116+
]
117+
)
118+
> 1
119+
):
111120
raise ValueError(
112121
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
113122
)
@@ -166,6 +175,13 @@ def set_begin_index(self, begin_index: int = 0):
166175
self._begin_index = begin_index
167176

168177
def set_shift(self, shift: float):
178+
"""
179+
Sets the shift value for the scheduler.
180+
181+
Args:
182+
shift (`float`):
183+
The shift value to be set.
184+
"""
169185
self._shift = shift
170186

171187
def scale_noise(
@@ -218,10 +234,25 @@ def scale_noise(
218234

219235
return sample
220236

221-
def _sigma_to_t(self, sigma):
237+
def _sigma_to_t(self, sigma) -> float:
222238
return sigma * self.config.num_train_timesteps
223239

224-
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
240+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
241+
"""
242+
Apply time shifting to the sigmas.
243+
244+
Args:
245+
mu (`float`):
246+
The mu parameter for the time shift.
247+
sigma (`float`):
248+
The sigma parameter for the time shift.
249+
t (`torch.Tensor`):
250+
The input timesteps.
251+
252+
Returns:
253+
`torch.Tensor`:
254+
The time-shifted timesteps.
255+
"""
225256
if self.config.time_shift_type == "exponential":
226257
return self._time_shift_exponential(mu, sigma, t)
227258
elif self.config.time_shift_type == "linear":
@@ -302,7 +333,9 @@ def set_timesteps(
302333
if sigmas is None:
303334
if timesteps is None:
304335
timesteps = np.linspace(
305-
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
336+
self._sigma_to_t(self.sigma_max),
337+
self._sigma_to_t(self.sigma_min),
338+
num_inference_steps,
306339
)
307340
sigmas = timesteps / self.config.num_train_timesteps
308341
else:
@@ -350,7 +383,24 @@ def set_timesteps(
350383
self._step_index = None
351384
self._begin_index = None
352385

353-
def index_for_timestep(self, timestep, schedule_timesteps=None):
386+
def index_for_timestep(
387+
self,
388+
timestep: Union[float, torch.FloatTensor],
389+
schedule_timesteps: Optional[torch.FloatTensor] = None,
390+
) -> int:
391+
"""
392+
Get the index for the given timestep.
393+
394+
Args:
395+
timestep (`float` or `torch.FloatTensor`):
396+
The timestep to find the index for.
397+
schedule_timesteps (`torch.FloatTensor`, *optional*):
398+
The schedule timesteps to validate against. If `None`, the scheduler's timesteps are used.
399+
400+
Returns:
401+
`int`:
402+
The index of the timestep.
403+
"""
354404
if schedule_timesteps is None:
355405
schedule_timesteps = self.timesteps
356406

@@ -364,7 +414,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
364414

365415
return indices[pos].item()
366416

367-
def _init_step_index(self, timestep):
417+
def _init_step_index(self, timestep: Union[float, torch.FloatTensor]) -> None:
368418
if self.begin_index is None:
369419
if isinstance(timestep, torch.Tensor):
370420
timestep = timestep.to(self.timesteps.device)
@@ -405,7 +455,7 @@ def step(
405455
A random number generator.
406456
per_token_timesteps (`torch.Tensor`, *optional*):
407457
The timesteps for each token in the sample.
408-
return_dict (`bool`):
458+
return_dict (`bool`, defaults to `True`):
409459
Whether or not to return a
410460
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
411461
@@ -474,7 +524,7 @@ def step(
474524
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
475525

476526
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
477-
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
527+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
478528
"""
479529
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
480530
Models](https://huggingface.co/papers/2206.00364).
@@ -595,11 +645,11 @@ def _convert_to_beta(
595645
)
596646
return sigmas
597647

598-
def _time_shift_exponential(self, mu, sigma, t):
648+
def _time_shift_exponential(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
599649
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
600650

601-
def _time_shift_linear(self, mu, sigma, t):
651+
def _time_shift_linear(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
602652
return mu / (mu + (1 / t - 1) ** sigma)
603653

604-
def __len__(self):
654+
def __len__(self) -> int:
605655
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,21 @@ def set_timesteps(
482482

483483
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
484484
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
485+
"""
486+
Apply time shifting to the sigmas.
487+
488+
Args:
489+
mu (`float`):
490+
The mu parameter for the time shift.
491+
sigma (`float`):
492+
The sigma parameter for the time shift.
493+
t (`torch.Tensor`):
494+
The input timesteps.
495+
496+
Returns:
497+
`torch.Tensor`:
498+
The time-shifted timesteps.
499+
"""
485500
if self.config.time_shift_type == "exponential":
486501
return self._time_shift_exponential(mu, sigma, t)
487502
elif self.config.time_shift_type == "linear":

0 commit comments

Comments
 (0)