Skip to content

Commit 5f3ea22

Browse files
authored
docs: improve docstring scheduling_flow_match_heun_discrete.py (#13130)
Improve docstring scheduling flow match heun discrete
1 parent 427472e commit 5f3ea22

File tree

1 file changed

+48
-16
lines changed

1 file changed

+48
-16
lines changed

src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
5151
Args:
5252
num_train_timesteps (`int`, defaults to 1000):
5353
The number of diffusion steps to train the model.
54-
timestep_spacing (`str`, defaults to `"linspace"`):
55-
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56-
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
5754
shift (`float`, defaults to 1.0):
5855
The shift value for the timestep schedule.
5956
"""
@@ -110,7 +107,7 @@ def set_begin_index(self, begin_index: int = 0):
110107
def scale_noise(
111108
self,
112109
sample: torch.FloatTensor,
113-
timestep: torch.FloatTensor,
110+
timestep: Union[float, torch.FloatTensor],
114111
noise: torch.FloatTensor,
115112
) -> torch.FloatTensor:
116113
"""
@@ -119,7 +116,7 @@ def scale_noise(
119116
Args:
120117
sample (`torch.FloatTensor`):
121118
The input sample.
122-
timestep (`torch.FloatTensor`):
119+
timestep (`float` or `torch.FloatTensor`):
123120
The current timestep in the diffusion chain.
124121
noise (`torch.FloatTensor`):
125122
The noise tensor.
@@ -137,10 +134,14 @@ def scale_noise(
137134

138135
return sample
139136

140-
def _sigma_to_t(self, sigma):
137+
def _sigma_to_t(self, sigma: float) -> float:
141138
return sigma * self.config.num_train_timesteps
142139

143-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
140+
def set_timesteps(
141+
self,
142+
num_inference_steps: int,
143+
device: Union[str, torch.device] = None,
144+
) -> None:
144145
"""
145146
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
146147
@@ -153,7 +154,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
153154
self.num_inference_steps = num_inference_steps
154155

155156
timesteps = np.linspace(
156-
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
157+
self._sigma_to_t(self.sigma_max),
158+
self._sigma_to_t(self.sigma_min),
159+
num_inference_steps,
157160
)
158161

159162
sigmas = timesteps / self.config.num_train_timesteps
@@ -174,7 +177,24 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
174177
self._step_index = None
175178
self._begin_index = None
176179

177-
def index_for_timestep(self, timestep, schedule_timesteps=None):
180+
def index_for_timestep(
181+
self,
182+
timestep: Union[float, torch.FloatTensor],
183+
schedule_timesteps: Optional[torch.FloatTensor] = None,
184+
) -> int:
185+
"""
186+
Find the index of a given timestep in the timestep schedule.
187+
188+
Args:
189+
timestep (`float` or `torch.FloatTensor`):
190+
The timestep value to find in the schedule.
191+
schedule_timesteps (`torch.FloatTensor`, *optional*):
192+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
193+
194+
Returns:
195+
`int`:
196+
The index of the timestep in the schedule.
197+
"""
178198
if schedule_timesteps is None:
179199
schedule_timesteps = self.timesteps
180200

@@ -188,7 +208,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
188208

189209
return indices[pos].item()
190210

191-
def _init_step_index(self, timestep):
211+
def _init_step_index(self, timestep: Union[float, torch.FloatTensor]) -> None:
192212
if self.begin_index is None:
193213
if isinstance(timestep, torch.Tensor):
194214
timestep = timestep.to(self.timesteps.device)
@@ -197,7 +217,10 @@ def _init_step_index(self, timestep):
197217
self._step_index = self._begin_index
198218

199219
@property
200-
def state_in_first_order(self):
220+
def state_in_first_order(self) -> bool:
221+
"""
222+
Returns whether the scheduler is in the first-order state.
223+
"""
201224
return self.dt is None
202225

203226
def step(
@@ -219,13 +242,19 @@ def step(
219242
Args:
220243
model_output (`torch.FloatTensor`):
221244
The direct output from learned diffusion model.
222-
timestep (`float`):
245+
timestep (`float` or `torch.FloatTensor`):
223246
The current discrete timestep in the diffusion chain.
224247
sample (`torch.FloatTensor`):
225248
A current instance of a sample created by the diffusion process.
226249
s_churn (`float`):
227-
s_tmin (`float`):
228-
s_tmax (`float`):
250+
Stochasticity parameter that controls the amount of noise added during sampling. Higher values increase
251+
randomness.
252+
s_tmin (`float`):
253+
Minimum timestep threshold for applying stochasticity. Only timesteps above this value will have noise
254+
added.
255+
s_tmax (`float`):
256+
Maximum timestep threshold for applying stochasticity. Only timesteps below this value will have noise
257+
added.
229258
s_noise (`float`, defaults to 1.0):
230259
Scaling factor for noise added to the sample.
231260
generator (`torch.Generator`, *optional*):
@@ -274,7 +303,10 @@ def step(
274303

275304
if gamma > 0:
276305
noise = randn_tensor(
277-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
306+
model_output.shape,
307+
dtype=model_output.dtype,
308+
device=model_output.device,
309+
generator=generator,
278310
)
279311
eps = noise * s_noise
280312
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
@@ -320,5 +352,5 @@ def step(
320352

321353
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
322354

323-
def __len__(self):
355+
def __len__(self) -> int:
324356
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)