Skip to content

Commit 1626f4c

Browse files
ChuxiJclaude
andcommitted
Address PR #13095 review: migrate to FlowMatchEulerDiscreteScheduler
Replace the hand-rolled flow-matching Euler loop with `FlowMatchEulerDiscreteScheduler`. ACE-Step still computes its own shifted / turbo sigma schedule via `_get_timestep_schedule`, but now passes it to `scheduler.set_timesteps(sigmas=...)` and delegates the ODE step to `scheduler.step()`. The scheduler is configured with `num_train_timesteps=1` and `shift=1.0` so `scheduler.timesteps` stays in `[0, 1]` (the convention the DiT was trained on) and the scheduler doesn't re-shift already-shifted sigmas. The scheduler's appended terminal `sigma=0` reproduces the old loop's final-step "project to x0" case exactly: `prev = x + (0 - t_curr) * v`. Parity on jieyue (seed=42, bf16 + flash-attn, turbo text2music, 8 steps): waveform Pearson = 0.999999 spectral Pearson = 1.000000 max |diff| = 2.5e-3 (fp32 step-math vs previous bf16 step-math) fp32 Euler-loop A/B against the hand-rolled path: max |diff| = 3.6e-7. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent b46370a commit 1626f4c

2 files changed

Lines changed: 34 additions & 16 deletions

File tree

scripts/convert_ace_step_to_diffusers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,18 @@ def _rename_attn_keys(key: str) -> str:
329329
shutil.copy2(silence_latent_src, os.path.join(output_dir, "silence_latent.pt"))
330330
print(f"Baked silence_latent into condition_encoder + kept raw copy at {output_dir}/silence_latent.pt")
331331

332+
# Save scheduler config. ACE-Step drives the DiT with t ∈ [0, 1] and computes its own
333+
# shifted / turbo sigma schedule, which it passes to
334+
# `scheduler.set_timesteps(sigmas=...)` at sampling time. So the scheduler itself
335+
# needs `num_train_timesteps=1` (so `scheduler.timesteps == sigmas`) and `shift=1.0`
336+
# (so it doesn't re-shift already-shifted sigmas). All other defaults are fine.
337+
from diffusers import FlowMatchEulerDiscreteScheduler
338+
339+
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1, shift=1.0)
340+
scheduler_output_dir = os.path.join(output_dir, "scheduler")
341+
scheduler.save_pretrained(scheduler_output_dir)
342+
print(f"Saved scheduler config -> {scheduler_output_dir}")
343+
332344
# Report other keys that were not saved to transformer or condition_encoder
333345
if other_sd:
334346
print(f"\nNote: {len(other_sd)} keys were dropped (tokenizer / detokenizer weights):")

src/diffusers/pipelines/ace_step/pipeline_ace_step.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from ...models import AutoencoderOobleck
2424
from ...models.transformers.ace_step_transformer import AceStepTransformer1DModel
25+
from ...schedulers import FlowMatchEulerDiscreteScheduler
2526
from ...utils import logging, replace_example_docstring
2627
from ...utils.torch_utils import randn_tensor
2728
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
@@ -233,6 +234,11 @@ class AceStepPipeline(DiffusionPipeline):
233234
The Diffusion Transformer (DiT) model for denoising audio latents.
234235
condition_encoder ([`AceStepConditionEncoder`]):
235236
Condition encoder that combines text, lyric, and timbre embeddings for cross-attention.
237+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
238+
Flow-matching Euler scheduler. ACE-Step feeds the DiT timesteps in `[0, 1]`, so the
239+
scheduler is configured with `num_train_timesteps=1` and `shift=1.0` — the pipeline
240+
computes its shifted / turbo sigma schedule itself and passes it via
241+
`set_timesteps(sigmas=...)`.
236242
"""
237243

238244
model_cpu_offload_seq = "text_encoder->condition_encoder->transformer->vae"
@@ -244,6 +250,7 @@ def __init__(
244250
tokenizer: PreTrainedTokenizerFast,
245251
transformer: AceStepTransformer1DModel,
246252
condition_encoder: AceStepConditionEncoder,
253+
scheduler: FlowMatchEulerDiscreteScheduler,
247254
):
248255
super().__init__()
249256

@@ -253,6 +260,7 @@ def __init__(
253260
tokenizer=tokenizer,
254261
transformer=transformer,
255262
condition_encoder=condition_encoder,
263+
scheduler=scheduler,
256264
)
257265

258266
# ACE-Step is designed for variable-length audio up to 10 minutes. Enable
@@ -1306,23 +1314,27 @@ def __call__(
13061314
device=encoder_hidden_states.device, dtype=encoder_hidden_states.dtype
13071315
).expand_as(encoder_hidden_states)
13081316

1309-
# 9. Get timestep schedule
1317+
# 9. Configure scheduler with ACE-Step's custom sigma schedule. `_get_timestep_schedule`
1318+
# already returns the shifted / turbo sigmas in `[0, 1]`; the scheduler was
1319+
# registered with `num_train_timesteps=1` and `shift=1.0` so it consumes them
1320+
# verbatim (and appends the terminal 0 used on the final Euler step).
13101321
t_schedule = self._get_timestep_schedule(
13111322
num_inference_steps=num_inference_steps,
13121323
shift=shift,
13131324
device=device,
1314-
dtype=dtype,
1325+
dtype=torch.float32,
13151326
timesteps=timesteps,
13161327
)
1317-
num_steps = len(t_schedule)
1328+
self.scheduler.set_timesteps(sigmas=t_schedule.tolist(), device=device)
1329+
num_steps = len(self.scheduler.timesteps)
13181330

13191331
# 10. Denoising loop (flow matching ODE)
13201332
xt = latents
13211333
# APG momentum is stateful across steps, so instantiate once before the loop.
13221334
momentum_buffer = _APGMomentumBuffer() if do_cfg else None
13231335
with self.progress_bar(total=num_steps) as progress_bar:
1324-
for step_idx in range(num_steps):
1325-
current_timestep = t_schedule[step_idx].item()
1336+
for step_idx, t_sched in enumerate(self.scheduler.timesteps):
1337+
current_timestep = float(t_sched)
13261338
t_curr_tensor = current_timestep * torch.ones((batch_size,), device=device, dtype=dtype)
13271339

13281340
# Determine if CFG should be applied at this timestep
@@ -1384,17 +1396,11 @@ def __call__(
13841396
# Blend: strength * cover_vt + (1 - strength) * text2music_vt
13851397
vt = audio_cover_strength * vt + (1.0 - audio_cover_strength) * vt_nc
13861398

1387-
# On final step, directly compute x0
1388-
if step_idx == num_steps - 1:
1389-
xt = xt - vt * t_curr_tensor.unsqueeze(-1).unsqueeze(-1)
1390-
progress_bar.update()
1391-
break
1392-
1393-
# Euler ODE step: x_{t-1} = x_t - v_t * dt
1394-
next_timestep = t_schedule[step_idx + 1].item()
1395-
dt = current_timestep - next_timestep
1396-
dt_tensor = dt * torch.ones((batch_size,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1)
1397-
xt = xt - vt * dt_tensor
1399+
# Euler ODE step via the scheduler. The scheduler appends a terminal
1400+
# sigma=0, so on the last step `dt = 0 - t_curr = -t_curr` and
1401+
# `prev = x + dt * v = x - t_curr * v` — the "project to x0" step the
1402+
# hand-rolled loop did as a special case.
1403+
xt = self.scheduler.step(vt, t_sched, xt, return_dict=False)[0]
13981404

13991405
progress_bar.update()
14001406

0 commit comments

Comments
 (0)