Skip to content

Add ACE-Step pipeline for text-to-music generation#13095

Open
ChuxiJ wants to merge 32 commits intohuggingface:mainfrom
ChuxiJ:add-ace-step-pipeline
Open

Add ACE-Step pipeline for text-to-music generation#13095
ChuxiJ wants to merge 32 commits intohuggingface:mainfrom
ChuxiJ:add-ace-step-pipeline

Conversation

@ChuxiJ
Copy link
Copy Markdown

@ChuxiJ ChuxiJ commented Feb 7, 2026

What does this PR do?

This PR adds the ACE-Step 1.5 pipeline to Diffusers — a text-to-music generation model that produces high-quality stereo music with lyrics at 48kHz from text prompts.

New Components

  • AceStepDiTModel (src/diffusers/models/transformers/ace_step_transformer.py): A Diffusion Transformer (DiT) model with RoPE, GQA, sliding window attention, and flow matching for denoising audio latents. Includes custom components: AceStepRMSNorm, AceStepRotaryEmbedding, AceStepMLP, AceStepTimestepEmbedding, AceStepAttention, AceStepEncoderLayer, and AceStepDiTLayer.

  • AceStepConditionEncoder (src/diffusers/pipelines/ace_step/modeling_ace_step.py): Condition encoder that fuses text, lyric, and timbre embeddings into a unified cross-attention conditioning signal. Includes AceStepLyricEncoder and AceStepTimbreEncoder sub-modules.

  • AceStepPipeline (src/diffusers/pipelines/ace_step/pipeline_ace_step.py): The main pipeline supporting 6 task types:

    • text2music — generate music from text and lyrics
    • cover — generate from audio semantic codes or with timbre transfer via reference audio
    • repaint — regenerate a time region within existing audio
    • extract — extract a specific track (vocals, drums, etc.) from audio
    • lego — generate a specific track given audio context
    • complete — complete audio with additional tracks
  • Conversion script (scripts/convert_ace_step_to_diffusers.py): Converts original ACE-Step 1.5 checkpoint weights to Diffusers format.

Key Features

  • Multi-task support: 6 task types with automatic instruction routing via _get_task_instruction
  • Music metadata conditioning: Optional bpm, keyscale, timesignature parameters formatted into the SFT prompt template
  • Audio-to-audio tasks: Source audio (src_audio) and reference audio (reference_audio) inputs with VAE encoding
  • Tiled VAE encode/decode: Memory-efficient chunked encoding (_tiled_encode) and decoding (_tiled_decode) for long audio
  • Classifier-free guidance (CFG): Dual forward pass with configurable guidance_scale, cfg_interval_start, and cfg_interval_end (primarily for base/SFT models; turbo models have guidance distilled into weights)
  • Audio cover strength blending: Smooth interpolation between cover-conditioned and text-only-conditioned outputs via audio_cover_strength
  • Audio code parsing: _parse_audio_code_string extracts semantic codes from <|audio_code_N|> tokens for cover tasks
  • Chunk masking: _build_chunk_mask creates time-region masks for repaint/lego tasks
  • Anti-clipping normalization: Post-decode normalization to prevent audio clipping
  • Multi-language lyrics: 50+ languages including English, Chinese, Japanese, Korean, French, German, Spanish, etc.
  • Variable-length generation: Configurable duration from 10 seconds to 10+ minutes
  • Custom timestep schedules: Pre-defined shifted schedules for shift=1.0/2.0/3.0, or user-provided timesteps
  • Turbo model variant: Optimized for 8 inference steps with shift=3.0

Architecture

ACE-Step 1.5 comprises three main components:

  1. Oobleck autoencoder (VAE): Compresses 48kHz stereo waveforms into 25Hz latent representations
  2. Qwen3-Embedding-0.6B text encoder: Encodes text prompts and lyrics for conditioning
  3. Diffusion Transformer (DiT): Denoises audio latents using flow matching with an Euler ODE solver

Tests

  • Pipeline tests (tests/pipelines/ace_step/test_ace_step.py):
    • AceStepDiTModelTests — forward shape, return dict, gradient checkpointing
    • AceStepConditionEncoderTests — forward shape, save/load config
    • AceStepPipelineFastTests (extends PipelineTesterMixin) — 39 tests covering basic generation, batch processing, latent output, save/load, float16 inference, CPU/model offloading, encode_prompt, prepare_latents, timestep_schedule, format_prompt, and more
  • Model tests (tests/models/transformers/test_models_transformer_ace_step.py):
    • TestAceStepDiTModel (extends ModelTesterMixin) — forward pass, dtype inference, save/load, determinism
    • TestAceStepDiTModelMemory (extends MemoryTesterMixin) — layerwise casting, group offloading
    • TestAceStepDiTModelTraining (extends TrainingTesterMixin) — training, EMA, gradient checkpointing, mixed precision

