Skip to content

Add DiT attention fusion for F5-TTS and diffusion transformer models#27999

Open
Rishi-Dave wants to merge 5 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/dit-attention-fusion
Open

Add DiT attention fusion for F5-TTS and diffusion transformer models#27999
Rishi-Dave wants to merge 5 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/dit-attention-fusion

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Summary

  • Add FusionMultiHeadAttentionDiT to recognize DiT-style attention patterns (F5-TTS, etc.) and fuse them into MultiHeadAttention, enabling Flash Attention dispatch.
  • Register the new fusion as a second pass in MmditOnnxModel.fuse_multi_head_attention(), alongside the existing MMDit fusion for SD3/Flux.
  • Add test model generator and three test cases covering FP32, FP16 cast, and custom-scale variants.

Motivation

Fixes #27983

DiT models like F5-TTS use an attention pattern where Q, K, V are pre-computed (e.g., after RoPE) in BNSH format, K is pre-transposed to BNHS, and a custom scalar scale (e.g., 100.0) is applied via Mul before Softmax. Optional Cast nodes (FP16↔FP32) may wrap Softmax for mixed-precision inference.

The existing MMDit fusion (for SD3/Flux) expects a specific Mul→Sqrt→Div→Sqrt→Cast→Slice→Shape scaling path and does not match the simpler Mul(scalar_constant) pattern, so the attention is never fused and Flash Attention is never dispatched. This causes ~44 extra Cast ops per inference and ~200ms overhead per forward pass.

Changes

New files:

  • onnxruntime/python/tools/transformers/fusion_mha_dit.py — Core fusion class that matches the pattern:

    MatMul(Q, K^T) → [Cast FP16→FP32] → Mul(scale) → Softmax → [Cast FP32→FP16] → MatMul(attn, V)
        → Transpose(0,2,1,3) → Reshape → output
    

    and replaces it with a single MultiHeadAttention op (with scale attribute).

  • onnxruntime/test/python/transformers/dit_model_generator.py — Synthetic ONNX graph generators for testing.

Modified files:

  • onnxruntime/python/tools/transformers/onnx_model_mmdit.py — Register FusionMultiHeadAttentionDiT as a second fusion pass after the existing MMDit fusion.
  • onnxruntime/test/python/transformers/test_attention_fusion.py — Three new test cases:
    • test_dit_attention_fusion — FP32 with K pre-transpose, scale=100.0
    • test_dit_attention_fusion_with_fp16_casts — FP16 Cast nodes around Softmax
    • test_dit_attention_fusion_custom_scale — Standard 1/√d_k scale

Test Plan

  • All three new DiT fusion tests pass, verifying:
    • Exactly 1 MultiHeadAttention node is produced
    • num_heads attribute is correctly detected from upstream Reshape shapes
    • scale attribute matches the original scalar constant
    • No Softmax nodes remain after fusion
  • Existing attention fusion tests remain unaffected
  • ruff check, ruff format, and lintrunner -a pass clean

… models

DiT models like F5-TTS use an attention pattern where Q, K, V are
pre-computed (e.g. after RoPE) in BNSH format, K is pre-transposed to
BNHS, and a custom scalar scale (e.g. 100.0) is applied before Softmax.
Optional Cast nodes (FP16<->FP32) may wrap Softmax for mixed-precision.

The existing MMDit fusion (for SD3/Flux) expects a very specific
Mul->Sqrt->Div->Sqrt->Cast->Slice->Shape scaling path and does not
match the simpler Mul(scalar) pattern used in DiT models, so Flash
Attention is never dispatched.

This commit adds FusionMultiHeadAttentionDiT which recognizes:
  MatMul(Q, K^T) -> [Cast] -> Mul(scale) -> Softmax -> [Cast] -> MatMul(attn, V)

and fuses it into a single MultiHeadAttention op with the custom scale
attribute, enabling Flash Attention dispatch.

Fixes microsoft#27983
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the Python transformers graph fuser to recognize and fuse DiT-style attention patterns (e.g., F5-TTS) into a com.microsoft::MultiHeadAttention node so Flash Attention can be dispatched for these diffusion transformer models.

Changes:

  • Add a new DiT attention fusion pass (FusionMultiHeadAttentionDiT) that detects MatMul → (Cast) → Mul(scale) → Softmax → (Cast) → MatMul patterns and replaces them with MultiHeadAttention (including scale).
  • Register the new fusion as a second attention fusion pass in MmditOnnxModel.fuse_multi_head_attention().
  • Add a synthetic DiT ONNX model generator and new unit tests for FP32, optional Cast-wrapped Softmax, and custom scale variants.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
