2222``diffusers/pipelines/ace_step/modeling_ace_step.py``.
2323"""
2424
25+ import inspect
2526import math
2627from typing import List , Optional , Tuple , Union
2728
3132
3233from ...configuration_utils import ConfigMixin , register_to_config
3334from ...utils import logging
34- from ..attention import AttentionMixin
35+ from ..attention import AttentionMixin , AttentionModuleMixin
36+ from ..attention_dispatch import dispatch_attention_fn
3537from ..cache_utils import CacheMixin
3638from ..embeddings import Timesteps , apply_rotary_emb , get_1d_rotary_pos_embed
3739from ..modeling_outputs import Transformer2DModelOutput
@@ -157,87 +159,135 @@ def forward(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
157159 return temb , timestep_proj
158160
159161
160- class AceStepAttention ( nn . Module ) :
161- """GQA attention with RMSNorm on query/key and optional sliding-window mask .
162+ class AceStepAttnProcessor2_0 :
163+ """Attention processor for ACE-Step GQA attention .
162164
163- The block matches the original ACE-Step attention layout (``q_proj``,
164- ``k_proj``, ``v_proj``, ``o_proj``, ``q_norm``, ``k_norm``). Self-attention
165- applies RoPE on query+key; cross-attention does not.
165+ Dispatches the actual attention call through ``dispatch_attention_fn`` so users
166+ can pick flash / sage / native backends via ``model.set_attention_backend(...)``
167+ or the ``attention_backend`` context manager. Uses the ``(B, L, H, D)`` tensor
168+ layout that the diffusers attention backends consume directly.
166169 """
167170
171+ _attention_backend = None
172+ _parallel_config = None
173+
174+ def __init__ (self ):
175+ if not hasattr (F , "scaled_dot_product_attention" ):
176+ raise ImportError (
177+ "AceStepAttnProcessor2_0 requires PyTorch 2.0. Please upgrade your pytorch version."
178+ )
179+
180+ def __call__ (
181+ self ,
182+ attn : "AceStepAttention" ,
183+ hidden_states : torch .Tensor ,
184+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
185+ attention_mask : Optional [torch .Tensor ] = None ,
186+ image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
187+ ) -> torch .Tensor :
188+ is_cross = attn .is_cross_attention and encoder_hidden_states is not None
189+ kv_input = encoder_hidden_states if is_cross else hidden_states
190+
191+ # Project to (B, L, H, D). Q uses ``heads``; K/V use ``kv_heads`` (GQA).
192+ query = attn .to_q (hidden_states ).unflatten (- 1 , (attn .heads , attn .head_dim ))
193+ key = attn .to_k (kv_input ).unflatten (- 1 , (attn .kv_heads , attn .head_dim ))
194+ value = attn .to_v (kv_input ).unflatten (- 1 , (attn .kv_heads , attn .head_dim ))
195+
196+ query = attn .norm_q (query )
197+ key = attn .norm_k (key )
198+
199+ # RoPE on self-attention only. Matches Qwen3 layout:
200+ # freqs = cat([freq_half, freq_half], dim=-1); rotate-half splits last dim.
201+ if not is_cross and image_rotary_emb is not None :
202+ query = apply_rotary_emb (query , image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 , sequence_dim = 1 )
203+ key = apply_rotary_emb (key , image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 , sequence_dim = 1 )
204+
205+ hidden_states = dispatch_attention_fn (
206+ query ,
207+ key ,
208+ value ,
209+ attn_mask = attention_mask ,
210+ dropout_p = attn .dropout if attn .training else 0.0 ,
211+ scale = attn .scaling ,
212+ enable_gqa = attn .heads != attn .kv_heads ,
213+ backend = self ._attention_backend ,
214+ parallel_config = self ._parallel_config ,
215+ )
216+ hidden_states = hidden_states .flatten (2 , 3 ).to (query .dtype )
217+ hidden_states = attn .to_out [0 ](hidden_states )
218+ hidden_states = attn .to_out [1 ](hidden_states )
219+ return hidden_states
220+
221+
222+ class AceStepAttention (torch .nn .Module , AttentionModuleMixin ):
223+ """GQA attention with RMSNorm on query/key for ACE-Step 1.5.
224+
225+ Uses the diffusers ``Attention`` + ``AttnProcessor`` split: this module holds
226+ the projections and Q/K norm; the processor runs the attention dispatch.
227+ Self-attention applies RoPE on query/key; cross-attention reads K/V from
228+ ``encoder_hidden_states`` and does not apply RoPE.
229+
230+ GQA means Q has ``heads * head_dim`` output while K/V have
231+ ``kv_heads * head_dim`` — QKV fusion is therefore disabled
232+ (``_supports_qkv_fusion = False``).
233+ """
234+
235+ _default_processor_cls = AceStepAttnProcessor2_0
236+ _available_processors = [AceStepAttnProcessor2_0 ]
237+ _supports_qkv_fusion = False
238+
168239 def __init__ (
169240 self ,
170241 hidden_size : int ,
171242 num_attention_heads : int ,
172243 num_key_value_heads : int ,
173244 head_dim : int ,
174- attention_bias : bool = False ,
175- attention_dropout : float = 0.0 ,
245+ bias : bool = False ,
246+ dropout : float = 0.0 ,
247+ eps : float = 1e-6 ,
176248 is_cross_attention : bool = False ,
177- sliding_window : Optional [int ] = None ,
178- rms_norm_eps : float = 1e-6 ,
249+ processor : Optional [AceStepAttnProcessor2_0 ] = None ,
179250 ):
180251 super ().__init__ ()
252+ self .heads = num_attention_heads
253+ self .kv_heads = num_key_value_heads
181254 self .head_dim = head_dim
182- self .num_attention_heads = num_attention_heads
183- self .num_key_value_heads = num_key_value_heads
184- self .num_key_value_groups = num_attention_heads // num_key_value_heads
255+ self .dropout = dropout
185256 self .scaling = head_dim ** - 0.5
186- self .attention_dropout = attention_dropout
187257 self .is_cross_attention = is_cross_attention
188- self .sliding_window = sliding_window
189258
190- self .q_proj = nn .Linear (hidden_size , num_attention_heads * head_dim , bias = attention_bias )
191- self .k_proj = nn .Linear (hidden_size , num_key_value_heads * head_dim , bias = attention_bias )
192- self .v_proj = nn .Linear (hidden_size , num_key_value_heads * head_dim , bias = attention_bias )
193- self .o_proj = nn .Linear (num_attention_heads * head_dim , hidden_size , bias = attention_bias )
194- self .q_norm = RMSNorm (head_dim , eps = rms_norm_eps )
195- self .k_norm = RMSNorm (head_dim , eps = rms_norm_eps )
259+ self .to_q = nn .Linear (hidden_size , num_attention_heads * head_dim , bias = bias )
260+ self .to_k = nn .Linear (hidden_size , num_key_value_heads * head_dim , bias = bias )
261+ self .to_v = nn .Linear (hidden_size , num_key_value_heads * head_dim , bias = bias )
262+ self .to_out = nn .ModuleList (
263+ [nn .Linear (num_attention_heads * head_dim , hidden_size , bias = bias ), nn .Dropout (0.0 )]
264+ )
265+ self .norm_q = RMSNorm (head_dim , eps = eps )
266+ self .norm_k = RMSNorm (head_dim , eps = eps )
267+
268+ if processor is None :
269+ processor = self ._default_processor_cls ()
270+ self .set_processor (processor )
196271
197272 def forward (
198273 self ,
199274 hidden_states : torch .Tensor ,
200- attention_mask : Optional [torch .Tensor ] = None ,
201275 encoder_hidden_states : Optional [torch .Tensor ] = None ,
202- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
276+ attention_mask : Optional [torch .Tensor ] = None ,
277+ image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
278+ ** kwargs ,
203279 ) -> torch .Tensor :
204- input_shape = hidden_states .shape [:- 1 ]
205- # (B, L, H, D)
206- q_shape = (* input_shape , self .num_attention_heads , self .head_dim )
207- query_states = self .q_norm (self .q_proj (hidden_states ).view (q_shape )).transpose (- 3 , - 2 )
208-
209- is_cross = self .is_cross_attention and encoder_hidden_states is not None
210- kv_input = encoder_hidden_states if is_cross else hidden_states
211- kv_shape = (* kv_input .shape [:- 1 ], self .num_key_value_heads , self .head_dim )
212- key_states = self .k_norm (self .k_proj (kv_input ).view (kv_shape )).transpose (- 3 , - 2 )
213- value_states = self .v_proj (kv_input ).view (kv_shape ).transpose (- 3 , - 2 )
214-
215- if not is_cross and position_embeddings is not None :
216- cos , sin = position_embeddings
217- query_states = apply_rotary_emb (
218- query_states , (cos , sin ), use_real = True , use_real_unbind_dim = - 2
219- )
220- key_states = apply_rotary_emb (
221- key_states , (cos , sin ), use_real = True , use_real_unbind_dim = - 2
222- )
223-
224- # Expand KV heads to match Q heads for grouped-query attention.
225- if self .num_key_value_groups > 1 :
226- key_states = key_states .repeat_interleave (self .num_key_value_groups , dim = - 3 )
227- value_states = value_states .repeat_interleave (self .num_key_value_groups , dim = - 3 )
228-
229- attn_output = F .scaled_dot_product_attention (
230- query_states ,
231- key_states ,
232- value_states ,
233- attn_mask = attention_mask ,
234- dropout_p = self .attention_dropout if self .training else 0.0 ,
235- scale = self .scaling ,
280+ attn_parameters = set (inspect .signature (self .processor .__call__ ).parameters .keys ())
281+ kwargs = {k : v for k , v in kwargs .items () if k in attn_parameters }
282+ return self .processor (
283+ self ,
284+ hidden_states ,
285+ encoder_hidden_states = encoder_hidden_states ,
286+ attention_mask = attention_mask ,
287+ image_rotary_emb = image_rotary_emb ,
288+ ** kwargs ,
236289 )
237290
238- attn_output = attn_output .transpose (- 3 , - 2 ).reshape (* input_shape , - 1 ).contiguous ()
239- return self .o_proj (attn_output )
240-
241291
242292class AceStepTransformerBlock (nn .Module ):
243293 """ACE-Step DiT transformer block: self-attn (AdaLN) → cross-attn → MLP (AdaLN).
@@ -266,11 +316,10 @@ def __init__(
266316 num_attention_heads = num_attention_heads ,
267317 num_key_value_heads = num_key_value_heads ,
268318 head_dim = head_dim ,
269- attention_bias = attention_bias ,
270- attention_dropout = attention_dropout ,
319+ bias = attention_bias ,
320+ dropout = attention_dropout ,
321+ eps = rms_norm_eps ,
271322 is_cross_attention = False ,
272- sliding_window = sliding_window ,
273- rms_norm_eps = rms_norm_eps ,
274323 )
275324
276325 self .use_cross_attention = use_cross_attention
@@ -281,10 +330,10 @@ def __init__(
281330 num_attention_heads = num_attention_heads ,
282331 num_key_value_heads = num_key_value_heads ,
283332 head_dim = head_dim ,
284- attention_bias = attention_bias ,
285- attention_dropout = attention_dropout ,
333+ bias = attention_bias ,
334+ dropout = attention_dropout ,
335+ eps = rms_norm_eps ,
286336 is_cross_attention = True ,
287- rms_norm_eps = rms_norm_eps ,
288337 )
289338
290339 self .mlp_norm = RMSNorm (hidden_size , eps = rms_norm_eps )
@@ -311,7 +360,7 @@ def forward(
311360 )
312361 attn_output = self .self_attn (
313362 hidden_states = norm_hidden_states ,
314- position_embeddings = position_embeddings ,
363+ image_rotary_emb = position_embeddings ,
315364 attention_mask = attention_mask ,
316365 )
317366 hidden_states = (hidden_states + attn_output * gate_msa ).type_as (hidden_states )
0 commit comments