Skip to content

Commit a879332

Browse files
committed
Fix integer overflows in tensor byte-size computations (TOB-EXECUTORCH-19)
Three tensor-byte-size multiplications had no overflow check, letting a malicious PTE trigger wrap-to-small size_t values while kernels iterate on the un-wrapped element count, producing heap buffer overflows. Fixed here: - extension/tensor/tensor_ptr.h: data.size() * elementSize(type) in make_tensor_ptr cast path. - extension/tensor/tensor_ptr_maker.cpp: compute_numel(...) * elementSize(type) in empty_strided. - runtime/core/tensor_layout.cpp: dim-product loop and final * elementSize(scalar_type) in calculate_nbytes; now returns Error::InvalidArgument on overflow since the function already returns Result<size_t>. All guards use c10::mul_overflows, matching the existing pattern in MethodMeta::calculate_nbytes, the data loaders, and PlatformMemoryAllocator. runtime/core/portable_type/tensor_impl.cpp is intentionally left alone in this branch; guarding the nbytes() / compute_numel multiplications there breaks internal callers and will be handled separately. Authored with Claude.
1 parent 217ad45 commit a879332

3 files changed

Lines changed: 38 additions & 7 deletions

File tree

extension/tensor/tensor_ptr.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <vector>
1515

1616
#include <c10/macros/Macros.h>
17+
#include <c10/util/safe_numerics.h>
1718
#include <executorch/runtime/core/error.h>
1819
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1920
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -117,7 +118,16 @@ inline TensorPtr make_tensor_ptr(
117118
ET_CHECK_MSG(
118119
runtime::canCast(deduced_type, type),
119120
"Cannot cast deduced type to specified type.");
120-
std::vector<uint8_t> casted_data(data.size() * aten::elementSize(type));
121+
size_t casted_bytes = 0;
122+
ET_CHECK_MSG(
123+
!c10::mul_overflows(
124+
data.size(),
125+
static_cast<size_t>(aten::elementSize(type)),
126+
&casted_bytes),
127+
"casted_data size overflow: %zu elements * %zu bytes/element",
128+
data.size(),
129+
static_cast<size_t>(aten::elementSize(type)));
130+
std::vector<uint8_t> casted_data(casted_bytes);
121131

122132
// Create a minimal context for error handling in ET_SWITCH
123133
struct {

extension/tensor/tensor_ptr_maker.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <random>
1212

13+
#include <c10/util/safe_numerics.h>
14+
1315
namespace executorch {
1416
namespace extension {
1517
namespace {
@@ -111,9 +113,17 @@ TensorPtr empty_strided(
111113
std::vector<executorch::aten::StridesType> strides,
112114
executorch::aten::ScalarType type,
113115
executorch::aten::TensorShapeDynamism dynamism) {
114-
std::vector<uint8_t> data(
115-
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
116-
executorch::aten::elementSize(type));
116+
const auto numel = static_cast<size_t>(
117+
executorch::aten::compute_numel(sizes.data(), sizes.size()));
118+
const auto elem_size =
119+
static_cast<size_t>(executorch::aten::elementSize(type));
120+
size_t nbytes = 0;
121+
ET_CHECK_MSG(
122+
!c10::mul_overflows(numel, elem_size, &nbytes),
123+
"empty_strided size overflow: numel %zu * element size %zu",
124+
numel,
125+
elem_size);
126+
std::vector<uint8_t> data(nbytes);
117127
return make_tensor_ptr(
118128
std::move(sizes),
119129
std::move(data),

runtime/core/tensor_layout.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <c10/util/irange.h>
10+
#include <c10/util/safe_numerics.h>
1011
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1112
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
1213
#include <executorch/runtime/core/span.h>
@@ -19,15 +20,25 @@ namespace {
1920
Result<size_t> calculate_nbytes(
2021
const Span<const int32_t>& sizes,
2122
const executorch::aten::ScalarType& scalar_type) {
22-
ssize_t n = 1;
23+
size_t n = 1;
2324
for (const auto i : c10::irange(sizes.size())) {
2425
if (sizes[i] < 0) {
2526
return Error::InvalidArgument;
2627
}
27-
n *= sizes[i];
28+
size_t next = 0;
29+
if (c10::mul_overflows(n, static_cast<size_t>(sizes[i]), &next)) {
30+
return Error::InvalidArgument;
31+
}
32+
n = next;
2833
}
2934
// Use the full namespace to disambiguate from c10::elementSize.
30-
return n * executorch::runtime::elementSize(scalar_type);
35+
const size_t elem_size =
36+
static_cast<size_t>(executorch::runtime::elementSize(scalar_type));
37+
size_t total = 0;
38+
if (c10::mul_overflows(n, elem_size, &total)) {
39+
return Error::InvalidArgument;
40+
}
41+
return total;
3142
}
3243
} // namespace
3344

0 commit comments

Comments
 (0)