Skip to content

Commit cef037d

Browse files
committed
[Test] Add dtype-mismatch rejection test for GQA pre-norm fusion
Covers the new gate that requires SimplifiedLayerNormalization input/scale/output element types to match before fusing into GroupQueryAttention.
1 parent 21e09aa commit cef037d

1 file changed

Lines changed: 23 additions & 2 deletions

File tree

onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ struct BuildOptions {
5656
// If true, pre-populate the GQA node's slot 14 with a q_norm_weight initializer so the
5757
// optimizer treats the node as already fused and skips it.
5858
bool pre_fused = false;
59+
// If true, build the q_norm_weight initializer as float32 while the SLN input/output
60+
// remain MLFloat16. SimplifiedLayerNormalization's schema allows scale (V) to differ
61+
// from input/output (T), but the fused GQA op reuses T for the norm-weight slots, so
62+
// the optimizer must reject the rewrite.
63+
bool mismatched_norm_weight_dtype = false;
5964
};
6065

6166
void BuildQwenQkPostNormPattern(ModelTestBuilder& builder, const BuildOptions& opts) {
@@ -77,10 +82,13 @@ void BuildQwenQkPostNormPattern(ModelTestBuilder& builder, const BuildOptions& o
7782
NodeArg* seqlens_k = builder.MakeInput<int32_t>(std::vector<int64_t>{kBatch}, std::vector<int32_t>{0});
7883
NodeArg* total_seq_len = builder.MakeInput<int32_t>(std::vector<int64_t>{1}, std::vector<int32_t>{1});
7984

80-
// Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch.)
85+
// Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch, or
86+
// a different element type to exercise the dtype-mismatch gate.)
8187
std::vector<int64_t> q_norm_weight_shape =
8288
opts.break_q_norm_weight_shape ? std::vector<int64_t>{1, kHeadSize} : std::vector<int64_t>{kHeadSize};
83-
NodeArg* q_norm_weight = builder.MakeInitializer<MLFloat16>(q_norm_weight_shape, MLFloat16(1.0f), MLFloat16(1.0f));
89+
NodeArg* q_norm_weight = opts.mismatched_norm_weight_dtype
90+
? builder.MakeInitializer<float>(q_norm_weight_shape, 1.0f, 1.0f)
91+
: builder.MakeInitializer<MLFloat16>(q_norm_weight_shape, MLFloat16(1.0f), MLFloat16(1.0f));
8492
NodeArg* k_norm_weight = builder.MakeInitializer<MLFloat16>({kHeadSize}, MLFloat16(1.0f), MLFloat16(1.0f));
8593

8694
// Reshape "shape" initializers.
@@ -254,6 +262,19 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNon1DNor
254262
TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
255263
}
256264

265+
TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsMismatchedNormWeightDtype) {
266+
// SimplifiedLayerNormalization permits its scale (V) to differ from input/output (T),
267+
// but the fused GroupQueryAttention slot reuses T for the norm-weight inputs. Wiring a
268+
// float32 scale into a float16 chain must skip the rewrite to avoid changing type
269+
// constraints on the fused node.
270+
BuildOptions opts;
271+
opts.mismatched_norm_weight_dtype = true;
272+
auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
273+
ASSERT_STATUS_OK(TestGraphTransformer(
274+
build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
275+
TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
276+
}
277+
257278
TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) {
258279
// Build the pattern but assign all nodes to CPU EP. The fusion is gated to WebGPU only,
259280
// so the graph must remain unfused.

0 commit comments

Comments
 (0)