diff --git a/runtime/executor/platform_memory_allocator.h b/runtime/executor/platform_memory_allocator.h index 5951f116d3d..601a4c19c85 100644 --- a/runtime/executor/platform_memory_allocator.h +++ b/runtime/executor/platform_memory_allocator.h @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -46,8 +47,20 @@ class PlatformMemoryAllocator final : public MemoryAllocator { return nullptr; } - // Allocate enough memory for the node, the data and the alignment bump. - size_t alloc_size = sizeof(AllocationNode) + size + alignment; + // Check for overflow before computing total allocation size. + // Allocate enough for the node, data, and alignment bump (at most + // alignment - 1 extra bytes to align the data pointer). + size_t alloc_size = 0; + if (c10::add_overflows(sizeof(AllocationNode), size, &alloc_size) || + c10::add_overflows(alloc_size, alignment - 1, &alloc_size)) { + ET_LOG( + Error, + "Allocation size overflow: size %zu, alignment %zu", + size, + alignment); + return nullptr; + } + void* node_memory = runtime::pal_allocate(alloc_size); // If allocation failed, log message and return nullptr.