Skip to content
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ T* ConcatStateChunkGQA(const T* past,
T* p = start;
if (!past_present_share_buffer && past_chunk_length > 0) {
const T* src_past = past + i * past_buff_chunk_length;
memcpy(p, src_past, past_chunk_length * sizeof(T));
memcpy(p, src_past, SafeInt<size_t>(past_chunk_length) * sizeof(T));
}
p += past_chunk_length;

memcpy(p, chunk, new_chunk_length * sizeof(T));
memcpy(p, chunk, SafeInt<size_t>(new_chunk_length) * sizeof(T));
return start;
}

Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ class GQAAttentionBase {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const size_t batch_index = i / num_heads_;
const size_t head_index = i % num_heads_;
const size_t total_seqlen = static_cast<size_t>(seqlens_k[batch_index]) + 1;
const size_t total_seqlen = SafeInt<size_t>(seqlens_k[batch_index]) + 1;
const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length
const size_t past_chunk_length = past_seqlen * head_size;
const size_t past_chunk_length = SafeInt<size_t>(past_seqlen) * head_size;

const ptrdiff_t output_offset = SafeInt<ptrdiff_t>(i) * sequence_length * present_buffer_sequence_length;
U* output = attention_probs + output_offset;
Expand Down Expand Up @@ -270,7 +270,7 @@ class GQAAttentionBase {
static_cast<int>(present_buffer_sequence_length),
MLFloat16(alpha).val, static_cast<uint16_t>(0) /*beta*/, nullptr);
} else {
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
size_t bytes = SafeInt<size_t>(head_size) * (sequence_length + total_seqlen) * sizeof(float);
auto q_k_fp32 = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator));

Expand All @@ -291,7 +291,7 @@ class GQAAttentionBase {
if constexpr (!std::is_same_v<U, T>) {
static_assert(std::is_same_v<U, float> && std::is_same_v<T, MLFloat16>);

size_t bytes = attention_total_seqlen * sizeof(float);
size_t bytes = SafeInt<size_t>(attention_total_seqlen) * sizeof(float);
attention_bias_thread_fp32 = static_cast<float*>(allocator->Alloc(bytes));
}
}
Expand Down Expand Up @@ -440,9 +440,9 @@ class GQAAttentionBase {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const size_t batch_index = i / num_heads_;
const size_t head_index = i % num_heads_;
const size_t total_seqlen = static_cast<size_t>(seqlens_k[batch_index]) + 1;
const size_t total_seqlen = SafeInt<size_t>(seqlens_k[batch_index]) + 1;
const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length
const size_t past_chunk_length = past_seqlen * head_size;
const size_t past_chunk_length = SafeInt<size_t>(past_seqlen) * head_size;

const T* v;
if (packed_qkv) {
Expand Down Expand Up @@ -472,7 +472,7 @@ class GQAAttentionBase {
v, static_cast<int>(head_size), output_current, static_cast<int>(hidden_size),
MLFloat16(1.0f).val, static_cast<uint16_t>(0) /*beta*/, nullptr);
} else {
size_t bytes = head_size * total_seqlen * sizeof(float);
size_t bytes = SafeInt<size_t>(head_size) * total_seqlen * sizeof(float);
auto v_fp32 = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator));

Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
const int sequence_length = parameters.sequence_length;
const int present_kv_seqlen = parameters.seqlen_present_kv_cache;
int head_size = parameters.head_size;

// Validate seqlens_k values before they are used as GEMM dimensions to prevent OOB access.
{
const int32_t* seqlens_k_data = seqlens_k->Data<int32_t>();
for (int b = 0; b < batch_size; b++) {
if (seqlens_k_data[b] < 0 || seqlens_k_data[b] >= present_kv_seqlen) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k[", b, "] = ", seqlens_k_data[b],
" is out of range [0, ", present_kv_seqlen, ")");
}
if (!parameters.is_first_prompt && static_cast<int64_t>(seqlens_k_data[b]) + 1 < sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k[", b, "] = ", seqlens_k_data[b],
" is too small for sequence_length ", sequence_length);
}
}
}
int q_hidden_size = parameters.hidden_size;
const bool packed_qkv = parameters.is_packed_qkv;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ Status CheckInputs(const T* query,
}

const auto& seqlens_k_dim = seqlens_k->Shape().GetDims();
if (seqlens_k_dim.size() != 1 && seqlens_k_dim[0] != batch_size) {
if (seqlens_k_dim.size() != 1 || seqlens_k_dim[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (batch_size).");
}
Expand All @@ -275,6 +275,10 @@ Status CheckInputs(const T* query,
// When graph capture is enabled, total_seqlen is on GPU and cannot be read. Skip validation.
const bool is_total_seqlen_on_cpu = (total_seqlen->Location().device.Type() == OrtDevice::CPU);
int total_sequence_length = is_total_seqlen_on_cpu ? *((*total_seqlen).template Data<int32_t>()) : 0;
if (is_total_seqlen_on_cpu && total_sequence_length <= 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"total_sequence_length must be positive, got ", total_sequence_length, ".");
}
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);

int rotary_dim = 0;
Expand Down
Loading
Loading