Skip to content

Commit b46370a

Browse files
committed
Address PR #13095 review: refactor AceStepAttention to Attention + AttnProcessor
Splits the monolithic AceStepAttention into the diffusers standard Attention + AttnProcessor layout: - AceStepAttention (torch.nn.Module, AttentionModuleMixin) holds the to_q/to_k/to_v/to_out projections and norm_q/norm_k RMSNorms. - AceStepAttnProcessor2_0 runs the attention dispatch through dispatch_attention_fn so users can pick flash / sage / native backends via model.set_attention_backend(...) or the attention_backend context manager. GQA (Q has 16 heads / K,V have 8) is preserved by passing enable_gqa=True to dispatch_attention_fn instead of repeat_interleave; fusion is disabled (_supports_qkv_fusion = False) because Q and K,V have different output sizes. The converter is updated to rename the six attention sub-keys (q_proj -> to_q, k_proj -> to_k, v_proj -> to_v, o_proj -> to_out.0, q_norm -> norm_q, k_norm -> norm_k) on both the DiT decoder path and the condition encoder path, since AceStepLyricEncoder / AceStepTimbreEncoder share the same AceStepAttention class. Addresses review comments r2785433213 and r2785450463.
1 parent c06373e commit b46370a

3 files changed

Lines changed: 139 additions & 71 deletions

File tree

