Add seqlens_k bounds validation in GroupQueryAttention to prevent GEMM OOB#28031
Add seqlens_k bounds validation in GroupQueryAttention to prevent GEMM OOB#28031
Conversation
There was a problem hiding this comment.
Pull request overview
This PR hardens the CPU GroupQueryAttention contrib op against malicious/invalid seqlens_k values that could otherwise drive GEMM dimensions and lead to out-of-bounds reads, and adds regression tests for the reported OOB scenario.
Changes:
- Add runtime validation of
seqlens_kvalues inGroupQueryAttention::Compute()before they influence GEMM dimensions. - Fix
seqlens_kshape validation in the helper (&&→||) so incorrect tensor shapes are rejected correctly. - Add CPU-focused regression tests for negative/oversized/multi-batch/boundary-valid
seqlens_kinputs.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | Adds per-batch seqlens_k bounds checks prior to GEMM/attention computations. |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h | Corrects seqlens_k shape validation logic to properly reject invalid shapes. |
| onnxruntime/test/contrib_ops/group_query_attention_op_test.cc | Introduces new regression tests targeting invalid seqlens_k inputs on CPU. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
… int64 overflow safety, non-first-prompt test - Replace ORT_RETURN_IF (generic FAIL) with ORT_MAKE_STATUS INVALID_ARGUMENT for consistent error classification with CheckInputs() - Add total_sequence_length validation to prevent attention_bias OOB when seqlens_k+1 > total_sequence_length but <= present_kv_seqlen - Use int64_t for seqlens_k+1 computation to prevent signed overflow at INT32_MAX - Add SubsequentPromptSeqlensKUnderflow test that independently exercises the !is_first_prompt underflow guard with a positive seqlens_k value - Add SeqlensKExceedsTotalSeqLen test for the new total_sequence_length bound - Add BoundaryValidSeqlensKWithLargerPast boundary success test - Rename original NonPrompt test to NegativeSeqlensKWithPast to clarify intent - Fix test helper: head_size=8 (must be multiple of 8), correct output shapes, update comment to note past support Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
A crafted model can supply total_sequence_length <= 0. This flows into present_sequence_length (which could end up 0 if no past), then into buffer allocation sizes. A negative value would be clamped by std::max only if past_sequence_length > 0. With no past and total_sequence_length = 0, present_kv_seqlen = 0, which makes the seqlens_k[b] >= present_kv_seqlen check reject any seqlens_k >= 0 (effectively forcing failure), but allocations of zero-size buffers might have platform-specific behavior. Recommendation: Add total_sequence_length > 0 validation in CheckInputs. Refers to: onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h:277 in 5edfbfc. [](commit_id = 5edfbfc, deletion_comment = False) |
|
No SafeInt on seqlens_k-Derived Arithmetic in gqa_attention_base.h:209-211 const size_t total_seqlen = static_cast<size_t>(seqlens_k[batch_index]) + 1; While the upstream validation makes these safe, defense-in-depth would use SafeInt<size_t> (already used elsewhere in the same file for packed_batch_stride, probs_matrix_bytes, etc.) to guard the cast and subsequent arithmetic. This would catch any future regression if the validation is loosened. Same pattern at line 443-444. |
|
At gqa_attention_base.h:289: total_seqlen is size_t from the validated seqlens_k. The multiplication head_size * (sequence_length + total_seqlen) * sizeof(float) doesn't use SafeInt, unlike the similar allocation at line 98 which does. Same issue at line 480. |
|
At attention_helper.h:172: past_chunk_length is past_seqlen * head_size — no SafeInt on the multiply. If past_seqlen were corrupted (would require bypassing validation), this could overflow. The multiply with sizeof(T) is also unguarded. |
|
Gaps in test coverage: No test for seqlens_k[b] == present_kv_seqlen - 1 — The max valid boundary. The BoundaryValidSeqlensK test uses seqlens_k=0 with present=1, which is the minimum valid boundary, not the maximum. No test for seqlens_k[b] == present_kv_seqlen — Boundary just past the valid range (off-by-one confirmation). No test with INT32_MAX or very large seqlens_k — Would confirm the seqlens_k_data[b] + 1 overflow is harmless. No test with total_sequence_length <= 0 — The total_sequence_length tensor is not validated for non-positive values. No test with packed QKV — All tests use separate Q/K/V. The packed path has different pointer arithmetic. No test with MLFloat16 type — Only float type is tested. The kernel is registered for both. No test with do_rotary=1 — The rotary code path at line 174 also reads seqlens_k data directly (for position IDs), and this path is untested. No functional correctness test for valid inputs — All success tests use SetOutputTolerance(1e6f) which effectively disables output checking. There should be at least one test verifying correct attention output. |
|
Is this operator implemented in other EPs, notably CUDA? |
|
Addressed all feedback in the latest two commits (5d6f96d, 11f5d83): int64 cast, total_sequence_length validation, SafeInt wraps in gqa_attention_base.h and attention_helper.h, and additional boundary/INT32_MAX tests. For the remaining test gaps (packed QKV, MLFloat16, do_rotary, functional correctness), these exercise different code paths unrelated to the seqlens_k validation and each requires non-trivial test setup (different input layouts, type dispatch, rotary cache tensors, hand-computed expected outputs). Happy to add them in a follow-up if needed. Re CUDA: GQA exists in CUDA, WebGPU, and JS, but none read seqlens_k on the host. CUDA passes it as a device pointer to flash attention kernels, so host-side bounds validation doesn't apply there. The shared CheckInputs helper now validates total_sequence_length > 0 which benefits all EPs. |
|
One allocation was missed at gqa_attention_base.h:273: // Current (unguarded): // Should be: |
Description
Validate seqlens_k tensor values in the CPU GroupQueryAttention operator before they are used as GEMM dimensions. Without this check, a crafted model can supply negative or oversized seqlens_k values that cause out-of-bounds reads in the K/V present cache buffers.
Changes
Compute()before any seqlens_k access:seqlens_k[b] >= 0(prevents unsigned wraparound instatic_cast<size_t>)seqlens_k[b] + 1 <= present_kv_seqlen(prevents GEMM reading past K/V buffer)seqlens_k[b] + 1 >= sequence_length(prevents underflow inpast_seqlen = total_seqlen - sequence_length)&&to||) so wrong-length tensors are correctly rejectedMotivation and Context
MSRC case 108962: A crafted model can set seqlens_k values that, when cast to
size_tand used as GEMM N dimension, cause heap OOB reads from the present K/V cache buffers.