Skip to content

Commit 1c8e435

Browse files
committed
Fix denoising_start/end with higher orders
1 parent 0fff459 commit 1c8e435

2 files changed

Lines changed: 7 additions & 12 deletions

File tree

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,8 +1175,10 @@ def __call__(
11751175
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
11761176
)
11771177
)
1178-
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1179-
timesteps = timesteps[:num_inference_steps]
1178+
num_inference_steps = (
1179+
(torch.as_tensor(timesteps)[:: self.scheduler.order] >= discrete_timestep_cutoff).sum().item()
1180+
)
1181+
timesteps = timesteps[: num_inference_steps * self.scheduler.order]
11801182

11811183
# 9. Optionally get Guidance Scale Embedding
11821184
timestep_cond = None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -666,18 +666,11 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
666666
)
667667
)
668668

669-
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
670-
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
671-
# if the scheduler is a 2nd order scheduler we might have to do +1
672-
# because `num_inference_steps` might be even given that every timestep
673-
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
674-
# mean that we cut the timesteps in the middle of the denoising step
675-
# (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
676-
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
677-
num_inference_steps = num_inference_steps + 1
669+
real_timesteps = self.scheduler.timesteps[:: self.scheduler.order]
670+
num_inference_steps = (real_timesteps < discrete_timestep_cutoff).sum().item()
678671

679672
# because t_n+1 >= t_n, we slice the timesteps starting from the end
680-
t_start = len(self.scheduler.timesteps) - num_inference_steps
673+
t_start = (len(real_timesteps) - num_inference_steps) * self.scheduler.order
681674
timesteps = self.scheduler.timesteps[t_start:]
682675
if hasattr(self.scheduler, "set_begin_index"):
683676
self.scheduler.set_begin_index(t_start)

0 commit comments

Comments
 (0)