Skip to content

Add seqlens_k bounds validation in GroupQueryAttention to prevent GEMM OOB#28031

Merged
vraspar merged 7 commits intomainfrom
vraspar/fix-gqa-seqlens-k-oob
Apr 21, 2026
Merged

Add seqlens_k bounds validation in GroupQueryAttention to prevent GEMM OOB#28031
vraspar merged 7 commits intomainfrom
vraspar/fix-gqa-seqlens-k-oob

Conversation

@vraspar
Copy link
Copy Markdown
Contributor

@vraspar vraspar commented Apr 10, 2026

Description

Validate seqlens_k tensor values in the CPU GroupQueryAttention operator before they are used as GEMM dimensions. Without this check, a crafted model can supply negative or oversized seqlens_k values that cause out-of-bounds reads in the K/V present cache buffers.

Changes

  • group_query_attention.cc: Add validation loop in Compute() before any seqlens_k access:
    • seqlens_k[b] >= 0 (prevents unsigned wraparound in static_cast<size_t>)
    • seqlens_k[b] + 1 <= present_kv_seqlen (prevents GEMM reading past K/V buffer)
    • For non-first-prompt: seqlens_k[b] + 1 >= sequence_length (prevents underflow in past_seqlen = total_seqlen - sequence_length)
  • group_query_attention_helper.h: Fix seqlens_k shape validation (&& to ||) so wrong-length tensors are correctly rejected
  • Tests: 4 regression tests covering negative, oversized, multi-batch, and boundary-valid seqlens_k values

Motivation and Context

MSRC case 108962: A crafted model can set seqlens_k values that, when cast to size_t and used as GEMM N dimension, cause heap OOB reads from the present K/V cache buffers.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens the CPU GroupQueryAttention contrib op against malicious/invalid seqlens_k values that could otherwise drive GEMM dimensions and lead to out-of-bounds reads, and adds regression tests for the reported OOB scenario.

Changes:

  • Add runtime validation of seqlens_k values in GroupQueryAttention::Compute() before they influence GEMM dimensions.
  • Fix seqlens_k shape validation in the helper (&&||) so incorrect tensor shapes are rejected correctly.
  • Add CPU-focused regression tests for negative/oversized/multi-batch/boundary-valid seqlens_k inputs.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Adds per-batch seqlens_k bounds checks prior to GEMM/attention computations.
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Corrects seqlens_k shape validation logic to properly reject invalid shapes.
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Introduces new regression tests targeting invalid seqlens_k inputs on CPU.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
vraspar and others added 3 commits April 10, 2026 14:29
… int64 overflow safety, non-first-prompt test

- Replace ORT_RETURN_IF (generic FAIL) with ORT_MAKE_STATUS INVALID_ARGUMENT
  for consistent error classification with CheckInputs()
- Add total_sequence_length validation to prevent attention_bias OOB when
  seqlens_k+1 > total_sequence_length but <= present_kv_seqlen
- Use int64_t for seqlens_k+1 computation to prevent signed overflow at INT32_MAX
- Add SubsequentPromptSeqlensKUnderflow test that independently exercises the
  !is_first_prompt underflow guard with a positive seqlens_k value
- Add SeqlensKExceedsTotalSeqLen test for the new total_sequence_length bound
- Add BoundaryValidSeqlensKWithLargerPast boundary success test
- Rename original NonPrompt test to NegativeSeqlensKWithPast to clarify intent
- Fix test helper: head_size=8 (must be multiple of 8), correct output shapes,
  update comment to note past support

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@yuslepukhin
Copy link
Copy Markdown
Member

int total_sequence_length = is_total_seqlen_on_cpu ? *((*total_seqlen).template Data<int32_t>()) : 0;

total_sequence_length Not Validated for Negativity.

A crafted model can supply total_sequence_length <= 0. This flows into present_sequence_length (which could end up 0 if no past), then into buffer allocation sizes. A negative value would be clamped by std::max only if past_sequence_length > 0. With no past and total_sequence_length = 0, present_kv_seqlen = 0, which makes the seqlens_k[b] >= present_kv_seqlen check reject any seqlens_k >= 0 (effectively forcing failure), but allocations of zero-size buffers might have platform-specific behavior.

Recommendation: Add total_sequence_length > 0 validation in CheckInputs.


