Skip to content

Commit 84d9370

Browse files
committed
Address review feedback on batched right-padded rotary prefill test
- Reword the tolerance-justification comment to name the actual failure modes per EP (CPU: uninitialized attention-probs reads; WebGPU: u32 underflow on rotary past_seqlen, see PR #29002) instead of calling it an "underflow bug" generically. - Add an explicit std::isfinite check over the full batched output so the regression is caught deterministically regardless of whether the allocator returns zeroed pages.
1 parent c4a3254 commit 84d9370

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

onnxruntime/test/contrib_ops/group_query_attention_op_test.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2567,10 +2567,22 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) {
25672567
batch_size, sequence_length, num_heads, kv_num_heads, head_size,
25682568
seqlens_k_data, packed_batched, target_ep);
25692569

2570+
// Guard the regression deterministically: every element of the batched output
2571+
// (including padding rows) must be finite. The CPU root cause is uninitialized
2572+
// attention-probs memory, so a NaN/Inf at any padding position would otherwise
2573+
// depend on the allocator returning non-zero pages.
2574+
for (size_t i = 0; i < batched_output.size(); ++i) {
2575+
ASSERT_TRUE(std::isfinite(batched_output[i]))
2576+
<< "non-finite value at index " << i << " in batched GQA output";
2577+
}
2578+
25702579
// Each batch's real-last-token output (used to predict next token) must match
2571-
// its single-prompt reference. The tolerance is loose enough for fp16 rounding
2572-
// while still catching the underflow bug (which produces values that differ
2573-
// by orders of magnitude or are NaN/Inf).
2580+
// its single-prompt reference. Tolerance is loose enough for fp16 rounding,
2581+
// tight enough to catch the right-padding regressions across EPs:
2582+
// - CPU: uninitialized attention-probs reads at padding positions -> NaN.
2583+
// - WebGPU: u32 underflow on rotary past_seqlen -> out-of-range cos/sin
2584+
// index -> garbage Q/K (see PR #29002).
2585+
// Both manifest as NaN/Inf or values differing by orders of magnitude.
25742586
constexpr float tolerance = 5e-3f;
25752587
for (int b = 0; b < batch_size; ++b) {
25762588
const int real_len = real_lens[b];

0 commit comments

Comments
 (0)