Skip to content

Commit 0ab2124

Browse files
authored
docs: improve docstring scheduling_dpm_cogvideox.py (#13044)
1 parent 74a0f0b commit 0ab2124

File tree

1 file changed

+89
-14
lines changed

1 file changed

+89
-14
lines changed

src/diffusers/schedulers/scheduling_dpm_cogvideox.py

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def rescale_zero_terminal_snr(alphas_cumprod):
105105
"""
106106
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
107107
108-
109108
Args:
110109
betas (`torch.Tensor`):
111110
the betas that the scheduler is being initialized with.
@@ -175,11 +174,14 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
175174
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
176175
timestep_spacing (`str`, defaults to `"leading"`):
177176
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
178-
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
177+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Choose from
178+
`leading`, `linspace` or `trailing`.
179179
rescale_betas_zero_snr (`bool`, defaults to `False`):
180180
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
181181
dark samples instead of limiting it to samples with medium brightness. Loosely related to
182182
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
183+
snr_shift_scale (`float`, defaults to 3.0):
184+
Shift scale for SNR.
183185
"""
184186

185187
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -191,15 +193,15 @@ def __init__(
191193
num_train_timesteps: int = 1000,
192194
beta_start: float = 0.00085,
193195
beta_end: float = 0.0120,
194-
beta_schedule: str = "scaled_linear",
196+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
195197
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
196198
clip_sample: bool = True,
197199
set_alpha_to_one: bool = True,
198200
steps_offset: int = 0,
199-
prediction_type: str = "epsilon",
201+
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
200202
clip_sample_range: float = 1.0,
201203
sample_max_value: float = 1.0,
202-
timestep_spacing: str = "leading",
204+
timestep_spacing: Literal["leading", "linspace", "trailing"] = "leading",
203205
rescale_betas_zero_snr: bool = False,
204206
snr_shift_scale: float = 3.0,
205207
):
@@ -209,7 +211,15 @@ def __init__(
209211
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
210212
elif beta_schedule == "scaled_linear":
211213
# this schedule is very specific to the latent diffusion model.
212-
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
214+
self.betas = (
215+
torch.linspace(
216+
beta_start**0.5,
217+
beta_end**0.5,
218+
num_train_timesteps,
219+
dtype=torch.float64,
220+
)
221+
** 2
222+
)
213223
elif beta_schedule == "squaredcos_cap_v2":
214224
# Glide cosine schedule
215225
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -266,13 +276,20 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
266276
"""
267277
return sample
268278

269-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
279+
def set_timesteps(
280+
self,
281+
num_inference_steps: int,
282+
device: Optional[Union[str, torch.device]] = None,
283+
):
270284
"""
271285
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
272286
273287
Args:
274288
num_inference_steps (`int`):
275289
The number of diffusion steps used when generating samples with a pre-trained model.
290+
device (`str` or `torch.device`, *optional*):
291+
The device to which the timesteps should be moved to. If `None` (the default), the timesteps are not
292+
moved.
276293
"""
277294

278295
if num_inference_steps > self.config.num_train_timesteps:
@@ -311,7 +328,27 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
311328

312329
self.timesteps = torch.from_numpy(timesteps).to(device)
313330

314-
def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
331+
def get_variables(
332+
self,
333+
alpha_prod_t: torch.Tensor,
334+
alpha_prod_t_prev: torch.Tensor,
335+
alpha_prod_t_back: Optional[torch.Tensor] = None,
336+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
337+
"""
338+
Compute the variables used for DPM-Solver++ (2M) referencing the original implementation.
339+
340+
Args:
341+
alpha_prod_t (`torch.Tensor`):
342+
The cumulative product of alphas at the current timestep.
343+
alpha_prod_t_prev (`torch.Tensor`):
344+
The cumulative product of alphas at the previous timestep.
345+
alpha_prod_t_back (`torch.Tensor`, *optional*):
346+
The cumulative product of alphas at the timestep before the previous timestep.
347+
348+
Returns:
349+
`tuple`:
350+
A tuple containing the variables `h`, `r`, `lamb`, `lamb_next`.
351+
"""
315352
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
316353
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
317354
h = lamb_next - lamb
@@ -324,7 +361,36 @@ def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None)
324361
else:
325362
return h, None, lamb, lamb_next
326363

327-
def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
364+
def get_mult(
365+
self,
366+
h: torch.Tensor,
367+
r: Optional[torch.Tensor],
368+
alpha_prod_t: torch.Tensor,
369+
alpha_prod_t_prev: torch.Tensor,
370+
alpha_prod_t_back: Optional[torch.Tensor] = None,
371+
) -> Union[
372+
Tuple[torch.Tensor, torch.Tensor],
373+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
374+
]:
375+
"""
376+
Compute the multipliers for the previous sample and the predicted original sample.
377+
378+
Args:
379+
h (`torch.Tensor`):
380+
The log-SNR difference.
381+
r (`torch.Tensor`):
382+
The ratio of log-SNR differences.
383+
alpha_prod_t (`torch.Tensor`):
384+
The cumulative product of alphas at the current timestep.
385+
alpha_prod_t_prev (`torch.Tensor`):
386+
The cumulative product of alphas at the previous timestep.
387+
alpha_prod_t_back (`torch.Tensor`, *optional*):
388+
The cumulative product of alphas at the timestep before the previous timestep.
389+
390+
Returns:
391+
`tuple`:
392+
A tuple containing the multipliers.
393+
"""
328394
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
329395
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
330396

@@ -338,13 +404,13 @@ def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
338404
def step(
339405
self,
340406
model_output: torch.Tensor,
341-
old_pred_original_sample: torch.Tensor,
407+
old_pred_original_sample: Optional[torch.Tensor],
342408
timestep: int,
343409
timestep_back: int,
344410
sample: torch.Tensor,
345411
eta: float = 0.0,
346412
use_clipped_model_output: bool = False,
347-
generator=None,
413+
generator: Optional[torch.Generator] = None,
348414
variance_noise: Optional[torch.Tensor] = None,
349415
return_dict: bool = False,
350416
) -> Union[DDIMSchedulerOutput, Tuple]:
@@ -355,8 +421,12 @@ def step(
355421
Args:
356422
model_output (`torch.Tensor`):
357423
The direct output from learned diffusion model.
358-
timestep (`float`):
424+
old_pred_original_sample (`torch.Tensor`):
425+
The predicted original sample from the previous timestep.
426+
timestep (`int`):
359427
The current discrete timestep in the diffusion chain.
428+
timestep_back (`int`):
429+
The timestep to look back to.
360430
sample (`torch.Tensor`):
361431
A current instance of a sample created by the diffusion process.
362432
eta (`float`):
@@ -436,7 +506,12 @@ def step(
436506
return prev_sample, pred_original_sample
437507
else:
438508
denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
439-
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
509+
noise = randn_tensor(
510+
sample.shape,
511+
generator=generator,
512+
device=sample.device,
513+
dtype=sample.dtype,
514+
)
440515
x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise
441516

442517
prev_sample = x_advanced
@@ -524,5 +599,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
524599
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
525600
return velocity
526601

527-
def __len__(self):
602+
def __len__(self) -> int:
528603
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)