Skip to content

Commit 25bbf32

Browse files
kappacommitYour NameclaudeJPPhoto
authored
feat(model): Add ER SDE / DPM++ 2M Scheduler Support For Anima (invoke-ai#9125)
* refactor(anima): reshape ANIMA_SCHEDULER_MAP to (class, kwargs) tuples Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(anima): address Task 1 review feedback Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat(anima): add dpmpp_2m and dpmpp_2m_sde schedulers * refactor(anima): unify ANIMA_SHIFT in schedulers.py and add Literal-coverage test * fix(anima): seed generator into scheduler.step for SDE reproducibility * feat(anima): add pure ancestral-Euler step helper Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * docs(anima): fix _anima_euler_ancestral_step docstring formula to match code Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * refactor(anima): address Task 4 review feedback * feat(anima): add euler_a (rectified-flow ancestral Euler) scheduler * fix(anima): sample euler_a noise in float32 to avoid bfloat16 quantization * chore(anima): bump anima_denoise to v1.3.0 and regen schema * fix(frontend): revert Windows path-separator drift in schema regen * fix(anima): gate step_generator construction to schedulers that need it Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * chore(anima): apply ruff lint and format fixes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat(frontend): expose new Anima schedulers in dropdown and metadata recall * chore(frontend): apply prettier wrap to setAnimaScheduler PayloadAction * fix(anima): correct euler_a math — variance-preserving noise mix, not biased Euler * revert(anima): remove euler_a scheduler — quality not worth the complexity * chore(anima): apply ruff format (trim trailing blank lines) * feat(rectified-flow): add order-1 ER-SDE stepper for rectified flow Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * docs(rectified-flow): document lambda_next>0 precondition; tighten terminal-step test Address code review feedback on order-1 ER-SDE stepper: - Add docstring preconditions to integral helpers noting the logarithmic singularity at lam=0 (callers must guard sigma_next>0). - Tighten terminal-step test from atol=1e-5 to torch.equal — the algebra is exact when sigma_next=0, not approximate. * feat(rectified-flow): add 2nd-order Taylor extension to ER-SDE stepper * test(rectified-flow): tighten Task 2 state-mutation test with value assertion Address code review feedback on the 2nd-order Taylor extension tests: - Assert state.old_d_x0 equals the analytically-computed d_x0 = 0.2v / (1.5 - 4.0) rather than just checking it's non-None. - Document that x_t is intentionally re-used across calls (state threading test, not trajectory correctness). - Document the order-2 correction coefficient magnitude that justifies the atol=1e-3 threshold in the engagement test. * feat(rectified-flow): add 3rd-order Taylor extension to ER-SDE stepper * docs(rectified-flow): document have_two_back invariant + order-3 test margin Address code review feedback on Task 3: - Comment why have_two_back checks both old_d_x0 and sigma_prev_prev (the sigma~=1 boundary path can break the joint invariant). - Document the analytically verified ~0.0004 per-element correction magnitude that justifies the atol=1e-3 threshold in the order-3 test. * feat(anima): register er_sde scheduler choice * docs(anima): document custom-code-path scheduler convention; tighten test Address code review feedback on Task 4: - Add an in-file comment above ANIMA_SCHEDULER_MAP explaining the convention: schedulers with custom code paths (er_sde) live in the Literal+labels only, not the map. - Hoist `import typing` to module-level in test_anima_schedulers.py (was inline-imported in two test functions). - Pin the er_sde label value (== "ER-SDE"), not just key existence. * feat(anima): wire er_sde scheduler into denoise loop * docs(anima): document float32 noise dtype and sigma_next/sigma_prev naming Address code review feedback on Task 5: - Comment explaining why fresh_noise is float32 (matches er_sde_rf_step's dtype contract with x_t.to(float32)). - Bridging comment at the inpaint extension call clarifying that sigma_next here means the same thing as sigma_prev in the Euler branch and the AnimaInpaintExtension API. * chore(anima): bump anima_denoise to v1.4.0 and regen schema * feat(frontend): expose er_sde scheduler in dropdown and metadata recall Address code review on Task 6 — er_sde was registered in the OpenAPI schema but missing from the frontend's own Zod enums and Redux PayloadAction types, so: - The combobox dropdown didn't list it. - setAnimaScheduler('er_sde') would fail TypeScript at the call site. - Metadata recall for er_sde-generated images would silently no-op (the scheduler value couldn't pass zParameterScheduler validation). Changes: - Add er_sde to zAnimaSchedulerField (the per-Anima Zod enum). - Add er_sde to the animaScheduler state-shape Zod enum. - Widen setAnimaScheduler's PayloadAction union. - Add ER-SDE option to the ParamAnimaScheduler combobox. - Make the metadata Scheduler handler accept ParameterAnimaScheduler too, with a fallback parse and a narrowing guard before the SD/SDXL dispatch. * fix(rectified-flow): guard 2nd-order branch against sigma_prev_curr=1.0 When step 0 goes through the sigma_curr=1 closed-form limit branch it writes state.sigma_prev_curr=1.0. On step 1, have_one_back was True and the 2nd-order path called _lambda(1.0) = 1.0/(1.0-1.0), crashing with ZeroDivisionError in every real denoise run. Fix: extend the have_one_back guard to require that sigma_prev_curr is more than _SIGMA_ONE_TOLERANCE below 1.0. The finite-difference derivative across the limit step is not meaningful, so skipping the 2nd-order term on that transition is correct. Order-3 is already gated behind old_d_x0 being set, which this path never sets, so no additional guard is needed there. Adds a regression test that runs the full sigma=1.0->0.9->0.7->0.0 sequence and asserts no ZeroDivisionError and all-finite output. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat(scheduler): add ERSDEScheduler available to SD/SDXL ER-SDE solver (Cui et al., arXiv:2309.06169) usable across SD/SDXL (VP-SDE) and rectified-flow models. Anima migration follows in subsequent commits. - ERSDEScheduler(SchedulerMixin, ConfigMixin) with prediction_type (epsilon | v_prediction | flow_prediction), use_flow_sigmas, solver_order (1/2/3 with auto-warmup), and stochastic toggle - set_timesteps(sigmas=) for pre-shifted Anima/FLUX/Z-Image schedules - Closed-form limit at sigma=1 in flow mode - Unit tests + VP smoke + 5/5 Anima parity vs er_sde_rf_step (worst delta 5.137e-07) - Frontend wiring: zSchedulerField, SCHEDULER_OPTIONS, OpenAPI regen - parsing.tsx cleanup: removes the AnyScheduler widening since er_sde is now a first-class general scheduler Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(anima): wire ERSDEScheduler into ANIMA_SCHEDULER_MAP Adds er_sde to the standard scheduler dispatch map with rectified-flow kwargs (flow_prediction, use_flow_sigmas=True, flow_shift=3.0, solver_order=3, stochastic=True). Anima still routes through the legacy elif is_er_sde: branch — that's removed in a follow-up commit. This is the additive prerequisite that lets the cutover happen without a window where Anima can't use ER-SDE. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(anima): add ER-SDE dispatch integration tests Verifies the ANIMA_SCHEDULER_MAP['er_sde'] entry instantiates correctly, accepts pre-shifted sigmas via set_timesteps(sigmas=...), and resets multistep state. Catches wiring regressions that the algorithm-level parity test cannot. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * refactor(anima): remove elif is_er_sde branch, dispatch through ANIMA_SCHEDULER_MAP Anima ER-SDE now flows through the same standard scheduler path as dpmpp_2m_sde — pre-shifted sigmas via scheduler.set_timesteps(sigmas=...), inpaint extension via inpaint_extension.merge_intermediate_latents_with_init_latents, step_callback per-step. The custom code path was the only thing keeping ER-SDE off the universal pipeline. Bumps invocation version to 1.5.0. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * docs(er_sde): mark module as internal reference and parity oracle ERSDEScheduler is now the production code path. er_sde_rf_step is kept as the comparison oracle for the scheduler's parity test, and as a self-contained mathematical reference for the rectified-flow algebra. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * chore(rectified-flow): remove er_sde.py reference helper ERSDEScheduler is the production code path. The pure-function helper was retained as a parity oracle but YAGNI — keeping ~200 lines of code purely as a regression net for hypothetical future drift isn't worth the maintenance signal it generates. Removes: - invokeai/backend/rectified_flow/er_sde.py - tests/backend/rectified_flow/test_er_sde.py - tests/backend/rectified_flow/test_er_sde_scheduler_anima_parity.py ERSDEScheduler's own tests (test_er_sde_scheduler.py) remain — they exercise both VP-SDE and rectified-flow paths directly. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(schema): restore forward slashes in @default cache dir paths Windows-side typegen run flipped these to backslashes. Restore the canonical forward-slash form to match upstream. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * style: apply ruff lint and format to ER-SDE files Sort imports + format per project ruff config (line-length 120). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * refactor(state): use zParameterAnimaScheduler in state shape The animaScheduler field inlined its enum (originally to add er_sde when the shared schema didn't have it yet). Now that zAnimaSchedulerField already includes er_sde, reference the shared zParameterAnimaScheduler to match the pattern used by scheduler/fluxScheduler/zImageScheduler. Drops the redundant .default('euler') — initial value comes from getInitialParamsState. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * style(types): sort imports per simple-import-sort Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(anima): honour clipped sigma schedule for DPM++ img2img/inpaint DPMSolverMultistepScheduler doesn't accept sigmas= in diffusers 0.35.1, so the fallback previously called set_timesteps(num_inference_steps=total_steps) which regenerated a full schedule from sigma_max, ignoring denoising_start/end. When the scheduler supports set_begin_index, call set_timesteps with the full step count and offset into it, so the internal flow_shift applies correctly and denoising starts at the right sigma. Also fixes the inpaint sigma_prev lookup and the timestep loop to use the same offset, and corrects the false parity-test reference in the ER-SDE dispatch test docstring. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * style: apply ruff format to anima_denoise dispatch block Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(anima): extract scheduler driver and fix Heun progress/inpaint bugs Encapsulate per-scheduler dispatch quirks (sigmas= vs num_inference_steps=, Heun's doubled-array index, set_begin_index path) in AnimaSchedulerDriver and tighten two latent bugs in the Heun path of anima_denoise: * Heun's terminal first-order step never reported a user-step completion, so progress capped at N-1 of N. The driver now flags it via sigma_prev==0, and the <= total_steps clamp that papered over the off-by-one is gone. * The inpaint mix ran after every Heun half-step, corrupting the second-order corrector's input (RectifiedFlowInpaintExtension's docstring says it should be called after each denoising step — i.e. once per user step). Mix is now gated on completes_user_step, which is unconditionally True for non-Heun. Also: Heun shift kwarg switched to ANIMA_SHIFT (its set_timesteps doesn't accept sigmas=, so it builds its own internal schedule); narrative comments in scheduler_driver and er_sde_scheduler trimmed; new tests covering driver iteration counts, terminal sigma_prev, seed determinism, and the begin_index fallback for clipped DPM++ schedules. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * style: apply ruff lint and format to anima scheduler driver Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * ci: re-run after flaky token-expiration test The 1-second JWT token-expiration test in test_token_service.py is timing sensitive — passes locally on retry. Empty commit to retrigger CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Your Name <you@example.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
1 parent fcc0881 commit 25bbf32

