Skip to content

Commit 199d5f7

Browse files
committed
Revert "[Test] Add dtype-mismatch rejection test for GQA pre-norm fusion"
This reverts commit cef037d.
1 parent d3928db commit 199d5f7

1 file changed

Lines changed: 2 additions & 23 deletions

File tree

onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,6 @@ struct BuildOptions {
5656
// If true, pre-populate the GQA node's slot 14 with a q_norm_weight initializer so the
5757
// optimizer treats the node as already fused and skips it.
5858
bool pre_fused = false;
59-
// If true, build the q_norm_weight initializer as float32 while the SLN input/output
60-
// remain MLFloat16. SimplifiedLayerNormalization's schema allows scale (V) to differ
61-
// from input/output (T), but the fused GQA op reuses T for the norm-weight slots, so
62-
// the optimizer must reject the rewrite.
63-
bool mismatched_norm_weight_dtype = false;
6459
};
6560

6661
void BuildQwenQkPostNormPattern(ModelTestBuilder& builder, const BuildOptions& opts) {
@@ -82,13 +77,10 @@ void BuildQwenQkPostNormPattern(ModelTestBuilder& builder, const BuildOptions& o
8277
NodeArg* seqlens_k = builder.MakeInput<int32_t>(std::vector<int64_t>{kBatch}, std::vector<int32_t>{0});
8378
NodeArg* total_seq_len = builder.MakeInput<int32_t>(std::vector<int64_t>{1}, std::vector<int32_t>{1});
8479

85-
// Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch, or
86-
// a different element type to exercise the dtype-mismatch gate.)
80+
// Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch.)
8781
std::vector<int64_t> q_norm_weight_shape =
8882
opts.break_q_norm_weight_shape ? std::vector<int64_t>{1, kHeadSize} : std::vector<int64_t>{kHeadSize};
89-
NodeArg* q_norm_weight = opts.mismatched_norm_weight_dtype
90-
? builder.MakeInitializer<float>(q_norm_weight_shape, 1.0f, 1.0f)
91-
: builder.MakeInitializer<MLFloat16>(q_norm_weight_shape, MLFloat16(1.0f), MLFloat16(1.0f));
83+
NodeArg* q_norm_weight = builder.MakeInitializer<MLFloat16>(q_norm_weight_shape, MLFloat16(1.0f), MLFloat16(1.0f));
9284
NodeArg* k_norm_weight = builder.MakeInitializer<MLFloat16>({kHeadSize}, MLFloat16(1.0f), MLFloat16(1.0f));
9385

9486
// Reshape "shape" initializers.
@@ -262,19 +254,6 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNon1DNor
262254
TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
263255
}
264256

265-
TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsMismatchedNormWeightDtype) {
266-
// SimplifiedLayerNormalization permits its scale (V) to differ from input/output (T),
267-
// but the fused GroupQueryAttention slot reuses T for the norm-weight inputs. Wiring a
268-
// float32 scale into a float16 chain must skip the rewrite to avoid changing type
269-
// constraints on the fused node.
270-
BuildOptions opts;
271-
opts.mismatched_norm_weight_dtype = true;
272-
auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
273-
ASSERT_STATUS_OK(TestGraphTransformer(
274-
build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
275-
TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
276-
}
277-
278257
TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) {
279258
// Build the pattern but assign all nodes to CPU EP. The fusion is gated to WebGPU only,
280259
// so the graph must remain unfused.

0 commit comments

Comments
 (0)