Skip to content

Commit 2bf09e9

Browse files
edgchen1Copilot
andauthored
Fix CPU Attention overflow issue (#27822)
### Description <!-- Describe your changes. --> Fix `int` overflow issue in `ComputeAttentionSoftmaxInplace<MLFloat16>()` by using `size_t` and `SafeInt` instead. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Fix overflow issue. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 32d32b0 commit 2bf09e9

4 files changed

Lines changed: 124 additions & 17 deletions

File tree

onnxruntime/core/providers/cpu/llm/attention.cc

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "core/providers/cpu/llm/attention.h"
55
#include "core/providers/cpu/llm/attention_helper.h"
6+
#include "core/providers/cpu/llm/attention_softmax.h"
67

78
#include "core/common/common.h"
89
#include "core/common/safeint.h"
@@ -77,23 +78,6 @@ void make_copy<MLFloat16, bool>(MLFloat16* mask_data, const bool* mask_index, si
7778
}
7879
}
7980

80-
template <typename T>
81-
inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp, AllocatorPtr) {
82-
MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp);
83-
}
84-
85-
template <>
86-
inline void ComputeAttentionSoftmaxInplace<MLFloat16>(MLFloat16* score, int N, int D, ThreadPool* tp, AllocatorPtr allocator) {
87-
ORT_ENFORCE(tp == nullptr, "No parallelized version of softmax for float16.");
88-
// Mlas Lacks kernels for fp16 softmax, we convert into float32 and call the float32 version.
89-
void* allocated_ptr = allocator->Alloc(static_cast<size_t>(N * D * sizeof(float)));
90-
BufferUniquePtr float_buffer(allocated_ptr, BufferDeleter(allocator));
91-
float* ptr = reinterpret_cast<float*>(allocated_ptr);
92-
MlasConvertHalfToFloatBuffer(score, ptr, N * D);
93-
MlasComputeSoftmax(ptr, ptr, N, D, false, false, 0.0f, tp);
94-
MlasConvertFloatToHalfBuffer(ptr, score, N * D);
95-
}
96-
9781
template <typename T>
9882
inline void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, T softcap) {
9983
MlasComputeSoftcap(scores, scores, sequence_length, softcap);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/common/common.h"
7+
#include "core/common/float16.h"
8+
#include "core/common/safeint.h"
9+
#include "core/framework/allocator.h"
10+
#include "core/framework/buffer_deleter.h"
11+
#include "core/mlas/inc/mlas.h"
12+
#include "core/platform/threadpool.h"
13+
14+
namespace onnxruntime {
15+
16+
template <typename T>
17+
inline void ComputeAttentionSoftmaxInplace(T* score, size_t N, size_t D,
18+
concurrency::ThreadPool* tp, AllocatorPtr) {
19+
MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp);
20+
}
21+
22+
template <>
23+
inline void ComputeAttentionSoftmaxInplace<MLFloat16>(MLFloat16* score, size_t N, size_t D,
24+
concurrency::ThreadPool* tp, AllocatorPtr allocator) {
25+
ORT_ENFORCE(tp == nullptr, "No parallelized version of softmax for float16.");
26+
// MLAS lacks kernels for fp16 softmax, so we convert to float32 and use the float32 version.
27+
auto num_elements = SafeInt<size_t>(N) * D;
28+
void* allocated_ptr = allocator->Alloc(num_elements * sizeof(float));
29+
BufferUniquePtr float_buffer(allocated_ptr, BufferDeleter(allocator));
30+
float* ptr = reinterpret_cast<float*>(allocated_ptr);
31+
MlasConvertHalfToFloatBuffer(score, ptr, num_elements);
32+
MlasComputeSoftmax(ptr, ptr, N, D, false, false, 0.0f, tp);
33+
MlasConvertFloatToHalfBuffer(ptr, score, num_elements);
34+
}
35+
36+
} // namespace onnxruntime

