@@ -50,6 +50,29 @@ def retrieve_timesteps(
5050 sigmas : list [float ] | None = None ,
5151 ** kwargs ,
5252):
53+ r"""
54+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
55+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
56+
57+ Args:
58+ scheduler (`SchedulerMixin`):
59+ The scheduler to get timesteps from.
60+ num_inference_steps (`int`):
61+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
62+ must be `None`.
63+ device (`str` or `torch.device`, *optional*):
64+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
65+ timesteps (`list[int]`, *optional*):
66+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
67+ `num_inference_steps` and `sigmas` must be `None`.
68+ sigmas (`list[float]`, *optional*):
69+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
70+ `num_inference_steps` and `timesteps` must be `None`.
71+
72+ Returns:
73+ `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
74+ second element is the number of inference steps.
75+ """
5376 if timesteps is not None and sigmas is not None :
5477 raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
5578 if timesteps is not None :
0 commit comments