Skip to content

Commit 7c56fa8

Browse files
vrasparCopilot
andauthored
Add seqlens_k bounds validation in GroupQueryAttention to prevent GEMM OOB (#28031)
### Description Validate seqlens_k tensor values in the CPU GroupQueryAttention operator before they are used as GEMM dimensions. Without this check, a crafted model can supply negative or oversized seqlens_k values that cause out-of-bounds reads in the K/V present cache buffers. Fixes https://portal.microsofticm.com/imp/v5/incidents/details/31000000559235/summary ### Changes - **group_query_attention.cc**: Add validation loop in `Compute()` before any seqlens_k access: - `seqlens_k[b] >= 0` (prevents unsigned wraparound in `static_cast<size_t>`) - `seqlens_k[b] + 1 <= present_kv_seqlen` (prevents GEMM reading past K/V buffer) - For non-first-prompt: `seqlens_k[b] + 1 >= sequence_length` (prevents underflow in `past_seqlen = total_seqlen - sequence_length`) - **group_query_attention_helper.h**: Fix seqlens_k shape validation (`&&` to `||`) so wrong-length tensors are correctly rejected - **Tests**: 4 regression tests covering negative, oversized, multi-batch, and boundary-valid seqlens_k values ### Motivation and Context MSRC case 108962: A crafted model can set seqlens_k values that, when cast to `size_t` and used as GEMM N dimension, cause heap OOB reads from the present K/V cache buffers. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent fb13eb3 commit 7c56fa8

5 files changed

Lines changed: 342 additions & 10 deletions

File tree

onnxruntime/contrib_ops/cpu/bert/attention_helper.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ T* ConcatStateChunkGQA(const T* past,
168168
T* p = start;
169169
if (!past_present_share_buffer && past_chunk_length > 0) {
170170
const T* src_past = past + i * past_buff_chunk_length;
171-
memcpy(p, src_past, past_chunk_length * sizeof(T));
171+
memcpy(p, src_past, SafeInt<size_t>(past_chunk_length) * sizeof(T));
172172
}
173173
p += past_chunk_length;
174174

175-
memcpy(p, chunk, new_chunk_length * sizeof(T));
175+
memcpy(p, chunk, SafeInt<size_t>(new_chunk_length) * sizeof(T));
176176
return start;
177177
}
178178

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ class GQAAttentionBase {
206206
for (std::ptrdiff_t i = begin; i != end; ++i) {
207207
const size_t batch_index = i / num_heads_;
208208
const size_t head_index = i % num_heads_;
209-
const size_t total_seqlen = static_cast<size_t>(seqlens_k[batch_index]) + 1;
209+
const size_t total_seqlen = SafeInt<size_t>(seqlens_k[batch_index]) + 1;
210210
const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length
211-
const size_t past_chunk_length = past_seqlen * head_size;
211+
const size_t past_chunk_length = SafeInt<size_t>(past_seqlen) * head_size;
212212

213213
const ptrdiff_t output_offset = SafeInt<ptrdiff_t>(i) * sequence_length * present_buffer_sequence_length;
214214
U* output = attention_probs + output_offset;
@@ -270,7 +270,7 @@ class GQAAttentionBase {
270270
static_cast<int>(present_buffer_sequence_length),
271271
MLFloat16(alpha).val, static_cast<uint16_t>(0) /*beta*/, nullptr);
272272
} else {
273-
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
273+
size_t bytes = SafeInt<size_t>(head_size) * (sequence_length + total_seqlen) * sizeof(float);
274274
auto q_k_fp32 = allocator->Alloc(bytes);
275275
BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator));
276276

@@ -291,7 +291,7 @@ class GQAAttentionBase {
291291
if constexpr (!std::is_same_v<U, T>) {
292292
static_assert(std::is_same_v<U, float> && std::is_same_v<T, MLFloat16>);
293293

294-
size_t bytes = attention_total_seqlen * sizeof(float);
294+
size_t bytes = SafeInt<size_t>(attention_total_seqlen) * sizeof(float);
295295
attention_bias_thread_fp32 = static_cast<float*>(allocator->Alloc(bytes));
296296
}
297297
}
@@ -440,9 +440,9 @@ class GQAAttentionBase {
440440
for (std::ptrdiff_t i = begin; i != end; ++i) {
441441
const size_t batch_index = i / num_heads_;
442442
const size_t head_index = i % num_heads_;
443-
const size_t total_seqlen = static_cast<size_t>(seqlens_k[batch_index]) + 1;
443+
const size_t total_seqlen = SafeInt<size_t>(seqlens_k[batch_index]) + 1;
444444
const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length
445-
const size_t past_chunk_length = past_seqlen * head_size;
445+
const size_t past_chunk_length = SafeInt<size_t>(past_seqlen) * head_size;
446446

447447
const T* v;
448448
if (packed_qkv) {
@@ -472,7 +472,7 @@ class GQAAttentionBase {
472472
v, static_cast<int>(head_size), output_current, static_cast<int>(hidden_size),
473473
MLFloat16(1.0f).val, static_cast<uint16_t>(0) /*beta*/, nullptr);
474474
} else {
475-
size_t bytes = head_size * total_seqlen * sizeof(float);
475+
size_t bytes = SafeInt<size_t>(head_size) * total_seqlen * sizeof(float);
476476
auto v_fp32 = allocator->Alloc(bytes);
477477
BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator));
478478

onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,23 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
8282
const int sequence_length = parameters.sequence_length;
8383
const int present_kv_seqlen = parameters.seqlen_present_kv_cache;
8484
int head_size = parameters.head_size;
85+
86+
// Validate seqlens_k values before they are used as GEMM dimensions to prevent OOB access.
87+
{
88+
const int32_t* seqlens_k_data = seqlens_k->Data<int32_t>();
89+
for (int b = 0; b < batch_size; b++) {
90+
if (seqlens_k_data[b] < 0 || seqlens_k_data[b] >= present_kv_seqlen) {
91+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
92+
"seqlens_k[", b, "] = ", seqlens_k_data[b],
93+
" is out of range [0, ", present_kv_seqlen, ")");
94+
}
95+
if (!parameters.is_first_prompt && static_cast<int64_t>(seqlens_k_data[b]) + 1 < sequence_length) {
96+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
97+
"seqlens_k[", b, "] = ", seqlens_k_data[b],
98+
" is too small for sequence_length ", sequence_length);
99+
}
100+
}
101+
}
85102
int q_hidden_size = parameters.hidden_size;
86103
const bool packed_qkv = parameters.is_packed_qkv;
87104

onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ Status CheckInputs(const T* query,
262262
}
263263

264264
const auto& seqlens_k_dim = seqlens_k->Shape().GetDims();
265-
if (seqlens_k_dim.size() != 1 && seqlens_k_dim[0] != batch_size) {
265+
if (seqlens_k_dim.size() != 1 || seqlens_k_dim[0] != batch_size) {
266266
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
267267
"seqlens_k must be shape (batch_size).");
268268
}
@@ -275,6 +275,10 @@ Status CheckInputs(const T* query,
275275
// When graph capture is enabled, total_seqlen is on GPU and cannot be read. Skip validation.
276276
const bool is_total_seqlen_on_cpu = (total_seqlen->Location().device.Type() == OrtDevice::CPU);
277277
int total_sequence_length = is_total_seqlen_on_cpu ? *((*total_seqlen).template Data<int32_t>()) : 0;
278+
if (is_total_seqlen_on_cpu && total_sequence_length <= 0) {
279+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
280+
"total_sequence_length must be positive, got ", total_sequence_length, ".");
281+
}
278282
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
279283

280284
int rotary_dim = 0;

0 commit comments

Comments
 (0)