@@ -56,6 +56,11 @@ struct BuildOptions {
5656 // If true, pre-populate the GQA node's slot 14 with a q_norm_weight initializer so the
5757 // optimizer treats the node as already fused and skips it.
5858 bool pre_fused = false ;
59+ // If true, build the q_norm_weight initializer as float32 while the SLN input/output
60+ // remain MLFloat16. SimplifiedLayerNormalization's schema allows scale (V) to differ
61+ // from input/output (T), but the fused GQA op reuses T for the norm-weight slots, so
62+ // the optimizer must reject the rewrite.
63+ bool mismatched_norm_weight_dtype = false ;
5964};
6065
6166void BuildQwenQkPostNormPattern (ModelTestBuilder& builder, const BuildOptions& opts) {
@@ -77,10 +82,13 @@ void BuildQwenQkPostNormPattern(ModelTestBuilder& builder, const BuildOptions& o
7782 NodeArg* seqlens_k = builder.MakeInput <int32_t >(std::vector<int64_t >{kBatch }, std::vector<int32_t >{0 });
7883 NodeArg* total_seq_len = builder.MakeInput <int32_t >(std::vector<int64_t >{1 }, std::vector<int32_t >{1 });
7984
80- // Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch.)
85+ // Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch, or
86+ // a different element type to exercise the dtype-mismatch gate.)
8187 std::vector<int64_t > q_norm_weight_shape =
8288 opts.break_q_norm_weight_shape ? std::vector<int64_t >{1 , kHeadSize } : std::vector<int64_t >{kHeadSize };
83- NodeArg* q_norm_weight = builder.MakeInitializer <MLFloat16>(q_norm_weight_shape, MLFloat16 (1 .0f ), MLFloat16 (1 .0f ));
89+ NodeArg* q_norm_weight = opts.mismatched_norm_weight_dtype
90+ ? builder.MakeInitializer <float >(q_norm_weight_shape, 1 .0f , 1 .0f )
91+ : builder.MakeInitializer <MLFloat16>(q_norm_weight_shape, MLFloat16 (1 .0f ), MLFloat16 (1 .0f ));
8492 NodeArg* k_norm_weight = builder.MakeInitializer <MLFloat16>({kHeadSize }, MLFloat16 (1 .0f ), MLFloat16 (1 .0f ));
8593
8694 // Reshape "shape" initializers.
@@ -254,6 +262,19 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNon1DNor
254262 TransformerLevel::Level2, /* steps=*/ 1 , nullptr , CheckUnfusedGraph));
255263}
256264
265+ TEST_F (GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsMismatchedNormWeightDtype) {
266+ // SimplifiedLayerNormalization permits its scale (V) to differ from input/output (T),
267+ // but the fused GroupQueryAttention slot reuses T for the norm-weight inputs. Wiring a
268+ // float32 scale into a float16 chain must skip the rewrite to avoid changing type
269+ // constraints on the fused node.
270+ BuildOptions opts;
271+ opts.mismatched_norm_weight_dtype = true ;
272+ auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern (builder, opts); };
273+ ASSERT_STATUS_OK (TestGraphTransformer (
274+ build, /* opset_version=*/ 21 , *logger_, MakeWebGpuTransformer (),
275+ TransformerLevel::Level2, /* steps=*/ 1 , nullptr , CheckUnfusedGraph));
276+ }
277+
257278TEST_F (GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) {
258279 // Build the pattern but assign all nodes to CPU EP. The fusion is gated to WebGPU only,
259280 // so the graph must remain unfused.
0 commit comments