@@ -51,9 +51,6 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
5151 Args:
5252 num_train_timesteps (`int`, defaults to 1000):
5353 The number of diffusion steps to train the model.
54- timestep_spacing (`str`, defaults to `"linspace"`):
55- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
5754 shift (`float`, defaults to 1.0):
5855 The shift value for the timestep schedule.
5956 """
@@ -110,7 +107,7 @@ def set_begin_index(self, begin_index: int = 0):
110107 def scale_noise (
111108 self ,
112109 sample : torch .FloatTensor ,
113- timestep : torch .FloatTensor ,
110+ timestep : Union [ float , torch .FloatTensor ] ,
114111 noise : torch .FloatTensor ,
115112 ) -> torch .FloatTensor :
116113 """
@@ -119,7 +116,7 @@ def scale_noise(
119116 Args:
120117 sample (`torch.FloatTensor`):
121118 The input sample.
122- timestep (`torch.FloatTensor`):
119+ timestep (`float` or ` torch.FloatTensor`):
123120 The current timestep in the diffusion chain.
124121 noise (`torch.FloatTensor`):
125122 The noise tensor.
@@ -137,10 +134,14 @@ def scale_noise(
137134
138135 return sample
139136
140- def _sigma_to_t (self , sigma ) :
137+ def _sigma_to_t (self , sigma : float ) -> float :
141138 return sigma * self .config .num_train_timesteps
142139
143- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
140+ def set_timesteps (
141+ self ,
142+ num_inference_steps : int ,
143+ device : Union [str , torch .device ] = None ,
144+ ) -> None :
144145 """
145146 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
146147
@@ -153,7 +154,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
153154 self .num_inference_steps = num_inference_steps
154155
155156 timesteps = np .linspace (
156- self ._sigma_to_t (self .sigma_max ), self ._sigma_to_t (self .sigma_min ), num_inference_steps
157+ self ._sigma_to_t (self .sigma_max ),
158+ self ._sigma_to_t (self .sigma_min ),
159+ num_inference_steps ,
157160 )
158161
159162 sigmas = timesteps / self .config .num_train_timesteps
@@ -174,7 +177,24 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
174177 self ._step_index = None
175178 self ._begin_index = None
176179
177- def index_for_timestep (self , timestep , schedule_timesteps = None ):
180+ def index_for_timestep (
181+ self ,
182+ timestep : Union [float , torch .FloatTensor ],
183+ schedule_timesteps : Optional [torch .FloatTensor ] = None ,
184+ ) -> int :
185+ """
186+ Find the index of a given timestep in the timestep schedule.
187+
188+ Args:
189+ timestep (`float` or `torch.FloatTensor`):
190+ The timestep value to find in the schedule.
191+ schedule_timesteps (`torch.FloatTensor`, *optional*):
192+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
193+
194+ Returns:
195+ `int`:
196+ The index of the timestep in the schedule.
197+ """
178198 if schedule_timesteps is None :
179199 schedule_timesteps = self .timesteps
180200
@@ -188,7 +208,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
188208
189209 return indices [pos ].item ()
190210
191- def _init_step_index (self , timestep ) :
211+ def _init_step_index (self , timestep : Union [ float , torch . FloatTensor ]) -> None :
192212 if self .begin_index is None :
193213 if isinstance (timestep , torch .Tensor ):
194214 timestep = timestep .to (self .timesteps .device )
@@ -197,7 +217,10 @@ def _init_step_index(self, timestep):
197217 self ._step_index = self ._begin_index
198218
199219 @property
200- def state_in_first_order (self ):
220+ def state_in_first_order (self ) -> bool :
221+ """
222+ Returns whether the scheduler is in the first-order state.
223+ """
201224 return self .dt is None
202225
203226 def step (
@@ -219,13 +242,19 @@ def step(
219242 Args:
220243 model_output (`torch.FloatTensor`):
221244 The direct output from learned diffusion model.
222- timestep (`float`):
245+ timestep (`float` or `torch.FloatTensor` ):
223246 The current discrete timestep in the diffusion chain.
224247 sample (`torch.FloatTensor`):
225248 A current instance of a sample created by the diffusion process.
226249 s_churn (`float`):
227- s_tmin (`float`):
228- s_tmax (`float`):
250+ Stochasticity parameter that controls the amount of noise added during sampling. Higher values increase
251+ randomness.
252+ s_tmin (`float`):
253+ Minimum timestep threshold for applying stochasticity. Only timesteps above this value will have noise
254+ added.
255+ s_tmax (`float`):
256+ Maximum timestep threshold for applying stochasticity. Only timesteps below this value will have noise
257+ added.
229258 s_noise (`float`, defaults to 1.0):
230259 Scaling factor for noise added to the sample.
231260 generator (`torch.Generator`, *optional*):
@@ -274,7 +303,10 @@ def step(
274303
275304 if gamma > 0 :
276305 noise = randn_tensor (
277- model_output .shape , dtype = model_output .dtype , device = model_output .device , generator = generator
306+ model_output .shape ,
307+ dtype = model_output .dtype ,
308+ device = model_output .device ,
309+ generator = generator ,
278310 )
279311 eps = noise * s_noise
280312 sample = sample + eps * (sigma_hat ** 2 - sigma ** 2 ) ** 0.5
@@ -320,5 +352,5 @@ def step(
320352
321353 return FlowMatchHeunDiscreteSchedulerOutput (prev_sample = prev_sample )
322354
323- def __len__ (self ):
355+ def __len__ (self ) -> int :
324356 return self .config .num_train_timesteps
0 commit comments