onnxruntime/test/providers/cpu/llm/attention_op_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include <cassert>
5+
#include <limits>
56
#include "gtest/gtest.h"
67
#include "core/session/onnxruntime_cxx_api.h"
78
#include "test/common/tensor_op_test_utils.h"
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#if !defined(ORT_NO_EXCEPTIONS)
5+
6+
#include <exception>
7+
#include <limits>
8+
9+
#include "gtest/gtest.h"
10+
#include "gmock/gmock.h"
11+
12+
#include "core/framework/allocator.h"
13+
#include "core/providers/cpu/llm/attention_softmax.h"
14+
15+
namespace onnxruntime {
16+
namespace test {
17+
18+
// Regression test for integer overflow in FP16 softmax allocation.
19+
// ComputeAttentionSoftmaxInplace<MLFloat16> previously used int for N and D, so N*D could overflow int32.
20+
// The fix changed parameters to size_t and uses SafeInt for the multiplication.
21+
//
22+
// This test calls ComputeAttentionSoftmaxInplace<MLFloat16> directly with overflow-triggering dimensions
23+
// (N=46341, D=46341, where N*D > INT_MAX).
24+
// A custom allocator intercepts the Alloc call to verify the requested size is computed correctly with size_t
25+
// arithmetic, without actually allocating the ~8GB buffer.
26+
//
27+
// On 32-bit builds, SafeInt<size_t> will signal an overflow for the requested size.
28+
TEST(AttentionSoftmaxTest, Fp16OverflowAllocation) {
29+
// Custom exception thrown by the allocator to distinguish it from SafeInt overflow.
30+
struct AllocationIntercepted : std::exception {
31+
const char* what() const noexcept override { return "allocation intercepted"; }
32+
};
33+
34+
// Custom allocator that records the requested allocation size and throws to avoid actually allocating the
35+
// (very large) buffer.
36+
class OverflowCheckAllocator : public IAllocator {
37+
public:
38+
OverflowCheckAllocator()
39+
: IAllocator(OrtMemoryInfo(CPU, OrtDeviceAllocator)) {}
40+
void* Alloc(size_t size) override {
41+
last_alloc_size_ = size;
42+
throw AllocationIntercepted();
43+
}
44+
void Free(void*) override {}
45+
size_t LastAllocSize() const { return last_alloc_size_; }
46+
47+
private:
48+
size_t last_alloc_size_ = 0;
49+
};
50+
51+
constexpr size_t N = 46341;
52+
constexpr size_t D = 46341;
53+
54+
// Verify at compile time that these dimensions would overflow int32.
55+
static_assert(int64_t{N} * int64_t{D} > int64_t{std::numeric_limits<int>::max()},
56+
"Test dimensions must cause int32 overflow in N*D");
57+
58+
auto alloc = std::make_shared<OverflowCheckAllocator>();
59+
MLFloat16 dummy_score{0.0f};
60+
61+
// The allocation size must reflect correct size_t arithmetic: N * D * sizeof(float).
62+
// With the old int parameters, N * D would overflow to a small/negative value, producing a wrong allocation size.
63+
constexpr uintmax_t expected_allocation_size = uintmax_t{N} * D * sizeof(float);
64+
65+
if constexpr (expected_allocation_size <= uintmax_t{std::numeric_limits<size_t>::max()}) {
66+
// Allocation size fits in size_t. The function reaches Alloc, which records the requested size and throws
67+
// AllocationIntercepted.
68+
EXPECT_THROW(ComputeAttentionSoftmaxInplace<MLFloat16>(&dummy_score, N, D, nullptr, alloc),
69+
AllocationIntercepted);
70+
71+
EXPECT_EQ(alloc->LastAllocSize(), static_cast<size_t>(expected_allocation_size));
72+
} else {
73+
// Allocation size overflows size_t (i.e., in a 32-bit build), so SafeInt<size_t> will throw an exception.
74+
try {
75+
ComputeAttentionSoftmaxInplace<MLFloat16>(&dummy_score, N, D, nullptr, alloc);
76+
FAIL() << "Expected OnnxRuntimeException to be thrown";
77+
} catch (const OnnxRuntimeException& e) {
78+
EXPECT_THAT(e.what(), testing::HasSubstr("Integer overflow"));
79+
}
80+
}
81+
}
82+
83+
} // namespace test
84+
} // namespace onnxruntime
85+
86+
#endif // !defined(ORT_NO_EXCEPTIONS)

0 commit comments

Comments
 (0)