Skip to content

Commit 2510dba

Browse files
committed
fix ci failures
1 parent bbc919e commit 2510dba

2 files changed

Lines changed: 24 additions & 3 deletions

File tree

src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,29 @@ def retrieve_timesteps(
5050
sigmas: list[float] | None = None,
5151
**kwargs,
5252
):
53+
r"""
54+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
55+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
56+
57+
Args:
58+
scheduler (`SchedulerMixin`):
59+
The scheduler to get timesteps from.
60+
num_inference_steps (`int`):
61+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
62+
must be `None`.
63+
device (`str` or `torch.device`, *optional*):
64+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
65+
timesteps (`list[int]`, *optional*):
66+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
67+
`num_inference_steps` and `sigmas` must be `None`.
68+
sigmas (`list[float]`, *optional*):
69+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
70+
`num_inference_steps` and `timesteps` must be `None`.
71+
72+
Returns:
73+
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
74+
second element is the number of inference steps.
75+
"""
5376
if timesteps is not None and sigmas is not None:
5477
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
5578
if timesteps is not None:

src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@
3535

3636
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
3737
def retrieve_latents(
38-
encoder_output: torch.Tensor,
39-
generator: torch.Generator | None = None,
40-
sample_mode: str = "sample",
38+
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
4139
):
4240
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
4341
return encoder_output.latent_dist.sample(generator)

0 commit comments

Comments
 (0)