Skip to content

Commit c2fdd2d

Browse files
authored
docs: improve docstring scheduling_ipndm.py (#13198)
Improve docstring scheduling ipndm
1 parent 84ff061 commit c2fdd2d

File tree

1 file changed

+43
-8
lines changed

1 file changed

+43
-8
lines changed

src/diffusers/schedulers/scheduling_ipndm.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,18 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
3131
Args:
3232
num_train_timesteps (`int`, defaults to 1000):
3333
The number of diffusion steps to train the model.
34-
trained_betas (`np.ndarray`, *optional*):
34+
trained_betas (`np.ndarray` or `List[float]`, *optional*):
3535
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
3636
"""
3737

3838
order = 1
3939

4040
@register_to_config
41-
def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray | list[float] | None = None):
41+
def __init__(
42+
self,
43+
num_train_timesteps: int = 1000,
44+
trained_betas: np.ndarray | list[float] | None = None,
45+
):
4246
# set `betas`, `alphas`, `timesteps`
4347
self.set_timesteps(num_train_timesteps)
4448

@@ -56,21 +60,29 @@ def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray |
5660
self._begin_index = None
5761

5862
@property
59-
def step_index(self):
63+
def step_index(self) -> int | None:
6064
"""
6165
The index counter for current timestep. It will increase 1 after each scheduler step.
66+
67+
Returns:
68+
`int` or `None`:
69+
The index counter for current timestep.
6270
"""
6371
return self._step_index
6472

6573
@property
66-
def begin_index(self):
74+
def begin_index(self) -> int | None:
6775
"""
6876
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
77+
78+
Returns:
79+
`int` or `None`:
80+
The index for the first timestep.
6981
"""
7082
return self._begin_index
7183

7284
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
73-
def set_begin_index(self, begin_index: int = 0):
85+
def set_begin_index(self, begin_index: int = 0) -> None:
7486
"""
7587
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
7688
@@ -169,7 +181,7 @@ def step(
169181
Args:
170182
model_output (`torch.Tensor`):
171183
The direct output from learned diffusion model.
172-
timestep (`int`):
184+
timestep (`int` or `torch.Tensor`):
173185
The current discrete timestep in the diffusion chain.
174186
sample (`torch.Tensor`):
175187
A current instance of a sample created by the diffusion process.
@@ -228,7 +240,30 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens
228240
"""
229241
return sample
230242

231-
def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
243+
def _get_prev_sample(
244+
self,
245+
sample: torch.Tensor,
246+
timestep_index: int,
247+
prev_timestep_index: int,
248+
ets: torch.Tensor,
249+
) -> torch.Tensor:
250+
"""
251+
Predicts the previous sample based on the current sample, timestep indices, and running model outputs.
252+
253+
Args:
254+
sample (`torch.Tensor`):
255+
The current sample.
256+
timestep_index (`int`):
257+
Index of the current timestep in the schedule.
258+
prev_timestep_index (`int`):
259+
Index of the previous timestep in the schedule.
260+
ets (`torch.Tensor`):
261+
The running sequence of model outputs.
262+
263+
Returns:
264+
`torch.Tensor`:
265+
The predicted previous sample.
266+
"""
232267
alpha = self.alphas[timestep_index]
233268
sigma = self.betas[timestep_index]
234269

@@ -240,5 +275,5 @@ def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
240275

241276
return prev_sample
242277

243-
def __len__(self):
278+
def __len__(self) -> int:
244279
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)