scripts/convert_ace_step_to_diffusers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,24 @@ def convert_ace_step_weights(checkpoint_dir, dit_config, output_dir, dtype_str="
114114
condition_encoder_sd = {}
115115
other_sd = {} # tokenizer, detokenizer (audio quantization — not used by the text2music pipeline)
116116

117+
# Rename original ACE-Step attention keys to the diffusers `Attention` +
118+
# `AttnProcessor` convention (`to_q`/`to_k`/`to_v`/`to_out.0`/`norm_q`/`norm_k`).
119+
# Applies uniformly to both the DiT (self-attn and cross-attn) and the
120+
# condition-encoder self-attention, since both use `AceStepAttention`.
121+
_ATTN_KEY_RENAMES = [
122+
(".q_proj.", ".to_q."),
123+
(".k_proj.", ".to_k."),
124+
(".v_proj.", ".to_v."),
125+
(".o_proj.", ".to_out.0."),
126+
(".q_norm.", ".norm_q."),
127+
(".k_norm.", ".norm_k."),
128+
]
129+
130+
def _rename_attn_keys(key: str) -> str:
131+
for old, new in _ATTN_KEY_RENAMES:
132+
key = key.replace(old, new)
133+
return key
134+
117135
for key, value in state_dict.items():
118136
if key.startswith("decoder."):
119137
# Strip "decoder." prefix for the transformer
@@ -125,10 +143,12 @@ def convert_ace_step_weights(checkpoint_dir, dit_config, output_dir, dtype_str="
125143
# In diffusers, we use standalone Conv1d/ConvTranspose1d named proj_in_conv/proj_out_conv.
126144
new_key = new_key.replace("proj_in.1.", "proj_in_conv.")
127145
new_key = new_key.replace("proj_out.1.", "proj_out_conv.")
146+
new_key = _rename_attn_keys(new_key)
128147
transformer_sd[new_key] = value.to(target_dtype)
129148
elif key.startswith("encoder."):
130149
# Strip "encoder." prefix for the condition encoder
131150
new_key = key[len("encoder.") :]
151+
new_key = _rename_attn_keys(new_key)
132152
condition_encoder_sd[new_key] = value.to(target_dtype)
133153
elif key == "null_condition_emb":
134154
# Learned unconditional embedding (used by the base/SFT CFG path).

src/diffusers/models/transformers/ace_step_transformer.py

Lines changed: 115 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
``diffusers/pipelines/ace_step/modeling_ace_step.py``.
2323
"""
2424

25+
import inspect
2526
import math
2627
from typing import List, Optional, Tuple, Union
2728

@@ -31,7 +32,8 @@
3132

3233
from ...configuration_utils import ConfigMixin, register_to_config
3334
from ...utils import logging
34-
from ..attention import AttentionMixin
35+
from ..attention import AttentionMixin, AttentionModuleMixin
36+
from ..attention_dispatch import dispatch_attention_fn
3537
from ..cache_utils import CacheMixin
3638
from ..embeddings import Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
3739
from ..modeling_outputs import Transformer2DModelOutput
@@ -157,87 +159,135 @@ def forward(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
157159
return temb, timestep_proj
158160

159161

160-
class AceStepAttention(nn.Module):
161-
"""GQA attention with RMSNorm on query/key and optional sliding-window mask.
162+
class AceStepAttnProcessor2_0:
163+
"""Attention processor for ACE-Step GQA attention.
162164
163-
The block matches the original ACE-Step attention layout (``q_proj``,
164-
``k_proj``, ``v_proj``, ``o_proj``, ``q_norm``, ``k_norm``). Self-attention
165-
applies RoPE on query+key; cross-attention does not.
165+
Dispatches the actual attention call through ``dispatch_attention_fn`` so users
166+
can pick flash / sage / native backends via ``model.set_attention_backend(...)``
167+
or the ``attention_backend`` context manager. Uses the ``(B, L, H, D)`` tensor
168+
layout that the diffusers attention backends consume directly.
166169
"""
167170

171+
_attention_backend = None
172+
_parallel_config = None
173+
174+
def __init__(self):
175+
if not hasattr(F, "scaled_dot_product_attention"):
176+
raise ImportError(
177+
"AceStepAttnProcessor2_0 requires PyTorch 2.0. Please upgrade your pytorch version."
178+
)
179+
180+
def __call__(
181+
self,
182+
attn: "AceStepAttention",
183+
hidden_states: torch.Tensor,
184+
encoder_hidden_states: Optional[torch.Tensor] = None,
185+
attention_mask: Optional[torch.Tensor] = None,
186+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
187+
) -> torch.Tensor:
188+
is_cross = attn.is_cross_attention and encoder_hidden_states is not None
189+
kv_input = encoder_hidden_states if is_cross else hidden_states
190+
191+
# Project to (B, L, H, D). Q uses ``heads``; K/V use ``kv_heads`` (GQA).
192+
query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, attn.head_dim))
193+
key = attn.to_k(kv_input).unflatten(-1, (attn.kv_heads, attn.head_dim))
194+
value = attn.to_v(kv_input).unflatten(-1, (attn.kv_heads, attn.head_dim))
195+
196+
query = attn.norm_q(query)
197+
key = attn.norm_k(key)
198+
199+
# RoPE on self-attention only. Matches Qwen3 layout:
200+
# freqs = cat([freq_half, freq_half], dim=-1); rotate-half splits last dim.
201+
if not is_cross and image_rotary_emb is not None:
202+
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2, sequence_dim=1)
203+
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2, sequence_dim=1)
204+
205+
hidden_states = dispatch_attention_fn(
206+
query,
207+
key,
208+
value,
209+
attn_mask=attention_mask,
210+
dropout_p=attn.dropout if attn.training else 0.0,
211+
scale=attn.scaling,
212+
enable_gqa=attn.heads != attn.kv_heads,
213+
backend=self._attention_backend,
214+
parallel_config=self._parallel_config,
215+
)
216+
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
217+
hidden_states = attn.to_out[0](hidden_states)
218+
hidden_states = attn.to_out[1](hidden_states)
219+
return hidden_states
220+
221+
222+
class AceStepAttention(torch.nn.Module, AttentionModuleMixin):
223+
"""GQA attention with RMSNorm on query/key for ACE-Step 1.5.
224+
225+
Uses the diffusers ``Attention`` + ``AttnProcessor`` split: this module holds
226+
the projections and Q/K norm; the processor runs the attention dispatch.
227+
Self-attention applies RoPE on query/key; cross-attention reads K/V from
228+
``encoder_hidden_states`` and does not apply RoPE.
229+
230+
GQA means Q has ``heads * head_dim`` output while K/V have
231+
``kv_heads * head_dim`` — QKV fusion is therefore disabled
232+
(``_supports_qkv_fusion = False``).
233+
"""
234+
235+
_default_processor_cls = AceStepAttnProcessor2_0
236+
_available_processors = [AceStepAttnProcessor2_0]
237+
_supports_qkv_fusion = False
238+
168239
def __init__(
169240
self,
170241
hidden_size: int,
171242
num_attention_heads: int,
172243
num_key_value_heads: int,
173244
head_dim: int,
174-
attention_bias: bool = False,
175-
attention_dropout: float = 0.0,
245+
bias: bool = False,
246+
dropout: float = 0.0,
247+
eps: float = 1e-6,
176248
is_cross_attention: bool = False,
177-
sliding_window: Optional[int] = None,
178-
rms_norm_eps: float = 1e-6,
249+
processor: Optional[AceStepAttnProcessor2_0] = None,
179250
):
180251
super().__init__()
252+
self.heads = num_attention_heads
253+
self.kv_heads = num_key_value_heads
181254
self.head_dim = head_dim
182-
self.num_attention_heads = num_attention_heads
183-
self.num_key_value_heads = num_key_value_heads
184-
self.num_key_value_groups = num_attention_heads // num_key_value_heads
255+
self.dropout = dropout
185256
self.scaling = head_dim ** -0.5
186-
self.attention_dropout = attention_dropout
187257
self.is_cross_attention = is_cross_attention
188-
self.sliding_window = sliding_window
189258

190-
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
191-
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
192-
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
193-
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
194-
self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps)
195-
self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps)
259+
self.to_q = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=bias)
260+
self.to_k = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=bias)
261+
self.to_v = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=bias)
262+
self.to_out = nn.ModuleList(
263+
[nn.Linear(num_attention_heads * head_dim, hidden_size, bias=bias), nn.Dropout(0.0)]
264+
)
265+
self.norm_q = RMSNorm(head_dim, eps=eps)
266+
self.norm_k = RMSNorm(head_dim, eps=eps)
267+
268+
if processor is None:
269+
processor = self._default_processor_cls()
270+
self.set_processor(processor)
196271

