@@ -56,11 +56,6 @@ struct BuildOptions {
5656 // If true, pre-populate the GQA node's slot 14 with a q_norm_weight initializer so the
5757 // optimizer treats the node as already fused and skips it.
5858 bool pre_fused = false ;
59- // If true, build the q_norm_weight initializer as float32 while the SLN input/output
60- // remain MLFloat16. SimplifiedLayerNormalization's schema allows scale (V) to differ
61- // from input/output (T), but the fused GQA op reuses T for the norm-weight slots, so
62- // the optimizer must reject the rewrite.
63- bool mismatched_norm_weight_dtype = false ;
6459};
6560
6661void BuildQwenQkPostNormPattern (ModelTestBuilder& builder, const BuildOptions& opts) {
@@ -82,13 +77,10 @@ void BuildQwenQkPostNormPattern(ModelTestBuilder& builder, const BuildOptions& o
8277 NodeArg* seqlens_k = builder.MakeInput <int32_t >(std::vector<int64_t >{kBatch }, std::vector<int32_t >{0 });
8378 NodeArg* total_seq_len = builder.MakeInput <int32_t >(std::vector<int64_t >{1 }, std::vector<int32_t >{1 });
8479
85- // Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch, or
86- // a different element type to exercise the dtype-mismatch gate.)
80+ // Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch.)
8781 std::vector<int64_t > q_norm_weight_shape =
8882 opts.break_q_norm_weight_shape ? std::vector<int64_t >{1 , kHeadSize } : std::vector<int64_t >{kHeadSize };
89- NodeArg* q_norm_weight = opts.mismatched_norm_weight_dtype
90- ? builder.MakeInitializer <float >(q_norm_weight_shape, 1 .0f , 1 .0f )
91- : builder.MakeInitializer <MLFloat16>(q_norm_weight_shape, MLFloat16 (1 .0f ), MLFloat16 (1 .0f ));
83+ NodeArg* q_norm_weight = builder.MakeInitializer <MLFloat16>(q_norm_weight_shape, MLFloat16 (1 .0f ), MLFloat16 (1 .0f ));
9284 NodeArg* k_norm_weight = builder.MakeInitializer <MLFloat16>({kHeadSize }, MLFloat16 (1 .0f ), MLFloat16 (1 .0f ));
9385
9486 // Reshape "shape" initializers.
@@ -262,19 +254,6 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNon1DNor
262254 TransformerLevel::Level2, /* steps=*/ 1 , nullptr , CheckUnfusedGraph));
263255}
264256
265- TEST_F (GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsMismatchedNormWeightDtype) {
266- // SimplifiedLayerNormalization permits its scale (V) to differ from input/output (T),
267- // but the fused GroupQueryAttention slot reuses T for the norm-weight inputs. Wiring a
268- // float32 scale into a float16 chain must skip the rewrite to avoid changing type
269- // constraints on the fused node.
270- BuildOptions opts;
271- opts.mismatched_norm_weight_dtype = true ;
272- auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern (builder, opts); };
273- ASSERT_STATUS_OK (TestGraphTransformer (
274- build, /* opset_version=*/ 21 , *logger_, MakeWebGpuTransformer (),
275- TransformerLevel::Level2, /* steps=*/ 1 , nullptr , CheckUnfusedGraph));
276- }
277-
278257TEST_F (GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) {
279258 // Build the pattern but assign all nodes to CPU EP. The fusion is gated to WebGPU only,
280259 // so the graph must remain unfused.
0 commit comments