Skip to content

Fix CPU GQA NaN output for right-padded batched prompts with rotary embeddings#29069

Merged
qjia7 merged 6 commits into
mainfrom
copilot/fix-wasm-debug-build-job
Jun 23, 2026
Merged

Fix CPU GQA NaN output for right-padded batched prompts with rotary embeddings#29069
qjia7 merged 6 commits into
mainfrom
copilot/fix-wasm-debug-build-job

Conversation

Copilot AI commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Description

Fixes NaN output in the CPU GQA kernel when running batched right-padded prefill. For padding token positions where seq_causal_length > total_seqlen, the softmax loop was reading beyond the GEMM-filled region of the attention probs buffer into uninitialized memory, producing NaN values that propagated through the V GEMM to the output.

Root cause: In ComputeAttentionProbs, seq_causal_length = causal_past_seqlen + seq + 1 grows with each query position. For right-padded batches, a batch item with real_len < sequence_length has total_seqlen = real_len, but padding positions still iterate up to sequence_length, giving seq_causal_length > total_seqlen. The QK GEMM only fills columns [0, total_seqlen) — positions beyond that are uninitialized.

Fix: Cap the effective causal length at total_seqlen before computing the softmax window:

// gqa_attention_base.h - both float and quantized paths
const size_t effective_causal_length = std::min(seq_causal_length, total_seqlen);
// use effective_causal_length for: local window check, start_offset, window_size, masking loops

Applied to both the non-quantized float path (~line 1097) and the quantized MLAS path (~line 436).

Motivation and Context

The new test GroupQueryAttentionTest.BatchedRightPaddedRotaryPrefill_CPU (added in this PR) exercises batched GQA with heterogeneous real sequence lengths {4, 2, 6} padded to sequence_length=6. Batch item 1 (real_len=2) has padding tokens at positions 2–5; position 3 triggered the NaN via uninitialized attention probs memory.

Copilot AI changed the title [WIP] Fix failing GitHub Actions job for wasm_Debug Fix CPU GQA NaN output for right-padded batched prompts with rotary embeddings Jun 16, 2026
Copilot finished work on behalf of qjia7 June 16, 2026 08:18
Copilot AI requested a review from qjia7 June 16, 2026 08:18
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a CPU GroupQueryAttention (GQA) correctness bug for batched right-padded prefill with rotary embeddings, where the softmax step could read beyond the QK GEMM-initialized region and propagate NaNs into the output. The fix caps the causal length used to compute the softmax window to the per-batch total_seqlen, and adds a regression test that validates per-batch consistency versus single-prompt execution.

Changes:

  • Cap per-token causal length at total_seqlen in both the float and quantized/MLAS softmax+masking paths to prevent out-of-bounds/uninitialized reads.
  • Add a regression test for batched right-padded packed-QKV rotary prefill, comparing each batch item’s real-last-token output to a batch=1 reference (CPU and CUDA EPs).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Caps the effective causal length used by the softmax/masking window to stay within the GEMM-written [0, total_seqlen) region (float + quantized paths).
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Adds a right-padded batched rotary prefill regression test that compares each prompt’s real-last-token output to a single-prompt reference.

Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
@hariharans29

Copy link
Copy Markdown
Member

@copilot - Please resolve conflicts

Copilot AI commented Jun 22, 2026

Copy link
Copy Markdown
Contributor Author

@copilot - Please resolve conflicts

Conflicts resolved in merge commit c4a3254. The test file now adopts the GqaTargetEp enum and MakeExecutionProviderForGqaTest helper that landed in main via PR #29002, and includes all three EP tests (CPU, CUDA, WebGPU).

tianleiwu
tianleiwu previously approved these changes Jun 23, 2026

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed carefully — this is a genuine bug and the fix is correct, minimal, and complete.

Root cause (confirmed). attention_probs is allocated with allocator->Alloc(bytes) and is not zero-initialized, and the QK GEMM only fills columns [0, total_seqlen) per row, where total_seqlen = seqlens_k[b] + 1. Per the GQA spec, seqlens_k is "Equivalent to (total_sequence_lengths - 1)" and is the per-batch KV length excluding padding — so a right-padded batch legitimately has total_seqlen[b] = real_len[b] smaller than the padded query dimension S. The softmax loop iterates over the full padded S, so padding positions get seq_causal_length > total_seqlen[b] and the pre-fix code read the unfilled [total_seqlen, seq_causal_length) region (uninitialized memory → NaN/garbage). This is a valid, supported input, not misuse.

