Fix CPU GQA NaN output for right-padded batched prompts with rotary embeddings#29069
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a CPU GroupQueryAttention (GQA) correctness bug for batched right-padded prefill with rotary embeddings, where the softmax step could read beyond the QK GEMM-initialized region and propagate NaNs into the output. The fix caps the causal length used to compute the softmax window to the per-batch total_seqlen, and adds a regression test that validates per-batch consistency versus single-prompt execution.
Changes:
- Cap per-token causal length at
total_seqlenin both the float and quantized/MLAS softmax+masking paths to prevent out-of-bounds/uninitialized reads. - Add a regression test for batched right-padded packed-QKV rotary prefill, comparing each batch item’s real-last-token output to a batch=1 reference (CPU and CUDA EPs).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h |
Caps the effective causal length used by the softmax/masking window to stay within the GEMM-written [0, total_seqlen) region (float + quantized paths). |
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc |
Adds a right-padded batched rotary prefill regression test that compares each prompt’s real-last-token output to a single-prompt reference. |
|
@copilot - Please resolve conflicts |
tianleiwu
left a comment
There was a problem hiding this comment.
Reviewed carefully — this is a genuine bug and the fix is correct, minimal, and complete.
Root cause (confirmed). attention_probs is allocated with allocator->Alloc(bytes) and is not zero-initialized, and the QK GEMM only fills columns [0, total_seqlen) per row, where total_seqlen = seqlens_k[b] + 1. Per the GQA spec, seqlens_k is "Equivalent to (total_sequence_lengths - 1)" and is the per-batch KV length excluding padding — so a right-padded batch legitimately has total_seqlen[b] = real_len[b] smaller than the padded query dimension S. The softmax loop iterates over the full padded S, so padding positions get seq_causal_length > total_seqlen[b] and the pre-fix code read the unfilled [total_seqlen, seq_causal_length) region (uninitialized memory → NaN/garbage). This is a valid, supported input, not misuse.
Fix (correct). Capping effective_causal_length = min(seq_causal_length, total_seqlen) keeps the softmax window, the local-window start offset, and both masking loops inside the GEMM-filled region. It is underflow-safe (the … - local_window_size_ subtractions only run under the apply_local guard that ensures effective_causal_length > local_window_size_), covers both softmax paths (quantized MLAS ~L436 and float/fp16 ~L1097), and leaves valid (non-padding) rows bit-identical since seq_causal_length <= total_seqlen there.
Approving. Left two non-blocking test-robustness suggestions inline.
- Reword the tolerance-justification comment to name the actual failure modes per EP (CPU: uninitialized attention-probs reads; WebGPU: u32 underflow on rotary past_seqlen, see PR #29002) instead of calling it an "underflow bug" generically. - Add an explicit std::isfinite check over the full batched output so the regression is caught deterministically regardless of whether the allocator returns zeroed pages.
tianleiwu
left a comment
There was a problem hiding this comment.
✅ APPROVE — the fix is correct, minimal, and complete, and the new test now deterministically guards it.
Bug is real. For right-padded batched prefill, padding query positions have seq_causal_length = causal_past_seqlen + seq + 1 > total_seqlen[b] (where total_seqlen[b] = real_len[b] excludes padding). The QK GEMM only fills columns [0, total_seqlen), while attention_probs is allocated un-zeroed, so the softmax window read into uninitialized memory → NaN propagated through the S*V GEMM into padding output rows. This is a supported configuration (variable-length batched prefill via seqlens_k).
Fix is correct. effective_causal_length = std::min(seq_causal_length, total_seqlen) keeps the softmax window, the local-window start offset, and both masking loops inside the GEMM-filled region. Valid (non-padding) rows are bit-identical since seq_causal_length <= total_seqlen there. The underflow effective_causal_length - local_window_size_ is only evaluated under the apply_local guard, and both softmax/masking paths (quantized MLAS ~L430 and float/fp16 templated ~L1090) are patched — no other path indexes the probs buffer with seq_causal_length.
Test now deterministic. The added per-element ASSERT_TRUE(std::isfinite(batched_output[i])) loop over every element (including padding rows) anchors the test to the fixed behavior and removes the prior reliance on the allocator returning non-zero pages or on the loose magnitude tolerance.
No security or ABI impact; the added std::min is a per-row scalar with no hot-path allocation. Recommend merging.
…mbeddings (#29069) ### Description Fixes NaN output in the CPU GQA kernel when running batched right-padded prefill. For padding token positions where `seq_causal_length > total_seqlen`, the softmax loop was reading beyond the GEMM-filled region of the attention probs buffer into uninitialized memory, producing NaN values that propagated through the V GEMM to the output. **Root cause:** In `ComputeAttentionProbs`, `seq_causal_length = causal_past_seqlen + seq + 1` grows with each query position. For right-padded batches, a batch item with `real_len < sequence_length` has `total_seqlen = real_len`, but padding positions still iterate up to `sequence_length`, giving `seq_causal_length > total_seqlen`. The QK GEMM only fills columns `[0, total_seqlen)` — positions beyond that are uninitialized. **Fix:** Cap the effective causal length at `total_seqlen` before computing the softmax window: ```cpp // gqa_attention_base.h - both float and quantized paths const size_t effective_causal_length = std::min(seq_causal_length, total_seqlen); // use effective_causal_length for: local window check, start_offset, window_size, masking loops ``` Applied to both the non-quantized float path (~line 1097) and the quantized MLAS path (~line 436). ### Motivation and Context The new test `GroupQueryAttentionTest.BatchedRightPaddedRotaryPrefill_CPU` (added in this PR) exercises batched GQA with heterogeneous real sequence lengths `{4, 2, 6}` padded to `sequence_length=6`. Batch item 1 (`real_len=2`) has padding tokens at positions 2–5; position 3 triggered the NaN via uninitialized attention probs memory. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
Description
Fixes NaN output in the CPU GQA kernel when running batched right-padded prefill. For padding token positions where
seq_causal_length > total_seqlen, the softmax loop was reading beyond the GEMM-filled region of the attention probs buffer into uninitialized memory, producing NaN values that propagated through the V GEMM to the output.Root cause: In
ComputeAttentionProbs,seq_causal_length = causal_past_seqlen + seq + 1grows with each query position. For right-padded batches, a batch item withreal_len < sequence_lengthhastotal_seqlen = real_len, but padding positions still iterate up tosequence_length, givingseq_causal_length > total_seqlen. The QK GEMM only fills columns[0, total_seqlen)— positions beyond that are uninitialized.Fix: Cap the effective causal length at
total_seqlenbefore computing the softmax window:Applied to both the non-quantized float path (~line 1097) and the quantized MLAS path (~line 436).
Motivation and Context
The new test
GroupQueryAttentionTest.BatchedRightPaddedRotaryPrefill_CPU(added in this PR) exercises batched GQA with heterogeneous real sequence lengths{4, 2, 6}padded tosequence_length=6. Batch item 1 (real_len=2) has padding tokens at positions 2–5; position 3 triggered the NaN via uninitialized attention probs memory.