All 70 tests pass (39 pipeline + 31 model).

Documentation

  • docs/source/en/api/pipelines/ace_step.md — Pipeline API documentation with usage examples
  • docs/source/en/api/models/ace_step_transformer.md — Transformer model documentation

Usage

import torch
import soundfile as sf
from diffusers import AceStepPipeline

pipe = AceStepPipeline.from_pretrained("ACE-Step/ACE-Step-v1-5-turbo", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")

# Text-to-music generation
audio = pipe(
    prompt="A beautiful piano piece with soft melodies",
    lyrics="[verse]\nSoft notes in the morning light\n[chorus]\nMusic fills the air tonight",
    audio_duration=30.0,
    num_inference_steps=8,
    bpm=120,
    keyscale="C major",
).audios

sf.write("output.wav", audio[0, 0].cpu().numpy(), 48000)

Before submitting

Who can review?

References

@ChuxiJ ChuxiJ marked this pull request as draft February 7, 2026 11:38
@ChuxiJ ChuxiJ marked this pull request as ready for review February 7, 2026 14:29
@dg845 dg845 requested review from dg845 and yiyixuxu February 8, 2026 03:21
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Feb 9, 2026

Hi @ChuxiJ, thanks for the PR! As a preliminary comment, I tried the test script given above but got an error, which I think is due to the fact that the ACE-Step/ACE-Step-v1-5-turbo repo doesn't currently exist on the HF hub.

If I convert the checkpoint locally from a local snapshot of ACE-Step/Ace-Step1.5 at /path/to/acestep-v15-repo using

python scripts/convert_ace_step_to_diffusers.py \
    --checkpoint_dir /path/to/acestep-v15-repo \
    --dit_config acestep-v15-turbo \
    --output_dir /path/to/acestep-v15-diffusers \
    --dtype bf16

and then test it using the following script:

import torch
import soundfile as sf
from diffusers import AceStepPipeline

OUTPUT_SAMPLE_RATE = 48000
model_id = "/path/to/acestep-v15-diffusers"
device = "cuda"
dtype = torch.bfloat16
seed = 42

pipe = AceStepPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe = pipe.to(device)

generator = torch.Generator(device=device).manual_seed(seed)

# Text-to-music generation
audio = pipe(
    prompt="A beautiful piano piece with soft melodies",
    lyrics="[verse]\nSoft notes in the morning light\n[chorus]\nMusic fills the air tonight",
    audio_duration=30.0,
    num_inference_steps=8,
    bpm=120,
    keyscale="C major",
    generator=generator,
).audios

sf.write("acestep_t2m.wav", audio[0, 0].cpu().numpy(), OUTPUT_SAMPLE_RATE)

I get the following sample:

acestep_t2m.wav

The sample quality is lower than expected, so there is probably a bug. Could you look into it?

Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/pipelines/ace_step/modeling_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py
@github-actions github-actions Bot added the size/L PR with diff > 200 LOC label Apr 24, 2026
@github-actions github-actions Bot added guiders size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 24, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 25, 2026
@yiyixuxu
Copy link
Copy Markdown
Collaborator

@ChuxiJ are you still working on it? or, should we do a final review now?

yiyixuxu added a commit that referenced this pull request Apr 27, 2026
* [agents docs] add pipelines.md and restructure review rules

- Add .ai/pipelines.md: pipeline conventions and gotchas (config-derived
  values, no_grad discipline, reinventing scheduler logic, subclassing
  variants, # Copied from annotations).
- models.md: add Attention masks subsection inside Attention pattern;
  fold reference-implementations skim into conventions; consolidate
  __init__.py / _import_structure gotchas; trim gotchas covered by
  AGENTS.md (silent fallbacks, config serialization gap) or pipelines.md
  (no_grad, guider/scheduler reuse).
- review-rules.md: collapse to a short reviewer checklist that points
  into AGENTS / models / pipelines / modular gotchas; only LLM-specific
  pattern (ephemeral context) lives here directly.
- AGENTS.md: collapse defensive-code / unused-params / backwards-compat
  / deprecation rules into one umbrella bullet; replace inline pipeline
  bullet list with a pointer to pipelines.md.
- SKILL.md (model-integration): trim pre-PR self-review to a one-line
  pointer.

Sourced from the ACE-Step PR (#13095) review.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Apply suggestion from @yiyixuxu

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix capability-flags gotcha: drop fake attrs, tighten to real failure modes

`_supports_quantization` and `_supports_cache_class` don't exist in diffusers
(sayak flagged the first; the second was also fabricated). Replaced with the
two flags where the "advertised but unbacked" pattern is a real mistake:
`_supports_gradient_checkpointing` (needs `if self.gradient_checkpointing:`
branches in forward) and `_no_split_modules` (needs correct block class names
for `device_map`). Dropped `_supports_group_offloading` — its realistic
failure mode is forgetting to opt out, not opt in.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 27, 2026
@yiyixuxu
Copy link
Copy Markdown
Collaborator

@claude
can you do a review here again? thanks

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 27, 2026

Claude finished @yiyixuxu's task in 2m 26s —— View job


Review of ACE-Step Pipeline PR

  • Read review guidelines (AGENTS.md, models.md, pipelines.md)
  • Read transformer model (ace_step_transformer.py)
  • Read condition encoder (modeling_ace_step.py)
  • Read pipeline (pipeline_ace_step.py)
  • Read VAE changes (autoencoder_oobleck.py)
  • Read tests
  • Read conversion script
  • Check prior review comments status
  • Write comprehensive review

Good progress since the last round. Many prior review comments have been addressed — the backward-compat aliases are gone, PR-process comments are removed, SHIFT_TIMESTEPS/VALID_SHIFTS lookup tables are eliminated, the scheduler is now integrated, parity scripts are cleaned up. Here's what remains:


Blockers

1. Missing _no_split_modules on AceStepTransformer1DModel

Per models.md: "Wrong _no_split_modules — these class attributes control device placement." Every other transformer in the repo defines this. Without it, device_map="auto" and accelerate integration won't split the model correctly.

ace_step_transformer.py:382 — add:

_supports_gradient_checkpointing = True
_no_split_modules = ["AceStepTransformerBlock"]

2. Nested torch.no_grad() inside methods called from __call__

Per pipelines.md: "Don't nest with torch.no_grad() inside helpers — the decorator covers all descendants, and nested context forecloses users calling encode_prompt with grads enabled."

Three locations still have inner torch.no_grad():

  • _encode_audio_to_latents (line 611): Prevents users from calling encode_audio() with gradients enabled (e.g., for fine-tuning/optimization).
  • _prepare_reference_audio_latents (line 664): Same issue.
  • _prepare_src_audio_and_latents (line 808): Same issue.

Remove the with torch.no_grad() from all three. The __call__ decorator already covers inference; users who call these helpers directly in a training context should be able to get gradients.

3. _variant_defaults() pattern — should follow Flux2 Klein distilled model pattern

The _variant_defaults() method (line 244) is unique to this pipeline. The Flux2 Klein pattern is cleaner and more standard:

  1. Warn when guidance_scale > 1.0 on a turbo/distilled model (already done at line 1035 ✅)
  2. Have do_classifier_free_guidance property check not self.is_turbo (not done)

Currently do_classifier_free_guidance (line 193) only checks gs > 1.0 — it doesn't check self.is_turbo. This means if guidance_scale gets set before the warning coerces it, the property is inconsistent. The property should be:

@property
def do_classifier_free_guidance(self) -> bool:
    gs = getattr(self, "_guidance_scale", 1.0)
    return gs is not None and gs > 1.0 and not self.is_turbo

For the default num_inference_steps and shift, just use inline defaults in __call__ parameters or fall back inside __call__:

if num_inference_steps is None:
    num_inference_steps = 8
if shift is None:
    shift = 1.0

This removes the _variant_defaults() method entirely.

4. No tests for audio-to-audio tasks

There are no tests for cover, repaint, extract, lego, or complete tasks. At minimum, add:

  • test_ace_step_cover — exercises reference_audio + timbre encoding path
  • test_ace_step_repaint — exercises src_audio, chunk_mask, and repaint window logic

These code paths are complex and untested — regressions will be silent.


Non-blockers (should fix before merge)

5. _silence_latent_tiled fallback to zeros (line 687–693)

if sl is None or sl.abs().sum() == 0:
    return torch.zeros(...)

This fallback should logger.warning(...) so users know they're getting degraded output. The previous review asked to remove an explicit warning, but that was about a different warning (line 833 in the old revision about the conversion script). A runtime degradation deserves a log message. Alternatively, just raise ValueError(...) since a missing silence_latent means the model was improperly converted.

Similarly, the fallback at line 1162–1169 (zeros for missing timbre silence latent) should warn.

6. _prepare_reference_audio_latents duplicates VAE encode logic (line 663–668)

This method manually does vae.encode(audio).latent_dist.sample() + transpose, which is exactly what _encode_audio_to_latents does. It should call self._encode_audio_to_latents(reference_audio, device, dtype) instead. DRY within the pipeline is fine.

7. _normalize_audio_to_stereo_48k is a @staticmethod-like method but uses self.sample_rate

This is fine as-is, but the docstring example in EXAMPLE_DOC_STRING (line 89–90) calls it as pipe._normalize_audio_to_stereo_48k(...) — a private method in user-facing docs is a smell. Consider either making it public (normalize_audio_to_stereo) or removing it from the docstring example.

8. encode_audio public API calls _encode_audio_to_latents with self.transformer.dtype

pipeline_ace_step.py:601:

dtype = dtype if dtype is not None else self.transformer.dtype

If the transformer is offloaded, self.transformer.dtype may not reflect the runtime dtype. The Flux2 pattern uses self.dtype (from DiffusionPipeline), which is more robust. Low risk but worth aligning.

9. full_attn_mask in AceStepLyricEncoder.forward is still materialized (line 198–199)

In the DiT, all-zeros masks were already replaced with None (good!). But in AceStepLyricEncoder.forward(), full_attn_mask is still being created via _create_4d_mask. When attention_mask is all-ones (no padding) and is_causal=False, is_sliding_window=False, this produces an all-zeros mask — a no-op that wastes memory and breaks flash/sage backends per models.md: "No mask needed → pass None, never an all-zero dense tensor."

When attention_mask has actual padding (mask values of 0), the full mask is meaningful and should stay. So the fix is conditional:

# Only materialize full_attn_mask when there's actual padding
needs_padding_mask = not attention_mask.all()
full_attn_mask = _create_4d_mask(...) if needs_padding_mask else None

The sliding mask in the lyric encoder also includes padding info via attention_mask, so it stays.

10. timbre_fix_frame = 750 hardcoded (line 1153)

This magic number is derived from 30s * 25Hz = 750 but is hardcoded rather than computed from self.latents_per_second. Should be:

timbre_fix_frame = int(30 * self.latents_per_second)  # 30 seconds of latents

Nits

  • pipeline_ace_step.py:255: The audio_duration comment in __call__ docstring says defaults to 60.0 but the actual default kwarg is 60.0 — fine, but the variant defaults table should be explicit about this (not all users read the code).

  • pipeline_ace_step.py:1348: Legacy callback receives t_curr_tensor (a tensor), but the convention is to pass float(t_curr_tensor). The Flux2 pipeline passes scalar timestep values.

  • test_ace_step.py:173: tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") — this downloads a tokenizer from HF Hub in every test run. For CI stability, consider using a self.get_dummy_tokenizer() pattern or ensuring this is cached / mock-able. If the test suite runs in an environment without network access, this will fail.

  • modeling_ace_step.py:812–816: The placeholder silence_latent is (1, 15000, timbre_hidden_dim). The 15000 is hardcoded. Consider deriving from config or at least adding a comment explaining 15000 = 10 minutes * 25 Hz.


What's looking good

  • Attention refactor is done properly: AceStepAttention + AceStepAttnProcessor2_0 with dispatch_attention_fn and AttentionModuleMixin — this is the correct pattern.
  • Scheduler integration: FlowMatchEulerDiscreteScheduler.set_timesteps(sigmas=...) + .step() is the right approach. The inline Euler step is gone.
  • APG via normalized_guidance: Correctly reusing the shared function with use_original_formulation=True and norm_dim=(1,) to match ACE-Step's formula. The guidance_scale - 1.0 offset is correct.
  • VAE tiling: Clean implementation in AutoencoderOobleck with enable_tiling()/disable_tiling() pattern. Not force-enabled in pipeline __init__ (addressed from prior review).
  • RoPE reuse: Using get_1d_rotary_pos_embed + apply_rotary_emb with the correct Qwen3 layout flags.
  • Model test suite: Using ModelTesterMixin with BaseModelTesterConfig — correct pattern for new models.

Summary

Category Count
Blockers 4
Non-blockers 6
Nits 4

The main items to address before merge are: (1) _no_split_modules, (2) remove nested torch.no_grad(), (3) clean up _variant_defaults → Flux2 pattern, and (4) add audio-to-audio task tests.

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!
i left a few more comments, I think we don't have any merge blocker left

Comment thread src/diffusers/models/transformers/ace_step_transformer.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py Outdated
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
Comment thread src/diffusers/pipelines/ace_step/pipeline_ace_step.py
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation guiders models pipelines size/L PR with diff > 200 LOC tests utils

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

4 participants