Skip to content

Commit d5e208b

Browse files
committed
[Test] Fix uniform_int_distribution assertion in GQA pre-norm fusion test
MakeInput<int>(shape, min, max) calls Uniform(min, max - 1) which asserts when min == max. Switch the seqlens_k and total_seq_len inputs to the explicit-data overload so the (0,0) and (1,1) ranges are no longer treated as integer distributions.
1 parent 7f69c46 commit d5e208b

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ void BuildQwenQkPostNormPattern(ModelTestBuilder& builder, const BuildOptions& o
7272
std::vector<int64_t>{kBatch, kKvNumHeads, kMaxSeq, kHeadSize}, MLFloat16(0.0f), MLFloat16(0.0f));
7373
NodeArg* past_value = builder.MakeInput<MLFloat16>(
7474
std::vector<int64_t>{kBatch, kKvNumHeads, kMaxSeq, kHeadSize}, MLFloat16(0.0f), MLFloat16(0.0f));
75-
NodeArg* seqlens_k = builder.MakeInput<int32_t>(std::vector<int64_t>{kBatch}, int32_t{0}, int32_t{0});
76-
NodeArg* total_seq_len = builder.MakeInput<int32_t>(std::vector<int64_t>{1}, int32_t{1}, int32_t{1});
75+
// Note: ModelTestBuilder::MakeInput<int>(shape, min, max) calls Uniform(min, max - 1)
76+
// internally, which asserts on min == max. Use the explicit-data overload instead.
77+
NodeArg* seqlens_k = builder.MakeInput<int32_t>(std::vector<int64_t>{kBatch}, std::vector<int32_t>{0});
78+
NodeArg* total_seq_len = builder.MakeInput<int32_t>(std::vector<int64_t>{1}, std::vector<int32_t>{1});
7779

7880
// Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch.)
7981
std::vector<int64_t> q_norm_weight_shape =

0 commit comments

Comments
 (0)