|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#if !defined(ORT_NO_EXCEPTIONS) |
| 5 | + |
| 6 | +#include <exception> |
| 7 | +#include <limits> |
| 8 | + |
| 9 | +#include "gtest/gtest.h" |
| 10 | +#include "gmock/gmock.h" |
| 11 | + |
| 12 | +#include "core/framework/allocator.h" |
| 13 | +#include "core/providers/cpu/llm/attention_softmax.h" |
| 14 | + |
| 15 | +namespace onnxruntime { |
| 16 | +namespace test { |
| 17 | + |
| 18 | +// Regression test for integer overflow in FP16 softmax allocation. |
| 19 | +// ComputeAttentionSoftmaxInplace<MLFloat16> previously used int for N and D, so N*D could overflow int32. |
| 20 | +// The fix changed parameters to size_t and uses SafeInt for the multiplication. |
| 21 | +// |
| 22 | +// This test calls ComputeAttentionSoftmaxInplace<MLFloat16> directly with overflow-triggering dimensions |
| 23 | +// (N=46341, D=46341, where N*D > INT_MAX). |
| 24 | +// A custom allocator intercepts the Alloc call to verify the requested size is computed correctly with size_t |
| 25 | +// arithmetic, without actually allocating the ~8GB buffer. |
| 26 | +// |
| 27 | +// On 32-bit builds, SafeInt<size_t> will signal an overflow for the requested size. |
| 28 | +TEST(AttentionSoftmaxTest, Fp16OverflowAllocation) { |
| 29 | + // Custom exception thrown by the allocator to distinguish it from SafeInt overflow. |
| 30 | + struct AllocationIntercepted : std::exception { |
| 31 | + const char* what() const noexcept override { return "allocation intercepted"; } |
| 32 | + }; |
| 33 | + |
| 34 | + // Custom allocator that records the requested allocation size and throws to avoid actually allocating the |
| 35 | + // (very large) buffer. |
| 36 | + class OverflowCheckAllocator : public IAllocator { |
| 37 | + public: |
| 38 | + OverflowCheckAllocator() |
| 39 | + : IAllocator(OrtMemoryInfo(CPU, OrtDeviceAllocator)) {} |
| 40 | + void* Alloc(size_t size) override { |
| 41 | + last_alloc_size_ = size; |
| 42 | + throw AllocationIntercepted(); |
| 43 | + } |
| 44 | + void Free(void*) override {} |
| 45 | + size_t LastAllocSize() const { return last_alloc_size_; } |
| 46 | + |
| 47 | + private: |
| 48 | + size_t last_alloc_size_ = 0; |
| 49 | + }; |
| 50 | + |
| 51 | + constexpr size_t N = 46341; |
| 52 | + constexpr size_t D = 46341; |
| 53 | + |
| 54 | + // Verify at compile time that these dimensions would overflow int32. |
| 55 | + static_assert(int64_t{N} * int64_t{D} > int64_t{std::numeric_limits<int>::max()}, |
| 56 | + "Test dimensions must cause int32 overflow in N*D"); |
| 57 | + |
| 58 | + auto alloc = std::make_shared<OverflowCheckAllocator>(); |
| 59 | + MLFloat16 dummy_score{0.0f}; |
| 60 | + |
| 61 | + // The allocation size must reflect correct size_t arithmetic: N * D * sizeof(float). |
| 62 | + // With the old int parameters, N * D would overflow to a small/negative value, producing a wrong allocation size. |
| 63 | + constexpr uintmax_t expected_allocation_size = uintmax_t{N} * D * sizeof(float); |
| 64 | + |
| 65 | + if constexpr (expected_allocation_size <= uintmax_t{std::numeric_limits<size_t>::max()}) { |
| 66 | + // Allocation size fits in size_t. The function reaches Alloc, which records the requested size and throws |
| 67 | + // AllocationIntercepted. |
| 68 | + EXPECT_THROW(ComputeAttentionSoftmaxInplace<MLFloat16>(&dummy_score, N, D, nullptr, alloc), |
| 69 | + AllocationIntercepted); |
| 70 | + |
| 71 | + EXPECT_EQ(alloc->LastAllocSize(), static_cast<size_t>(expected_allocation_size)); |
| 72 | + } else { |
| 73 | + // Allocation size overflows size_t (i.e., in a 32-bit build), so SafeInt<size_t> will throw an exception. |
| 74 | + try { |
| 75 | + ComputeAttentionSoftmaxInplace<MLFloat16>(&dummy_score, N, D, nullptr, alloc); |
| 76 | + FAIL() << "Expected OnnxRuntimeException to be thrown"; |
| 77 | + } catch (const OnnxRuntimeException& e) { |
| 78 | + EXPECT_THAT(e.what(), testing::HasSubstr("Integer overflow")); |
| 79 | + } |
| 80 | + } |
| 81 | +} |
| 82 | + |
| 83 | +} // namespace test |
| 84 | +} // namespace onnxruntime |
| 85 | + |
| 86 | +#endif // !defined(ORT_NO_EXCEPTIONS) |
0 commit comments