Fix (correct). Capping effective_causal_length = min(seq_causal_length, total_seqlen) keeps the softmax window, the local-window start offset, and both masking loops inside the GEMM-filled region. It is underflow-safe (the … - local_window_size_ subtractions only run under the apply_local guard that ensures effective_causal_length > local_window_size_), covers both softmax paths (quantized MLAS ~L436 and float/fp16 ~L1097), and leaves valid (non-padding) rows bit-identical since seq_causal_length <= total_seqlen there.

Approving. Left two non-blocking test-robustness suggestions inline.

Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
@qjia7 qjia7 marked this pull request as ready for review June 23, 2026 03:07
- Reword the tolerance-justification comment to name the actual failure
  modes per EP (CPU: uninitialized attention-probs reads; WebGPU: u32
  underflow on rotary past_seqlen, see PR #29002) instead of calling it
  an "underflow bug" generically.
- Add an explicit std::isfinite check over the full batched output so the
  regression is caught deterministically regardless of whether the
  allocator returns zeroed pages.

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

APPROVE — the fix is correct, minimal, and complete, and the new test now deterministically guards it.

Bug is real. For right-padded batched prefill, padding query positions have seq_causal_length = causal_past_seqlen + seq + 1 > total_seqlen[b] (where total_seqlen[b] = real_len[b] excludes padding). The QK GEMM only fills columns [0, total_seqlen), while attention_probs is allocated un-zeroed, so the softmax window read into uninitialized memory → NaN propagated through the S*V GEMM into padding output rows. This is a supported configuration (variable-length batched prefill via seqlens_k).

Fix is correct. effective_causal_length = std::min(seq_causal_length, total_seqlen) keeps the softmax window, the local-window start offset, and both masking loops inside the GEMM-filled region. Valid (non-padding) rows are bit-identical since seq_causal_length <= total_seqlen there. The underflow effective_causal_length - local_window_size_ is only evaluated under the apply_local guard, and both softmax/masking paths (quantized MLAS ~L430 and float/fp16 templated ~L1090) are patched — no other path indexes the probs buffer with seq_causal_length.

Test now deterministic. The added per-element ASSERT_TRUE(std::isfinite(batched_output[i])) loop over every element (including padding rows) anchors the test to the fixed behavior and removes the prior reliance on the allocator returning non-zero pages or on the loose magnitude tolerance.

No security or ABI impact; the added std::min is a per-row scalar with no hot-path allocation. Recommend merging.

@qjia7 qjia7 merged commit 701d88f into main Jun 23, 2026
85 checks passed
@qjia7 qjia7 deleted the copilot/fix-wasm-debug-build-job branch June 23, 2026 07:11
tianleiwu pushed a commit that referenced this pull request Jun 23, 2026
…mbeddings (#29069)

### Description

Fixes NaN output in the CPU GQA kernel when running batched right-padded
prefill. For padding token positions where `seq_causal_length >
total_seqlen`, the softmax loop was reading beyond the GEMM-filled
region of the attention probs buffer into uninitialized memory,
producing NaN values that propagated through the V GEMM to the output.

**Root cause:** In `ComputeAttentionProbs`, `seq_causal_length =
causal_past_seqlen + seq + 1` grows with each query position. For
right-padded batches, a batch item with `real_len < sequence_length` has
`total_seqlen = real_len`, but padding positions still iterate up to
`sequence_length`, giving `seq_causal_length > total_seqlen`. The QK
GEMM only fills columns `[0, total_seqlen)` — positions beyond that are
uninitialized.

**Fix:** Cap the effective causal length at `total_seqlen` before
computing the softmax window:

```cpp
// gqa_attention_base.h - both float and quantized paths
const size_t effective_causal_length = std::min(seq_causal_length, total_seqlen);
// use effective_causal_length for: local window check, start_offset, window_size, masking loops
```

Applied to both the non-quantized float path (~line 1097) and the
quantized MLAS path (~line 436).

### Motivation and Context

The new test
`GroupQueryAttentionTest.BatchedRightPaddedRotaryPrefill_CPU` (added in
this PR) exercises batched GQA with heterogeneous real sequence lengths
`{4, 2, 6}` padded to `sequence_length=6`. Batch item 1 (`real_len=2`)
has padding tokens at positions 2–5; position 3 triggered the NaN via
uninitialized attention probs memory.

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants