@@ -31,14 +31,18 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
3131 Args:
3232 num_train_timesteps (`int`, defaults to 1000):
3333 The number of diffusion steps to train the model.
34- trained_betas (`np.ndarray`, *optional*):
34+ trained_betas (`np.ndarray` or `List[float]` , *optional*):
3535 Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
3636 """
3737
3838 order = 1
3939
4040 @register_to_config
41- def __init__ (self , num_train_timesteps : int = 1000 , trained_betas : np .ndarray | list [float ] | None = None ):
41+ def __init__ (
42+ self ,
43+ num_train_timesteps : int = 1000 ,
44+ trained_betas : np .ndarray | list [float ] | None = None ,
45+ ):
4246 # set `betas`, `alphas`, `timesteps`
4347 self .set_timesteps (num_train_timesteps )
4448
@@ -56,21 +60,29 @@ def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray |
5660 self ._begin_index = None
5761
5862 @property
59- def step_index (self ):
63+ def step_index (self ) -> int | None :
6064 """
6165 The index counter for current timestep. It will increase 1 after each scheduler step.
66+
67+ Returns:
68+ `int` or `None`:
69+ The index counter for current timestep.
6270 """
6371 return self ._step_index
6472
6573 @property
66- def begin_index (self ):
74+ def begin_index (self ) -> int | None :
6775 """
6876 The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
77+
78+ Returns:
79+ `int` or `None`:
80+ The index for the first timestep.
6981 """
7082 return self ._begin_index
7183
7284 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
73- def set_begin_index (self , begin_index : int = 0 ):
85+ def set_begin_index (self , begin_index : int = 0 ) -> None :
7486 """
7587 Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
7688
@@ -169,7 +181,7 @@ def step(
169181 Args:
170182 model_output (`torch.Tensor`):
171183 The direct output from learned diffusion model.
172- timestep (`int`):
184+ timestep (`int` or `torch.Tensor` ):
173185 The current discrete timestep in the diffusion chain.
174186 sample (`torch.Tensor`):
175187 A current instance of a sample created by the diffusion process.
@@ -228,7 +240,30 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens
228240 """
229241 return sample
230242
231- def _get_prev_sample (self , sample , timestep_index , prev_timestep_index , ets ):
243+ def _get_prev_sample (
244+ self ,
245+ sample : torch .Tensor ,
246+ timestep_index : int ,
247+ prev_timestep_index : int ,
248+ ets : torch .Tensor ,
249+ ) -> torch .Tensor :
250+ """
251+ Predicts the previous sample based on the current sample, timestep indices, and running model outputs.
252+
253+ Args:
254+ sample (`torch.Tensor`):
255+ The current sample.
256+ timestep_index (`int`):
257+ Index of the current timestep in the schedule.
258+ prev_timestep_index (`int`):
259+ Index of the previous timestep in the schedule.
260+ ets (`torch.Tensor`):
261+ The running sequence of model outputs.
262+
263+ Returns:
264+ `torch.Tensor`:
265+ The predicted previous sample.
266+ """
232267 alpha = self .alphas [timestep_index ]
233268 sigma = self .betas [timestep_index ]
234269
@@ -240,5 +275,5 @@ def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
240275
241276 return prev_sample
242277
243- def __len__ (self ):
278+ def __len__ (self ) -> int :
244279 return self .config .num_train_timesteps
0 commit comments