Add DiT attention fusion for F5-TTS and diffusion transformer models#27999
Add DiT attention fusion for F5-TTS and diffusion transformer models#27999Rishi-Dave wants to merge 5 commits intomicrosoft:mainfrom
Conversation
… models DiT models like F5-TTS use an attention pattern where Q, K, V are pre-computed (e.g. after RoPE) in BNSH format, K is pre-transposed to BNHS, and a custom scalar scale (e.g. 100.0) is applied before Softmax. Optional Cast nodes (FP16<->FP32) may wrap Softmax for mixed-precision. The existing MMDit fusion (for SD3/Flux) expects a very specific Mul->Sqrt->Div->Sqrt->Cast->Slice->Shape scaling path and does not match the simpler Mul(scalar) pattern used in DiT models, so Flash Attention is never dispatched. This commit adds FusionMultiHeadAttentionDiT which recognizes: MatMul(Q, K^T) -> [Cast] -> Mul(scale) -> Softmax -> [Cast] -> MatMul(attn, V) and fuses it into a single MultiHeadAttention op with the custom scale attribute, enabling Flash Attention dispatch. Fixes microsoft#27983
There was a problem hiding this comment.
Pull request overview
This PR extends the Python transformers graph fuser to recognize and fuse DiT-style attention patterns (e.g., F5-TTS) into a com.microsoft::MultiHeadAttention node so Flash Attention can be dispatched for these diffusion transformer models.
Changes:
- Add a new DiT attention fusion pass (
FusionMultiHeadAttentionDiT) that detectsMatMul → (Cast) → Mul(scale) → Softmax → (Cast) → MatMulpatterns and replaces them withMultiHeadAttention(includingscale). - Register the new fusion as a second attention fusion pass in
MmditOnnxModel.fuse_multi_head_attention(). - Add a synthetic DiT ONNX model generator and new unit tests for FP32, optional Cast-wrapped Softmax, and custom scale variants.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| onnxruntime/python/tools/transformers/fusion_mha_dit.py | Implements DiT-specific pattern matching and replacement with MultiHeadAttention + scale. |
| onnxruntime/python/tools/transformers/onnx_model_mmdit.py | Runs the new DiT attention fusion pass after the existing MMDiT fusion pass. |
| onnxruntime/test/python/transformers/dit_model_generator.py | Adds synthetic DiT attention subgraph generators used by fusion tests. |
| onnxruntime/test/python/transformers/test_attention_fusion.py | Adds three unit tests validating DiT attention fusion behavior and attributes. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Cast V to FP16 in the FP16-cast test model so the attention MatMul has type-consistent inputs. Add Softmax-count-is-zero assertions to the FP16-cast and custom-scale tests to match the base test coverage.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…, shrink test models - Fix get_data_input_of_mul to handle Python int/float scalars (not just np.ndarray) - Validate Softmax axis=-1 before fusing to avoid semantic changes - Add detect_num_heads_from_input_shape fallback for graph inputs - Switch to numpy_helper.from_array for initializer creation - Reduce default test tensor sizes (num_heads=4, head_dim=8) - Fix FP16 type consistency: cast attn output back to FP32 before o_matmul - Add test_dit_attention_fusion_no_k_transpose for the inserted-Transpose path
There was a problem hiding this comment.
The fusion is well scoped and follows the existing MMDiT pattern, but I think there are still a couple of correctness issues to address before merge.
Blocking items:
- The cast-wrapped path can create a
MultiHeadAttentionwith mixed element types for Q/K vs V. I reproduced this with the new FP16-cast synthetic model: after saving with the MS opset, ORT rejects the optimized model becauseMultiHeadAttentioninput type parameterTis bound to bothtensor(float)andtensor(float16). - The Softmax axis guard handles explicit non-last axes, but still treats a missing axis as safe. That is only true for opset >= 13. For opset < 13, omitted Softmax axis defaults to 1, so this fusion can change semantics. I added that detail as a reply on the existing Softmax-axis thread.
Also worth addressing: add a single-consumer guard for the matched intermediate tensors, and make the tests load or run parity on the optimized model so node-count assertions do not miss invalid fused graphs.
…check - Trace V back through Cast in FP16 path to recover pre-cast tensor, ensuring Q/K/V share the same element type for the fused MHA - Add explicit Q/K/V dtype consistency check; bail when casts are present and types are unverifiable (V not traced through Cast) - Guard against Softmax axis=None on opset < 13 where default is axis=1, not last-axis - Add single-consumer guard for all matched intermediate tensors to prevent removing nodes that feed other consumers - Run onnx.shape_inference in test validation so get_dtype resolves intermediate tensor types, catching mixed-dtype fusions that pass structural assertions but fail at ORT load
There was a problem hiding this comment.
I still think there are two correctness issues left on the current head before this is ready to merge.
The blocking items are both in the DiT fusion itself:
- the cast-wrapped type-safety check still does not fully verify the K path, so the fused
MultiHeadAttentioncan still be emitted even when K's type was never actually proven compatible with Q/V; - the new single-consumer guard only protects the logits-side intermediates, but the pass still removes the downstream
matmul_sv -> transpose_out -> reshape_outchain without checking whethermatmul_svortranspose_outfeed any other consumers.
I also left one non-blocking inline suggestion on the FP16 synthetic test model. The new shape-inference-based validation is good, but the current graph still under-exercises the real mixed-precision path because the projections remain FP32.
|
|
||
| if use_fp16_casts: | ||
| # Cast QK scores FP16 -> FP32 (simulating FP16 model needing FP32 Softmax) | ||
| nodes.append(helper.make_node("Cast", ["qk_scores"], ["qk_scores_fp32"], "cast_to_fp32", to=1)) |
There was a problem hiding this comment.
Non-blocking suggestion from the consolidated review: this still under-exercises the real mixed-precision path.
In use_fp16_casts=True, the topology now validates, but Q/K/V projections are still produced by FP32 matmuls, so cast_to_fp32 is only a placeholder no-op and the new dtype-safety checks never see a truly mixed-precision QK path. If we want this test to justify the new fusion safety logic more strongly, it would be worth adding one synthetic case where the projections themselves are FP16 (or are explicitly cast to FP16 before the QK matmul) so the pre-Softmax cast actually models the real graph class.
Trace K through any Cast on the K input path to validate its source dtype against Q, mirroring the V-path trace already present in fuse(). The guard bails only on positive evidence of mismatch (Cast present and source dtypes known and differing), so benign graphs where k_bnsh comes from the synthesised BNHS->BNSH Transpose without value_info continue to fuse. Gate the downstream removal (matmul_sv, transpose_out, reshape_out, and optional cast_after_softmax) with is_safe_to_fuse_nodes using [reshape_out.output[0]] as the keep list, so side-consumers of the V path outputs are not silently deleted by remove_nodes().
|
Thanks for the second round, @tianleiwu. Pushed e6ead81 addressing both asks: K-path dtype guard ( Downstream single-consumer guard (~L520-530). Before All four DiT tests still pass locally; |
tianleiwu
left a comment
There was a problem hiding this comment.
Review Summary
The latest revision (e6ead81) addresses all three blocking concerns from prior rounds:
- Mixed-dtype MHA — V and K are now traced back through Cast nodes before dtype comparison, and the fusion bails on any confirmed mismatch. The conservative bail-out when casts are present but types unverifiable prevents silent type-invalid graphs.
- Softmax axis opset guard —
axis is Noneis now gated onget_opset_version() >= 13, correctly handling the pre-opset-13 default of axis=1. - Downstream safety — The logits-side single-consumer guard plus
is_safe_to_fuse_nodesfor thematmul_sv → transpose_out → reshape_outchain prevents removal of multi-consumer intermediates.
The fusion is well-structured, defensive, and follows the existing MMDit fusion patterns. Test coverage includes FP32, FP16 casts, custom scale, and no-K-transpose variants, with the _validate_mha_input_types helper catching type-invalid fused graphs.
I resolved the three previously-open blocking threads since the latest head addresses all of them.
One suggestion below on detect_num_heads — non-blocking.
| if len(node.input) >= 2: | ||
| shape_value = self.model.get_constant_value(node.input[1]) | ||
| if shape_value is not None and isinstance(shape_value, np.ndarray) and shape_value.size == 4: | ||
| return int(shape_value[2]) |
There was a problem hiding this comment.
Suggestion (non-blocking): This hardcodes N at shape index 2, which is correct for the BSNH→BNSH path used by DiT models. However, if a model reshapes directly to BNSH (N at index 1), this would silently return the wrong num_heads. Consider either:
- checking the Transpose perm in the caller before trusting index 2, or
- adding a brief comment noting this assumption ("assumes Reshape produces BSNH, not BNSH").
Unlikely to hit in practice for the target model class, but worth documenting.
Summary
FusionMultiHeadAttentionDiTto recognize DiT-style attention patterns (F5-TTS, etc.) and fuse them intoMultiHeadAttention, enabling Flash Attention dispatch.MmditOnnxModel.fuse_multi_head_attention(), alongside the existing MMDit fusion for SD3/Flux.Motivation
Fixes #27983
DiT models like F5-TTS use an attention pattern where Q, K, V are pre-computed (e.g., after RoPE) in BNSH format, K is pre-transposed to BNHS, and a custom scalar scale (e.g., 100.0) is applied via
MulbeforeSoftmax. OptionalCastnodes (FP16↔FP32) may wrapSoftmaxfor mixed-precision inference.The existing MMDit fusion (for SD3/Flux) expects a specific
Mul→Sqrt→Div→Sqrt→Cast→Slice→Shapescaling path and does not match the simplerMul(scalar_constant)pattern, so the attention is never fused and Flash Attention is never dispatched. This causes ~44 extra Cast ops per inference and ~200ms overhead per forward pass.Changes
New files:
onnxruntime/python/tools/transformers/fusion_mha_dit.py— Core fusion class that matches the pattern:and replaces it with a single
MultiHeadAttentionop (withscaleattribute).onnxruntime/test/python/transformers/dit_model_generator.py— Synthetic ONNX graph generators for testing.Modified files:
onnxruntime/python/tools/transformers/onnx_model_mmdit.py— RegisterFusionMultiHeadAttentionDiTas a second fusion pass after the existing MMDit fusion.onnxruntime/test/python/transformers/test_attention_fusion.py— Three new test cases:test_dit_attention_fusion— FP32 with K pre-transpose, scale=100.0test_dit_attention_fusion_with_fp16_casts— FP16 Cast nodes around Softmaxtest_dit_attention_fusion_custom_scale— Standard 1/√d_k scaleTest Plan
MultiHeadAttentionnode is producednum_headsattribute is correctly detected from upstream Reshape shapesscaleattribute matches the original scalar constantSoftmaxnodes remain after fusionruff check,ruff format, andlintrunner -apass clean