17 files changed

Lines changed: 1765 additions & 93 deletions

File tree

invokeai/app/invocations/anima_denoise.py

Lines changed: 41 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@
1616
- Anima uses 3D latents directly, Z-Image converts 4D -> list of 5D
1717
"""
1818

19-
import inspect
2019
import math
2120
from contextlib import ExitStack
2221
from typing import Callable, Iterator, Optional, Tuple
2322

2423
import torch
2524
import torchvision.transforms as tv_transforms
26-
from diffusers.schedulers.scheduling_utils import SchedulerMixin
2725
from torchvision.transforms.functional import resize as tv_resize
2826
from tqdm import tqdm
2927

@@ -42,7 +40,12 @@
4240
from invokeai.backend.anima.anima_transformer_patch import patch_anima_for_regional_prompting
4341
from invokeai.backend.anima.conditioning_data import AnimaRegionalTextConditioning, AnimaTextConditioning
4442
from invokeai.backend.anima.regional_prompting import AnimaRegionalPromptingExtension
45-
from invokeai.backend.flux.schedulers import ANIMA_SCHEDULER_LABELS, ANIMA_SCHEDULER_MAP, ANIMA_SCHEDULER_NAME_VALUES
43+
from invokeai.backend.anima.scheduler_driver import AnimaSchedulerDriver
44+
from invokeai.backend.flux.schedulers import (
45+
ANIMA_SCHEDULER_LABELS,
46+
ANIMA_SCHEDULER_NAME_VALUES,
47+
ANIMA_SHIFT,
48+
)
4649
from invokeai.backend.model_manager.taxonomy import BaseModelType
4750
from invokeai.backend.patches.layer_patcher import LayerPatcher
4851
from invokeai.backend.patches.lora_conversions.anima_lora_constants import ANIMA_LORA_TRANSFORMER_PREFIX
@@ -59,8 +62,6 @@
5962
ANIMA_LATENT_SCALE_FACTOR = 8
6063
# Anima uses 16 latent channels
6164
ANIMA_LATENT_CHANNELS = 16
62-
# Anima uses fixed shift=3.0 for the rectified flow schedule
63-
ANIMA_SHIFT = 3.0
6465
# Anima uses raw sigma values as timesteps (no rescaling)
6566
ANIMA_MULTIPLIER = 1.0
6667

@@ -165,7 +166,7 @@ def merge_intermediate_latents_with_init_latents(
165166
title="Denoise - Anima",
166167
tags=["image", "anima"],
167168
category="image",
168-
version="1.2.0",
169+
version="1.5.0",
169170
classification=Classification.Prototype,
170171
)
171172
class AnimaDenoiseInvocation(BaseInvocation):
@@ -491,22 +492,19 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
491492

492493
step_callback = self._build_step_callback(context)
493494

494-
# Initialize diffusers scheduler if not using built-in Euler
495-
scheduler: SchedulerMixin | None = None
495+
# Initialize scheduler driver if not using built-in Euler.
496496
use_scheduler = self.scheduler != "euler"
497-
497+
driver: AnimaSchedulerDriver | None = None
498498
if use_scheduler:
499-
scheduler_class = ANIMA_SCHEDULER_MAP[self.scheduler]
500-
scheduler = scheduler_class(num_train_timesteps=1000, shift=1.0)
501-
is_lcm = self.scheduler == "lcm"
502-
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
503-
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
504-
scheduler.set_timesteps(sigmas=sigmas, device=device)
505-
else:
506-
scheduler.set_timesteps(num_inference_steps=total_steps, device=device)
507-
num_scheduler_steps = len(scheduler.timesteps)
508-
else:
509-
num_scheduler_steps = total_steps
499+
driver = AnimaSchedulerDriver(
500+
scheduler_name=self.scheduler,
501+
sigmas=sigmas,
502+
steps=self.steps,
503+
denoising_start=self.denoising_start,
504+
denoising_end=self.denoising_end,
505+
device=device,
506+
seed=self.seed,
507+
)
510508

511509
with ExitStack() as exit_stack:
512510
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
@@ -587,19 +585,12 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor
587585
# t5xxl_ids=None skips the LLM Adapter — context is already pre-computed
588586
)
589587

590-
if use_scheduler and scheduler is not None:
591-
# Scheduler-based denoising
588+
if driver is not None:
592589
user_step = 0
593590
pbar = tqdm(total=total_steps, desc="Denoising (Anima)")
594-
for step_index in range(num_scheduler_steps):
595-
sched_timestep = scheduler.timesteps[step_index]
596-
sigma_curr = sched_timestep.item() / scheduler.config.num_train_timesteps
597-
598-
is_heun = hasattr(scheduler, "state_in_first_order")
599-
in_first_order = scheduler.state_in_first_order if is_heun else True
600-
591+
for it in driver.iterations():
601592
timestep = torch.tensor(
602-
[sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype
593+
[it.sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype
603594
).expand(latents.shape[0])
604595

605596
noise_pred_cond = _run_transformer(pos_context, latents, timestep).float()
@@ -610,48 +601,30 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor
610601
else:
611602
noise_pred = noise_pred_cond
612603

613-
step_output = scheduler.step(model_output=noise_pred, timestep=sched_timestep, sample=latents)
614-
latents = step_output.prev_sample
615-
616-
if step_index + 1 < len(scheduler.sigmas):
617-
sigma_prev = scheduler.sigmas[step_index + 1].item()
618-
else:
619-
sigma_prev = 0.0
604+
latents = driver.step(model_output=noise_pred, timestep=it.sched_timestep, sample=latents)
620605

621-
if inpaint_extension is not None:
622-
latents_4d = latents.squeeze(2)
623-
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(
624-
latents_4d, sigma_prev
625-
)
626-
latents = latents_4d.unsqueeze(2)
606+
if it.completes_user_step:
607+
# RectifiedFlowInpaintExtension expects this once per user step (its
608+
# docstring), so for Heun we skip the FO half of each pair to avoid
609+
# corrupting the second-order corrector's input.
610+
if inpaint_extension is not None:
611+
latents_4d = latents.squeeze(2)
612+
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(
613+
latents_4d, it.sigma_prev
614+
)
615+
latents = latents_4d.unsqueeze(2)
627616

628-
if is_heun:
629-
if not in_first_order:
630-
user_step += 1
631-
if user_step <= total_steps:
632-
pbar.update(1)
633-
step_callback(
634-
PipelineIntermediateState(
635-
step=user_step,
636-
order=2,
637-
total_steps=total_steps,
638-
timestep=int(sigma_curr * 1000),
639-
latents=latents.squeeze(2),
640-
)
641-
)
642-
else:
643617
user_step += 1
644-
if user_step <= total_steps:
645-
pbar.update(1)
646-
step_callback(
647-
PipelineIntermediateState(
648-
step=user_step,
649-
order=1,
650-
total_steps=total_steps,
651-
timestep=int(sigma_curr * 1000),
652-
latents=latents.squeeze(2),
653-
)
618+
pbar.update(1)
619+
step_callback(
620+
PipelineIntermediateState(
621+
step=user_step,
622+
order=it.order,
623+
total_steps=total_steps,
624+
timestep=int(it.sigma_curr * 1000),
625+
latents=latents.squeeze(2),
654626
)
627+
)
655628
pbar.close()
656629
else:
657630
# Built-in Euler implementation (default for Anima)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""Anima scheduler driver.
2+
3+
Encapsulates the per-scheduler API quirks that ``anima_denoise._run_diffusion``
4+
would otherwise have to know about:
5+
6+
* Schedulers that accept ``set_timesteps(sigmas=...)`` get the pre-shifted
7+
Anima schedule passed directly.
8+
* Schedulers that don't accept ``sigmas=`` use ``set_begin_index()`` over their
9+
own internal flow-shifted schedule. For Heun, the doubled-array index
10+
translation (logical step ``k`` → doubled index ``2k``) is handled here.
11+
* SDE-style schedulers receive a seeded ``torch.Generator`` on every step.
12+
13+
The denoise loop iterates :meth:`AnimaSchedulerDriver.iterations` and calls
14+
:meth:`AnimaSchedulerDriver.step` per iteration; the driver yields the
15+
``sigma_prev`` and ``completes_user_step`` flags the caller needs for inpaint
16+
mixing and progress reporting.
17+
"""
18+
19+
from __future__ import annotations
20+
21+
import inspect
22+
from dataclasses import dataclass
23+
from typing import Iterator
24+
25+
import torch
26+
from diffusers import FlowMatchHeunDiscreteScheduler
27+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
28+
29+
from invokeai.backend.flux.schedulers import ANIMA_SCHEDULER_MAP
30+
31+
32+
@dataclass(frozen=True)
33+
class AnimaSchedulerIteration:
34+
"""Per-iteration metadata yielded by :meth:`AnimaSchedulerDriver.iterations`.
35+
36+
``sigma_prev`` is the noise level the latents will be at after this iteration's
37+
:meth:`AnimaSchedulerDriver.step` call. ``completes_user_step`` is True when
38+
this iteration finishes a user-visible step — for Heun, the second-order
39+
half of each pair plus the unpaired terminal first-order step; for every
40+
other scheduler, always True.
41+
"""
42+
43+
sched_timestep: torch.Tensor
44+
sigma_curr: float
45+
sigma_prev: float
46+
completes_user_step: bool
47+
order: int
48+
49+
50+
class AnimaSchedulerDriver:
51+
"""Drives a diffusers scheduler over Anima's pre-shifted sigma schedule."""
52+
53+
def __init__(
54+
self,
55+
scheduler_name: str,
56+
sigmas: list[float],
57+
steps: int,
58+
denoising_start: float,
59+
denoising_end: float,
60+
device: torch.device,
61+
seed: int,
62+
):
63+
scheduler_class, scheduler_kwargs = ANIMA_SCHEDULER_MAP[scheduler_name]
64+
self.scheduler: SchedulerMixin = scheduler_class(num_train_timesteps=1000, **scheduler_kwargs)
65+
# Heun toggles state_in_first_order during step(); detect by class so we
66+
# can read it before set_timesteps has run.
67+
self.is_heun: bool = isinstance(self.scheduler, FlowMatchHeunDiscreteScheduler)
68+
self._begin_index: int = 0
69+
self._step_generator = torch.Generator(device=device).manual_seed(seed)
70+
71+
is_lcm = scheduler_name == "lcm"
72+
accepts_sigmas = "sigmas" in inspect.signature(self.scheduler.set_timesteps).parameters
73+
clipped = denoising_start > 0 or denoising_end < 1
74+
75+
if not is_lcm and accepts_sigmas:
76+
self.scheduler.set_timesteps(sigmas=sigmas, device=device)
77+
self._num_iterations = len(self.scheduler.timesteps)
78+
elif not is_lcm and clipped and hasattr(self.scheduler, "set_begin_index"):
79+
k_start = int(denoising_start * steps)
80+
k_end = int(denoising_end * steps)
81+
self.scheduler.set_timesteps(num_inference_steps=steps, device=device)
82+
if self.is_heun:
83+
# Heun's timesteps array is 2N-1 entries; logical step k maps to
84+
# doubled index 2k. min() clamps denoising_end=1.0 to the
85+
# unpaired terminal first-order step.
86+
self._begin_index = 2 * k_start
87+
self._num_iterations = min(
88+
2 * (k_end - k_start),
89+
len(self.scheduler.timesteps) - self._begin_index,
90+
)
91+
else:
92+
self._begin_index = k_start
93+
self._num_iterations = k_end - self._begin_index
94+
self.scheduler.set_begin_index(self._begin_index)
95+
else:
96+
self.scheduler.set_timesteps(num_inference_steps=len(sigmas) - 1, device=device)
97+
self._num_iterations = len(self.scheduler.timesteps)
98+
99+
@property
100+
def num_iterations(self) -> int:
101+
"""Total :meth:`step` calls. For Heun this is roughly 2× the user-visible step count."""
102+
return self._num_iterations
103+
104+
@property
105+
def begin_index(self) -> int:
106+
return self._begin_index
107+
108+
def iterations(self) -> Iterator[AnimaSchedulerIteration]:
109+
for i in range(self._num_iterations):
110+
sched_idx = i + self._begin_index
111+
sched_timestep = self.scheduler.timesteps[sched_idx]
112+
sigma_curr = sched_timestep.item() / self.scheduler.config.num_train_timesteps
113+
114+
# Read state_in_first_order before step (Heun toggles it inside step()).
115+
in_first_order = self.scheduler.state_in_first_order if self.is_heun else True
116+
117+
next_idx = sched_idx + 1
118+
sigma_prev = self.scheduler.sigmas[next_idx].item() if next_idx < len(self.scheduler.sigmas) else 0.0
119+
120+
# For Heun, a user step completes on the second-order half of each
121+
# pair AND on the unpaired terminal first-order step (sigma_prev==0).
122+
is_terminal = sigma_prev == 0.0
123+
completes_user_step = (not self.is_heun) or (not in_first_order) or is_terminal
124+
order = 2 if self.is_heun else 1
125+
126+
yield AnimaSchedulerIteration(
127+
sched_timestep=sched_timestep,
128+
sigma_curr=sigma_curr,
129+
sigma_prev=sigma_prev,
130+
completes_user_step=completes_user_step,
131+
order=order,
132+
)
133+
134+
def step(
135+
self,
136+
model_output: torch.Tensor,
137+
timestep: torch.Tensor,
138+
sample: torch.Tensor,
139+
) -> torch.Tensor:
140+
step_output = self.scheduler.step(
141+
model_output=model_output,
142+
timestep=timestep,
143+
sample=sample,
144+
generator=self._step_generator,
145+
)
146+
return step_output.prev_sample
147+
148+
@property
149+
def step_generator(self) -> torch.Generator:
150+
return self._step_generator

0 commit comments

Comments
 (0)