1414
1515import math
1616from dataclasses import dataclass
17- from typing import List , Optional , Tuple , Union
17+ from typing import List , Literal , Optional , Tuple , Union
1818
1919import numpy as np
2020import torch
@@ -102,12 +102,21 @@ def __init__(
102102 use_karras_sigmas : Optional [bool ] = False ,
103103 use_exponential_sigmas : Optional [bool ] = False ,
104104 use_beta_sigmas : Optional [bool ] = False ,
105- time_shift_type : str = "exponential" ,
105+ time_shift_type : Literal [ "exponential" , "linear" ] = "exponential" ,
106106 stochastic_sampling : bool = False ,
107107 ):
108108 if self .config .use_beta_sigmas and not is_scipy_available ():
109109 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
110- if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
110+ if (
111+ sum (
112+ [
113+ self .config .use_beta_sigmas ,
114+ self .config .use_exponential_sigmas ,
115+ self .config .use_karras_sigmas ,
116+ ]
117+ )
118+ > 1
119+ ):
111120 raise ValueError (
112121 "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
113122 )
@@ -166,6 +175,13 @@ def set_begin_index(self, begin_index: int = 0):
166175 self ._begin_index = begin_index
167176
168177 def set_shift (self , shift : float ):
178+ """
179+ Sets the shift value for the scheduler.
180+
181+ Args:
182+ shift (`float`):
183+ The shift value to be set.
184+ """
169185 self ._shift = shift
170186
171187 def scale_noise (
@@ -218,10 +234,25 @@ def scale_noise(
218234
219235 return sample
220236
221- def _sigma_to_t (self , sigma ):
237+ def _sigma_to_t (self , sigma ) -> float :
222238 return sigma * self .config .num_train_timesteps
223239
224- def time_shift (self , mu : float , sigma : float , t : torch .Tensor ):
240+ def time_shift (self , mu : float , sigma : float , t : torch .Tensor ) -> torch .Tensor :
241+ """
242+ Apply time shifting to the sigmas.
243+
244+ Args:
245+ mu (`float`):
246+ The mu parameter for the time shift.
247+ sigma (`float`):
248+ The sigma parameter for the time shift.
249+ t (`torch.Tensor`):
250+ The input timesteps.
251+
252+ Returns:
253+ `torch.Tensor`:
254+ The time-shifted timesteps.
255+ """
225256 if self .config .time_shift_type == "exponential" :
226257 return self ._time_shift_exponential (mu , sigma , t )
227258 elif self .config .time_shift_type == "linear" :
@@ -302,7 +333,9 @@ def set_timesteps(
302333 if sigmas is None :
303334 if timesteps is None :
304335 timesteps = np .linspace (
305- self ._sigma_to_t (self .sigma_max ), self ._sigma_to_t (self .sigma_min ), num_inference_steps
336+ self ._sigma_to_t (self .sigma_max ),
337+ self ._sigma_to_t (self .sigma_min ),
338+ num_inference_steps ,
306339 )
307340 sigmas = timesteps / self .config .num_train_timesteps
308341 else :
@@ -350,7 +383,24 @@ def set_timesteps(
350383 self ._step_index = None
351384 self ._begin_index = None
352385
353- def index_for_timestep (self , timestep , schedule_timesteps = None ):
386+ def index_for_timestep (
387+ self ,
388+ timestep : Union [float , torch .FloatTensor ],
389+ schedule_timesteps : Optional [torch .FloatTensor ] = None ,
390+ ) -> int :
391+ """
392+ Get the index for the given timestep.
393+
394+ Args:
395+ timestep (`float` or `torch.FloatTensor`):
396+ The timestep to find the index for.
397+ schedule_timesteps (`torch.FloatTensor`, *optional*):
398+ The schedule timesteps to validate against. If `None`, the scheduler's timesteps are used.
399+
400+ Returns:
401+ `int`:
402+ The index of the timestep.
403+ """
354404 if schedule_timesteps is None :
355405 schedule_timesteps = self .timesteps
356406
@@ -364,7 +414,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
364414
365415 return indices [pos ].item ()
366416
367- def _init_step_index (self , timestep ) :
417+ def _init_step_index (self , timestep : Union [ float , torch . FloatTensor ]) -> None :
368418 if self .begin_index is None :
369419 if isinstance (timestep , torch .Tensor ):
370420 timestep = timestep .to (self .timesteps .device )
@@ -405,7 +455,7 @@ def step(
405455 A random number generator.
406456 per_token_timesteps (`torch.Tensor`, *optional*):
407457 The timesteps for each token in the sample.
408- return_dict (`bool`):
458+ return_dict (`bool`, defaults to `True` ):
409459 Whether or not to return a
410460 [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
411461
@@ -474,7 +524,7 @@ def step(
474524 return FlowMatchEulerDiscreteSchedulerOutput (prev_sample = prev_sample )
475525
476526 # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
477- def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps ) -> torch .Tensor :
527+ def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps : int ) -> torch .Tensor :
478528 """
479529 Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
480530 Models](https://huggingface.co/papers/2206.00364).
@@ -595,11 +645,11 @@ def _convert_to_beta(
595645 )
596646 return sigmas
597647
598- def _time_shift_exponential (self , mu , sigma , t ) :
648+ def _time_shift_exponential (self , mu : float , sigma : float , t : torch . Tensor ) -> torch . Tensor :
599649 return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
600650
601- def _time_shift_linear (self , mu , sigma , t ) :
651+ def _time_shift_linear (self , mu : float , sigma : float , t : torch . Tensor ) -> torch . Tensor :
602652 return mu / (mu + (1 / t - 1 ) ** sigma )
603653
604- def __len__ (self ):
654+ def __len__ (self ) -> int :
605655 return self .config .num_train_timesteps
0 commit comments