onnxruntime/python/tools/transformers/fusion_mha_dit.py Implements DiT-specific pattern matching and replacement with MultiHeadAttention + scale.
onnxruntime/python/tools/transformers/onnx_model_mmdit.py Runs the new DiT attention fusion pass after the existing MMDiT fusion pass.
onnxruntime/test/python/transformers/dit_model_generator.py Adds synthetic DiT attention subgraph generators used by fusion tests.
onnxruntime/test/python/transformers/test_attention_fusion.py Adds three unit tests validating DiT attention fusion behavior and attributes.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/python/transformers/dit_model_generator.py
Comment thread onnxruntime/python/tools/transformers/fusion_mha_dit.py
Comment thread onnxruntime/test/python/transformers/test_attention_fusion.py
Cast V to FP16 in the FP16-cast test model so the attention MatMul
has type-consistent inputs. Add Softmax-count-is-zero assertions to
the FP16-cast and custom-scale tests to match the base test coverage.
@tianleiwu tianleiwu requested a review from Copilot April 7, 2026 19:12
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/python/transformers/test_attention_fusion.py
Comment thread onnxruntime/test/python/transformers/dit_model_generator.py Outdated
Comment thread onnxruntime/test/python/transformers/dit_model_generator.py Outdated
Comment thread onnxruntime/test/python/transformers/dit_model_generator.py
Comment thread onnxruntime/python/tools/transformers/fusion_mha_dit.py Outdated
Comment thread onnxruntime/python/tools/transformers/fusion_mha_dit.py
Comment thread onnxruntime/python/tools/transformers/fusion_mha_dit.py
…, shrink test models

- Fix get_data_input_of_mul to handle Python int/float scalars (not just np.ndarray)
- Validate Softmax axis=-1 before fusing to avoid semantic changes
- Add detect_num_heads_from_input_shape fallback for graph inputs
- Switch to numpy_helper.from_array for initializer creation
- Reduce default test tensor sizes (num_heads=4, head_dim=8)
- Fix FP16 type consistency: cast attn output back to FP32 before o_matmul
- Add test_dit_attention_fusion_no_k_transpose for the inserted-Transpose path
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fusion is well scoped and follows the existing MMDiT pattern, but I think there are still a couple of correctness issues to address before merge.

Blocking items:

  • The cast-wrapped path can create a MultiHeadAttention with mixed element types for Q/K vs V. I reproduced this with the new FP16-cast synthetic model: after saving with the MS opset, ORT rejects the optimized model because MultiHeadAttention input type parameter T is bound to both tensor(float) and tensor(float16).
  • The Softmax axis guard handles explicit non-last axes, but still treats a missing axis as safe. That is only true for opset >= 13. For opset < 13, omitted Softmax axis defaults to 1, so this fusion can change semantics. I added that detail as a reply on the existing Softmax-axis thread.

Also worth addressing: add a single-consumer guard for the matched intermediate tensors, and make the tests load or run parity on the optimized model so node-count assertions do not miss invalid fused graphs.

Comment thread onnxruntime/python/tools/transformers/fusion_mha_dit.py
Comment thread onnxruntime/python/tools/transformers/fusion_mha_dit.py
Comment thread onnxruntime/test/python/transformers/test_attention_fusion.py
…check

- Trace V back through Cast in FP16 path to recover pre-cast tensor,
  ensuring Q/K/V share the same element type for the fused MHA
- Add explicit Q/K/V dtype consistency check; bail when casts are
  present and types are unverifiable (V not traced through Cast)
- Guard against Softmax axis=None on opset < 13 where default is
  axis=1, not last-axis
- Add single-consumer guard for all matched intermediate tensors to
  prevent removing nodes that feed other consumers
- Run onnx.shape_inference in test validation so get_dtype resolves
  intermediate tensor types, catching mixed-dtype fusions that pass
  structural assertions but fail at ORT load
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think there are two correctness issues left on the current head before this is ready to merge.

The blocking items are both in the DiT fusion itself:

  • the cast-wrapped type-safety check still does not fully verify the K path, so the fused MultiHeadAttention can still be emitted even when K's type was never actually proven compatible with Q/V;
  • the new single-consumer guard only protects the logits-side intermediates, but the pass still removes the downstream matmul_sv -> transpose_out -> reshape_out chain without checking whether matmul_sv or transpose_out feed any other consumers.