Refers to: onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h:277 in 5edfbfc. [](commit_id = 5edfbfc, deletion_comment = False)

@yuslepukhin
Copy link
Copy Markdown
Member

No SafeInt on seqlens_k-Derived Arithmetic in gqa_attention_base.h:209-211

const size_t total_seqlen = static_cast<size_t>(seqlens_k[batch_index]) + 1;
const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length;
const size_t past_chunk_length = past_seqlen * head_size;

While the upstream validation makes these safe, defense-in-depth would use SafeInt<size_t> (already used elsewhere in the same file for packed_batch_stride, probs_matrix_bytes, etc.) to guard the cast and subsequent arithmetic. This would catch any future regression if the validation is loosened.

Same pattern at line 443-444.

@yuslepukhin
Copy link
Copy Markdown
Member

At gqa_attention_base.h:289:

total_seqlen is size_t from the validated seqlens_k. The multiplication head_size * (sequence_length + total_seqlen) * sizeof(float) doesn't use SafeInt, unlike the similar allocation at line 98 which does. Same issue at line 480.

@yuslepukhin
Copy link
Copy Markdown
Member

At attention_helper.h:172:

past_chunk_length is past_seqlen * head_size — no SafeInt on the multiply. If past_seqlen were corrupted (would require bypassing validation), this could overflow. The multiply with sizeof(T) is also unguarded.

@yuslepukhin
Copy link
Copy Markdown
Member

Gaps in test coverage:

No test for seqlens_k[b] == present_kv_seqlen - 1 — The max valid boundary. The BoundaryValidSeqlensK test uses seqlens_k=0 with present=1, which is the minimum valid boundary, not the maximum.

No test for seqlens_k[b] == present_kv_seqlen — Boundary just past the valid range (off-by-one confirmation).

No test with INT32_MAX or very large seqlens_k — Would confirm the seqlens_k_data[b] + 1 overflow is harmless.

No test with total_sequence_length <= 0 — The total_sequence_length tensor is not validated for non-positive values.

No test with packed QKV — All tests use separate Q/K/V. The packed path has different pointer arithmetic.

No test with MLFloat16 type — Only float type is tested. The kernel is registered for both.

No test with do_rotary=1 — The rotary code path at line 174 also reads seqlens_k data directly (for position IDs), and this path is untested.

No functional correctness test for valid inputs — All success tests use SetOutputTolerance(1e6f) which effectively disables output checking. There should be at least one test verifying correct attention output.

Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
@yuslepukhin
Copy link
Copy Markdown
Member

Is this operator implemented in other EPs, notably CUDA?

@vraspar
Copy link
Copy Markdown
Contributor Author

vraspar commented Apr 16, 2026

Addressed all feedback in the latest two commits (5d6f96d, 11f5d83): int64 cast, total_sequence_length validation, SafeInt wraps in gqa_attention_base.h and attention_helper.h, and additional boundary/INT32_MAX tests.

For the remaining test gaps (packed QKV, MLFloat16, do_rotary, functional correctness), these exercise different code paths unrelated to the seqlens_k validation and each requires non-trivial test setup (different input layouts, type dispatch, rotary cache tensors, hand-computed expected outputs). Happy to add them in a follow-up if needed.

Re CUDA: GQA exists in CUDA, WebGPU, and JS, but none read seqlens_k on the host. CUDA passes it as a device pointer to flash attention kernels, so host-side bounds validation doesn't apply there. The shared CheckInputs helper now validates total_sequence_length > 0 which benefits all EPs.

@yuslepukhin
Copy link
Copy Markdown
Member

One allocation was missed at gqa_attention_base.h:273:

// Current (unguarded):
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);

// Should be:
size_t bytes = SafeInt<size_t>(head_size) * (sequence_length + total_seqlen) * sizeof(float);

@vraspar vraspar requested a review from yuslepukhin April 20, 2026 22:29
Copy link
Copy Markdown
Member

@yuslepukhin yuslepukhin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@vraspar vraspar enabled auto-merge (squash) April 20, 2026 23:47
@vraspar vraspar merged commit 7c56fa8 into main Apr 21, 2026
89 checks passed
@vraspar vraspar deleted the vraspar/fix-gqa-seqlens-k-oob branch April 21, 2026 00:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants