Skip to content

Commit 51906b8

Browse files
committed
[schedulers] fix RecursionError in CosineDPMSolverMultistepScheduler
`CosineDPMSolverMultistepScheduler.step` initialised `BrownianTreeNoiseSampler` with `sigma_min`/`sigma_max` from the config, but the sampler is queried with `self.sigmas[step_index]` values that drift outside those bounds: the Karras/exponential reconstruction of the endpoints in fp32 lands a few ULPs off, and `final_sigmas_type="zero"` makes the last `sigmas` entry strictly below `config.sigma_min`. Out-of-range queries push torchsde into unbounded recursive interval splitting and trip Python's recursion limit (#13274). Initialise the sampler with the actual `self.sigmas` extrema instead, matching the pattern in `scheduling_dpmsolver_sde.py`. Adds a regression test covering both Karras and exponential schedules with `final_sigmas_type="zero"`.
1 parent 8f14cde commit 51906b8

2 files changed

Lines changed: 48 additions & 2 deletions

File tree

src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,16 @@ def step(
653653
seed = (
654654
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
655655
)
656+
# Use the actual extrema of `self.sigmas` rather than the config bounds:
657+
# the Karras/exponential reconstruction of the endpoints in fp32 can drift
658+
# by a few ULPs, and `final_sigmas_type="zero"` makes `sigmas[-1] == 0`,
659+
# both of which fall outside `[config.sigma_min, config.sigma_max]`. An
660+
# out-of-range query drives `torchsde` into unbounded recursive splitting
661+
# of its Brownian interval and eventually raises `RecursionError` (#13274).
656662
self.noise_sampler = BrownianTreeNoiseSampler(
657663
model_output,
658-
sigma_min=self.config.sigma_min,
659-
sigma_max=self.config.sigma_max,
664+
sigma_min=self.sigmas.min().item(),
665+
sigma_max=self.sigmas.max().item(),
660666
seed=seed,
661667
)
662668
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
import warnings
3+
4+
import torch
5+
6+
from diffusers import CosineDPMSolverMultistepScheduler
7+
8+
from ..testing_utils import require_torchsde
9+
10+
11+
@require_torchsde
12+
class CosineDPMSolverMultistepSchedulerTest(unittest.TestCase):
13+
"""Regression tests for `CosineDPMSolverMultistepScheduler` (used by Stable Audio Open)."""
14+
15+
def _run_loop(self, **scheduler_kwargs):
16+
scheduler = CosineDPMSolverMultistepScheduler(**scheduler_kwargs)
17+
scheduler.set_timesteps(num_inference_steps=10, device="cpu")
18+
sample = torch.randn(1, 4, 8)
19+
generator = torch.Generator().manual_seed(0)
20+
for t in scheduler.timesteps:
21+
model_output = torch.randn_like(sample)
22+
sample = scheduler.step(model_output, t, sample, generator=generator).prev_sample
23+
return sample
24+
25+
def test_step_does_not_recurse_with_zero_final_sigma(self):
26+
# See https://github.com/huggingface/diffusers/issues/13274. With the defaults
27+
# used by Stable Audio Open (sigma_min=0.3, sigma_max=500, final_sigmas_type="zero")
28+
# querying the Brownian sampler at sigma_next=0 used to fall below the configured
29+
# `sigma_min` interval and recurse until Python's recursion limit was hit.
30+
for sigma_schedule in ("exponential", "karras"):
31+
with self.subTest(sigma_schedule=sigma_schedule):
32+
with warnings.catch_warnings():
33+
warnings.simplefilter("ignore")
34+
sample = self._run_loop(
35+
sigma_schedule=sigma_schedule,
36+
final_sigmas_type="zero",
37+
sigma_min=0.3,
38+
sigma_max=500.0,
39+
)
40+
self.assertFalse(torch.isnan(sample).any().item())

0 commit comments

Comments
 (0)