197272
def forward(
198273
self,
199274
hidden_states: torch.Tensor,
200-
attention_mask: Optional[torch.Tensor] = None,
201275
encoder_hidden_states: Optional[torch.Tensor] = None,
202-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
276+
attention_mask: Optional[torch.Tensor] = None,
277+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
278+
**kwargs,
203279
) -> torch.Tensor:
204-
input_shape = hidden_states.shape[:-1]
205-
# (B, L, H, D)
206-
q_shape = (*input_shape, self.num_attention_heads, self.head_dim)
207-
query_states = self.q_norm(self.q_proj(hidden_states).view(q_shape)).transpose(-3, -2)
208-
209-
is_cross = self.is_cross_attention and encoder_hidden_states is not None
210-
kv_input = encoder_hidden_states if is_cross else hidden_states
211-
kv_shape = (*kv_input.shape[:-1], self.num_key_value_heads, self.head_dim)
212-
key_states = self.k_norm(self.k_proj(kv_input).view(kv_shape)).transpose(-3, -2)
213-
value_states = self.v_proj(kv_input).view(kv_shape).transpose(-3, -2)
214-
215-
if not is_cross and position_embeddings is not None:
216-
cos, sin = position_embeddings
217-
query_states = apply_rotary_emb(
218-
query_states, (cos, sin), use_real=True, use_real_unbind_dim=-2
219-
)
220-
key_states = apply_rotary_emb(
221-
key_states, (cos, sin), use_real=True, use_real_unbind_dim=-2
222-
)
223-
224-
# Expand KV heads to match Q heads for grouped-query attention.
225-
if self.num_key_value_groups > 1:
226-
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=-3)
227-
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=-3)
228-
229-
attn_output = F.scaled_dot_product_attention(
230-
query_states,
231-
key_states,
232-
value_states,
233-
attn_mask=attention_mask,
234-
dropout_p=self.attention_dropout if self.training else 0.0,
235-
scale=self.scaling,
280+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
281+
kwargs = {k: v for k, v in kwargs.items() if k in attn_parameters}
282+
return self.processor(
283+
self,
284+
hidden_states,
285+
encoder_hidden_states=encoder_hidden_states,
286+
attention_mask=attention_mask,
287+
image_rotary_emb=image_rotary_emb,
288+
**kwargs,
236289
)
237290

238-
attn_output = attn_output.transpose(-3, -2).reshape(*input_shape, -1).contiguous()
239-
return self.o_proj(attn_output)
240-
241291

242292
class AceStepTransformerBlock(nn.Module):
243293
"""ACE-Step DiT transformer block: self-attn (AdaLN) → cross-attn → MLP (AdaLN).
@@ -266,11 +316,10 @@ def __init__(
266316
num_attention_heads=num_attention_heads,
267317
num_key_value_heads=num_key_value_heads,
268318
head_dim=head_dim,
269-
attention_bias=attention_bias,
270-
attention_dropout=attention_dropout,
319+
bias=attention_bias,
320+
dropout=attention_dropout,
321+
eps=rms_norm_eps,
271322
is_cross_attention=False,
272-
sliding_window=sliding_window,
273-
rms_norm_eps=rms_norm_eps,
274323
)
275324

276325
self.use_cross_attention = use_cross_attention
@@ -281,10 +330,10 @@ def __init__(
281330
num_attention_heads=num_attention_heads,
282331
num_key_value_heads=num_key_value_heads,
283332
head_dim=head_dim,
284-
attention_bias=attention_bias,
285-
attention_dropout=attention_dropout,
333+
bias=attention_bias,
334+
dropout=attention_dropout,
335+
eps=rms_norm_eps,
286336
is_cross_attention=True,
287-
rms_norm_eps=rms_norm_eps,
288337
)
289338

290339
self.mlp_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
@@ -311,7 +360,7 @@ def forward(
311360
)
312361
attn_output = self.self_attn(
313362
hidden_states=norm_hidden_states,
314-
position_embeddings=position_embeddings,
363+
image_rotary_emb=position_embeddings,
315364
attention_mask=attention_mask,
316365
)
317366
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)

src/diffusers/pipelines/ace_step/modeling_ace_step.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,10 @@ def __init__(
9191
num_attention_heads=num_attention_heads,
9292
num_key_value_heads=num_key_value_heads,
9393
head_dim=head_dim,
94-
attention_bias=attention_bias,
95-
attention_dropout=attention_dropout,
94+
bias=attention_bias,
95+
dropout=attention_dropout,
96+
eps=rms_norm_eps,
9697
is_cross_attention=False,
97-
sliding_window=sliding_window,
98-
rms_norm_eps=rms_norm_eps,
9998
)
10099
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
101100
self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
@@ -111,7 +110,7 @@ def forward(
111110
hidden_states = self.input_layernorm(hidden_states)
112111
hidden_states = self.self_attn(
113112
hidden_states=hidden_states,
114-
position_embeddings=position_embeddings,
113+
image_rotary_emb=position_embeddings,
115114
attention_mask=attention_mask,
116115
)
117116
hidden_states = residual + hidden_states

0 commit comments

Comments
 (0)