Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@
title: LCMScheduler
- local: api/schedulers/lms_discrete
title: LMSDiscreteScheduler
- local: api/schedulers/ltx_euler_ancestral_rf
title: LTXEulerAncestralRFScheduler
- local: api/schedulers/pndm
title: PNDMScheduler
- local: api/schedulers/repaint
Expand Down
45 changes: 45 additions & 0 deletions docs/source/en/api/schedulers/ltx_euler_ancestral_rf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# LTXEulerAncestralRFScheduler

The `LTXEulerAncestralRFScheduler` implements a K-diffusion-style Euler-Ancestral sampler
for flow / CONST parameterization, closely mirroring ComfyUI's `sample_euler_ancestral_RF`
implementation used for [LTX-Video](https://huggingface.co/docs/diffusers/api/pipelines/ltx_video).

The scheduler operates on a normalized sigma schedule σ ∈ [0, 1] and reconstructs the clean
estimate as `x0 = x_t − σ_t · v_t` (CONST parametrization). Stochastic noise reinjection is
controlled by `eta` (`eta=0` gives a deterministic Euler step; `eta=1` matches ComfyUI's
default RF behavior).

This scheduler is used by [`LTXPipeline`], [`LTXImageToVideoPipeline`], and
[`LTXConditionPipeline`].

The `eta` parameter must be >= 0. `eta=0` gives a deterministic (DDIM-like) Euler step;
`eta=1` matches ComfyUI's default RF behavior. Values above 1 are accepted but trigger a
one-time warning when the schedule step is too coarse to keep `sigma_down` non-negative.

<Tip>

See also [`FlowMatchEulerDiscreteScheduler`], which this scheduler delegates to for
auto-generated sigma schedules and shares config compatibility with via `_compatibles`.

</Tip>

## LTXEulerAncestralRFScheduler
[[autodoc]] LTXEulerAncestralRFScheduler

## LTXEulerAncestralRFSchedulerOutput
[[autodoc]] schedulers.scheduling_ltx_euler_ancestral_rf.LTXEulerAncestralRFSchedulerOutput
49 changes: 38 additions & 11 deletions src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
# Copyright 2025 Lightricks, Vittoria Lanzo and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,8 +65,9 @@ class LTXEulerAncestralRFScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps (`int`, defaults to 1000):
Included for config compatibility; not used to build the schedule.
eta (`float`, defaults to 1.0):
Stochasticity parameter. `eta=0.0` yields deterministic DDIM-like sampling; `eta=1.0` matches ComfyUI's
default RF behavior.
Stochasticity parameter. Must be >= 0. `eta=0.0` yields deterministic DDIM-like sampling; `eta=1.0`
matches ComfyUI's default RF behavior. Values above 1.0 are accepted but will trigger clamping of
`sigma_down` to [0, sigma_next] with a one-time warning when the schedule step is too coarse.
s_noise (`float`, defaults to 1.0):
Global scaling factor for the stochastic noise term.
"""
Expand All @@ -82,12 +83,15 @@ def __init__(
eta: float = 1.0,
s_noise: float = 1.0,
):
if eta < 0:
raise ValueError(f"`eta` must be >= 0, got {eta}.")
# Note: num_train_timesteps is kept only for config compatibility.
self.num_inference_steps: int = None
self.sigmas: torch.Tensor | None = None
self.timesteps: torch.Tensor | None = None
self._step_index: int = None
self._begin_index: int = None
self._sigma_down_warned: bool = False # deduplication flag for sigma_down clamp warning

@property
def step_index(self) -> int:
Expand Down Expand Up @@ -233,12 +237,23 @@ def set_timesteps(
if sigmas_tensor.ndim != 1:
raise ValueError(f"`sigmas` must be a 1D tensor, got shape {tuple(sigmas_tensor.shape)}.")

if sigmas_tensor[0].item() > 1.0 + 1e-6:
raise ValueError(
f"`sigmas` values must be in [0, 1] for RF/CONST parameterization, "
f"got max={sigmas_tensor[0].item():.6f}."
)

if len(sigmas_tensor) > 1 and not (sigmas_tensor[:-1] >= sigmas_tensor[1:]).all():
sig_list = sigmas_tensor.tolist()
sig_repr = str(sig_list) if len(sig_list) <= 8 else f"{sig_list[:4]} ... {sig_list[-4:]} (len={len(sig_list)})"
raise ValueError(
f"`sigmas` must be monotonically non-increasing (each entry >= the next), got {sig_repr}"
)

if sigmas_tensor[-1].abs().item() > 1e-6:
logger.warning(
"The last sigma in the schedule is not zero (%.6f). "
"For best compatibility with ComfyUI's RF sampler, the terminal sigma "
"should be 0.0.",
sigmas_tensor[-1].item(),
f"The last sigma in the schedule is not zero ({sigmas_tensor[-1].item():.6f}). "
f"For best compatibility with ComfyUI's RF sampler, the terminal sigma should be 0.0."
)

# Move to device once, then derive timesteps.
Expand All @@ -256,10 +271,8 @@ def set_timesteps(

if num_inference_steps is not None and num_inference_steps != len(sigmas) - 1:
logger.warning(
"Provided `num_inference_steps=%d` does not match `len(sigmas)-1=%d`. "
"Overriding `num_inference_steps` with `len(sigmas)-1`.",
num_inference_steps,
len(sigmas) - 1,
f"Provided `num_inference_steps={num_inference_steps}` does not match `len(sigmas)-1={len(sigmas) - 1}`. "
f"Overriding `num_inference_steps` with `len(sigmas)-1`."
)

self.num_inference_steps = len(sigmas) - 1
Expand Down Expand Up @@ -345,6 +358,20 @@ def step(
downstep_ratio = 1.0 + (sigma_next / sigma - 1.0) * eta
sigma_down = sigma_next * downstep_ratio

# sigma_down can go negative when eta > 1 on a coarse schedule step, which
# flips sigma_ratio and corrupts the Euler update. Clamp to [0, +inf) and
# emit a one-time warning so the user knows to reduce eta or refine the schedule.
# (sigma_down > sigma_next is not reachable under a valid monotone schedule.)
if sigma_down.item() < 0:
if not self._sigma_down_warned:
logger.warning(
f"`eta`={eta:.3f} caused `sigma_down`={sigma_down.item():.6f} to go negative "
f"(sigma={sigma.item():.6f}, sigma_next={sigma_next.item():.6f}). "
f"Clamping to 0. Reduce `eta` or use a finer schedule to avoid this."
)
self._sigma_down_warned = True
sigma_down = sigma_down.clamp(min=0.0)

alpha_ip1 = 1.0 - sigma_next
alpha_down = 1.0 - sigma_down

Expand Down
222 changes: 222 additions & 0 deletions tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright 2025 Vittoria Lanzo and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers import LTXEulerAncestralRFScheduler


def _make_scheduler(**kwargs):
config = {"num_train_timesteps": 1000, "eta": 1.0, "s_noise": 1.0}
config.update(kwargs)
return LTXEulerAncestralRFScheduler(**config)


def _linear_sigmas(n=4):
"""Return a monotonically decreasing sigma schedule with terminal 0."""
return [round(1.0 - i / n, 6) for i in range(n + 1)]


class LTXEulerAncestralRFSchedulerTest(unittest.TestCase):
# ------------------------------------------------------------------
# set_timesteps: input validation
# ------------------------------------------------------------------

def test_set_timesteps_explicit_sigmas_valid(self):
scheduler = _make_scheduler()
scheduler.set_timesteps(sigmas=_linear_sigmas(4))
self.assertEqual(scheduler.num_inference_steps, 4)
self.assertEqual(len(scheduler.sigmas), 5)

def test_set_timesteps_non_monotone_raises(self):
"""
Non-monotonically-decreasing sigmas must raise ValueError.
Without this check, step() computes sigma_down outside [0, 1]
and sigma_ratio >> 1, silently amplifying the latent.
"""
scheduler = _make_scheduler()
# sigma increases at step 0 -> 1
with self.assertRaises(ValueError):
scheduler.set_timesteps(sigmas=[0.2, 0.8, 0.5, 0.0])

def test_set_timesteps_fully_ascending_raises(self):
scheduler = _make_scheduler()
with self.assertRaises(ValueError):
scheduler.set_timesteps(sigmas=[0.0, 0.5, 1.0])

def test_set_timesteps_plateau_is_valid(self):
"""Equal consecutive sigmas (plateau steps) must NOT raise — used in img2img partial schedules."""
scheduler = _make_scheduler()
# plateau at the first two entries is intentional in some set_begin_index workflows
scheduler.set_timesteps(sigmas=[1.0, 1.0, 0.5, 0.0])
self.assertEqual(scheduler.num_inference_steps, 3)

def test_set_timesteps_num_inference_steps_auto(self):
"""Auto-generated schedule (no explicit sigmas) must initialise correctly."""
scheduler = _make_scheduler()
scheduler.set_timesteps(num_inference_steps=10)
self.assertEqual(scheduler.num_inference_steps, 10)
self.assertEqual(len(scheduler.sigmas), 11) # N steps + terminal 0
# Verify the auto-generated schedule is itself monotone
sigmas = scheduler.sigmas
self.assertTrue(
(sigmas[:-1] >= sigmas[1:]).all(),
"Auto-generated sigma schedule is not monotonically non-increasing.",
)

# ------------------------------------------------------------------
# step(): output invariants
# ------------------------------------------------------------------

def test_step_output_dtype_fp16_preserved(self):
"""prev_sample.dtype must equal sample.dtype for fp16 inputs."""
scheduler = _make_scheduler()
scheduler.set_timesteps(sigmas=_linear_sigmas(4))
sample = torch.randn(1, 4, 8, 8, dtype=torch.float16)
model_output = torch.randn_like(sample)
out = scheduler.step(model_output, scheduler.timesteps[0], sample)
self.assertEqual(out.prev_sample.dtype, torch.float16)

def test_step_output_dtype_fp32_preserved(self):
"""prev_sample.dtype must equal sample.dtype for fp32 inputs."""
scheduler = _make_scheduler()
scheduler.set_timesteps(sigmas=_linear_sigmas(4))
sample = torch.randn(1, 4, 8, 8, dtype=torch.float32)
model_output = torch.randn_like(sample)
out = scheduler.step(model_output, scheduler.timesteps[0], sample)
self.assertEqual(out.prev_sample.dtype, torch.float32)

def test_step_output_shape_preserved(self):
"""prev_sample.shape must equal sample.shape."""
scheduler = _make_scheduler()
scheduler.set_timesteps(sigmas=_linear_sigmas(4))
sample = torch.randn(2, 4, 16, 16)
model_output = torch.randn_like(sample)
out = scheduler.step(model_output, scheduler.timesteps[0], sample)
self.assertEqual(out.prev_sample.shape, sample.shape)

def test_step_return_tuple(self):
"""return_dict=False must return a tuple whose first element matches return_dict=True."""
scheduler = _make_scheduler()
scheduler.set_timesteps(sigmas=_linear_sigmas(4))
sample = torch.randn(1, 4, 8, 8)
model_output = torch.randn_like(sample)
t = scheduler.timesteps[0]

torch.manual_seed(0)
out_dict = scheduler.step(model_output, t, sample, return_dict=True)
scheduler._step_index = None # reset step index to replay the same step
torch.manual_seed(0)
out_tuple = scheduler.step(model_output, t, sample, return_dict=False)

self.assertIsInstance(out_tuple, tuple)
self.assertTrue(torch.allclose(out_dict.prev_sample, out_tuple[0]))

def test_step_eta_zero_is_deterministic(self):
"""
With eta=0 no noise is injected; the output must be identical regardless
of the generator seed passed.
"""
scheduler = _make_scheduler(eta=0.0)
scheduler.set_timesteps(sigmas=_linear_sigmas(4))
sample = torch.randn(1, 4, 8, 8, generator=torch.Generator().manual_seed(0))
model_output = torch.randn(1, 4, 8, 8, generator=torch.Generator().manual_seed(1))
t = scheduler.timesteps[0]

out1 = scheduler.step(model_output, t, sample).prev_sample

scheduler._step_index = None
out2 = scheduler.step(
model_output, t, sample, generator=torch.Generator().manual_seed(99)
).prev_sample

self.assertTrue(torch.allclose(out1, out2), "eta=0 step should be fully deterministic.")

def test_step_final_step_returns_denoised(self):
"""At sigma=0 (final denoising step) prev_sample must equal the denoised estimate."""
scheduler = _make_scheduler(eta=1.0)
# Two-step schedule: [0.5, 0.0]
scheduler.set_timesteps(sigmas=[0.5, 0.0])
sample = torch.randn(1, 4, 8, 8)
model_output = torch.randn_like(sample)

# First (and only real) step
out = scheduler.step(model_output, scheduler.timesteps[0], sample)
# At sigma_next=0 the scheduler must return the clean estimate x0 = x_t - sigma*v_t
expected = sample - 0.5 * model_output
self.assertTrue(torch.allclose(out.prev_sample, expected, atol=1e-5))

def test_set_timesteps_sigma_above_one_raises(self):
"""Sigmas outside [0, 1] violate the RF/CONST parameterization assumption."""
scheduler = _make_scheduler()
with self.assertRaises(ValueError):
scheduler.set_timesteps(sigmas=[2.0, 1.0, 0.5, 0.0])

def test_step_eta_negative_raises(self):
"""eta < 0 is invalid and must raise ValueError at construction time."""
with self.assertRaises(ValueError):
_make_scheduler(eta=-0.1)

def test_step_eta_greater_than_one_clamps_sigma_down(self):
"""eta > 1 on a coarse schedule pushes sigma_down < 0; must clamp, warn once, and stay finite."""
scheduler = _make_scheduler(eta=2.0)
# Coarse schedule: large step size maximises the chance sigma_down goes negative
scheduler.set_timesteps(sigmas=[0.5, 0.1, 0.0])
sample = torch.randn(1, 4, 8, 8)
model_output = torch.randn_like(sample)
self.assertFalse(scheduler._sigma_down_warned)

out = scheduler.step(model_output, scheduler.timesteps[0], sample)

# Warning flag must be set (warning was emitted)
self.assertTrue(scheduler._sigma_down_warned)
# Output must be finite (clamp prevented NaN/Inf from negative sigma_down)
self.assertTrue(torch.isfinite(out.prev_sample).all())

# Second step must NOT re-emit (deduplication)
scheduler._sigma_down_warned_count_before = True # flag already True
out2 = scheduler.step(model_output, scheduler.timesteps[1], sample)
self.assertTrue(torch.isfinite(out2.prev_sample).all())

def test_step_index_advances(self):
"""_step_index must increment by 1 on each call."""
scheduler = _make_scheduler()
scheduler.set_timesteps(sigmas=_linear_sigmas(4))
sample = torch.randn(1, 4, 8, 8)
model_output = torch.randn_like(sample)

for expected_idx in range(4):
scheduler.step(model_output, scheduler.timesteps[expected_idx], sample)
self.assertEqual(scheduler._step_index, expected_idx + 1)

def test_step_beyond_end_returns_sample(self):
"""Calling step() past the last index must return the input sample unchanged."""
scheduler = _make_scheduler(eta=0.0)
scheduler.set_timesteps(sigmas=[0.5, 0.0])
sample = torch.randn(1, 4, 8, 8)
model_output = torch.randn_like(sample)

# Consume all steps normally
scheduler.step(model_output, scheduler.timesteps[0], sample)
# Force _step_index to the clamped maximum
scheduler._step_index = len(scheduler.sigmas) - 1
# A further call must not crash and must return a finite tensor
out = scheduler.step(model_output, scheduler.timesteps[-1], sample)
self.assertTrue(torch.isfinite(out.prev_sample).all())


if __name__ == "__main__":
unittest.main()