I also left one non-blocking inline suggestion on the FP16 synthetic test model. The new shape-inference-based validation is good, but the current graph still under-exercises the real mixed-precision path because the projections remain FP32.


if use_fp16_casts:
# Cast QK scores FP16 -> FP32 (simulating FP16 model needing FP32 Softmax)
nodes.append(helper.make_node("Cast", ["qk_scores"], ["qk_scores_fp32"], "cast_to_fp32", to=1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking suggestion from the consolidated review: this still under-exercises the real mixed-precision path.

In use_fp16_casts=True, the topology now validates, but Q/K/V projections are still produced by FP32 matmuls, so cast_to_fp32 is only a placeholder no-op and the new dtype-safety checks never see a truly mixed-precision QK path. If we want this test to justify the new fusion safety logic more strongly, it would be worth adding one synthetic case where the projections themselves are FP16 (or are explicitly cast to FP16 before the QK matmul) so the pre-Softmax cast actually models the real graph class.

Trace K through any Cast on the K input path to validate its source dtype
against Q, mirroring the V-path trace already present in fuse(). The guard
bails only on positive evidence of mismatch (Cast present and source dtypes
known and differing), so benign graphs where k_bnsh comes from the
synthesised BNHS->BNSH Transpose without value_info continue to fuse.

Gate the downstream removal (matmul_sv, transpose_out, reshape_out, and
optional cast_after_softmax) with is_safe_to_fuse_nodes using
[reshape_out.output[0]] as the keep list, so side-consumers of the V path
outputs are not silently deleted by remove_nodes().
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Thanks for the second round, @tianleiwu. Pushed e6ead81 addressing both asks:

K-path dtype guard (fusion_mha_dit.py ~L410-465). K is now traced one level back through any Cast on its input path, mirroring the V-path trace already present in fuse(). The guard bails only on positive evidence of mismatch — a Cast is present on the K path and both source dtypes are known and they differ. Benign graphs where k_bnsh comes from the synthesised BNHS→BNSH Transpose (no value_info, so get_dtype returns None) continue to fuse, which keeps test_dit_attention_fusion_no_k_transpose and the FP16 cast test passing while rejecting a mixed-dtype K silently reaching matmul_qk.

Downstream single-consumer guard (~L520-530). Before nodes_to_remove is committed, self.model.is_safe_to_fuse_nodes(nodes_to_remove, [reshape_out.output[0]], input_name_to_nodes, output_name_to_node) is invoked. reshape_out.output[0] is the single value the fused MHA node produces, so the check lets that through while catching any live external consumer of matmul_sv.output[0] or transpose_out.output[0].

All four DiT tests still pass locally; lintrunner -a clean. Let me know if you'd like a dedicated test exercising the K-path mismatch bail.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

The latest revision (e6ead81) addresses all three blocking concerns from prior rounds:

  1. Mixed-dtype MHA — V and K are now traced back through Cast nodes before dtype comparison, and the fusion bails on any confirmed mismatch. The conservative bail-out when casts are present but types unverifiable prevents silent type-invalid graphs.
  2. Softmax axis opset guardaxis is None is now gated on get_opset_version() >= 13, correctly handling the pre-opset-13 default of axis=1.
  3. Downstream safety — The logits-side single-consumer guard plus is_safe_to_fuse_nodes for the matmul_sv → transpose_out → reshape_out chain prevents removal of multi-consumer intermediates.

The fusion is well-structured, defensive, and follows the existing MMDit fusion patterns. Test coverage includes FP32, FP16 casts, custom scale, and no-K-transpose variants, with the _validate_mha_input_types helper catching type-invalid fused graphs.

I resolved the three previously-open blocking threads since the latest head addresses all of them.

One suggestion below on detect_num_heads — non-blocking.

if len(node.input) >= 2:
shape_value = self.model.get_constant_value(node.input[1])
if shape_value is not None and isinstance(shape_value, np.ndarray) and shape_value.size == 4:
return int(shape_value[2])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (non-blocking): This hardcodes N at shape index 2, which is correct for the BSNH→BNSH path used by DiT models. However, if a model reshapes directly to BNSH (N at index 1), this would silently return the wrong num_heads. Consider either:

  • checking the Transpose perm in the caller before trusting index 2, or
  • adding a brief comment noting this assumption ("assumes Reshape produces BSNH, not BNSH").

Unlikely to hit in practice for the target model class, but worth documenting.

@tianleiwu tianleiwu enabled auto-merge (squash) May 1, 2026 02:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Flash Attention not dispatched for DiT-style attention pattern (diffusion transformers)

3 participants