Skip to content

Commit e8f2fca

Browse files
author
Github Executorch
committed
Use compute_numel_overflow() in tensor_ptr builders
make_tensor_ptr() and related extension builders return TensorPtr (a shared_ptr) and cannot propagate Error, so the existing compute_numel()*elementSize() size checks silently wrapped on overflow and let an undersized buffer back an oversized tensor (heap OOB). Migrate these callers to compute_numel_overflow() so overflow aborts, and add explicit c10::mul_overflows() guards around the subsequent numel*elementSize multiplication. Touched: make_tensor_ptr(vector<uint8_t>), clone_tensor_ptr, empty_strided, and the two tensor_ptr.h template overloads. Authored with Claude. Differential Revision: [D102082922](https://our.internmc.facebook.com/intern/diff/D102082922/) [ghstack-poisoned]
1 parent 1fc3ee6 commit e8f2fca

3 files changed

Lines changed: 41 additions & 9 deletions

File tree

extension/tensor/tensor_ptr.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <numeric>
1212

13+
#include <c10/util/safe_numerics.h>
14+
1315
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1416

1517
namespace executorch {
@@ -147,11 +149,22 @@ TensorPtr make_tensor_ptr(
147149
std::vector<executorch::aten::StridesType> strides,
148150
executorch::aten::ScalarType type,
149151
executorch::aten::TensorShapeDynamism dynamism) {
152+
const ssize_t numel =
153+
executorch::aten::compute_numel_overflow(sizes.data(), sizes.size());
154+
size_t nbytes;
155+
ET_CHECK_MSG(
156+
!c10::mul_overflows(
157+
static_cast<size_t>(numel),
158+
executorch::aten::elementSize(type),
159+
&nbytes),
160+
"Overflow computing nbytes: numel=%zd element_size=%zu",
161+
numel,
162+
executorch::aten::elementSize(type));
150163
ET_CHECK_MSG(
151-
data.size() ==
152-
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
153-
executorch::aten::elementSize(type),
154-
"Data size does not match tensor size.");
164+
data.size() == nbytes,
165+
"Data size (%zu) does not match tensor size (%zu).",
166+
data.size(),
167+
nbytes);
155168
auto data_ptr = data.data();
156169
return make_tensor_ptr(
157170
std::move(sizes),
@@ -205,7 +218,14 @@ TensorPtr clone_tensor_ptr(
205218
runtime::canCast(tensor_type, type),
206219
"Cannot cast tensor type to desired type.");
207220
const auto tensor_numel = static_cast<size_t>(tensor.numel());
208-
std::vector<uint8_t> data(tensor_numel * aten::elementSize(type));
221+
size_t clone_nbytes;
222+
ET_CHECK_MSG(
223+
!c10::mul_overflows(
224+
tensor_numel, aten::elementSize(type), &clone_nbytes),
225+
"Overflow computing clone nbytes: numel=%zu element_size=%zu",
226+
tensor_numel,
227+
aten::elementSize(type));
228+
std::vector<uint8_t> data(clone_nbytes);
209229

210230
// Create a minimal context for error handling in ET_SWITCH
211231
struct {

extension/tensor/tensor_ptr.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ inline TensorPtr make_tensor_ptr(
111111
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
112112
ET_CHECK_MSG(
113113
data.size() ==
114-
executorch::aten::compute_numel(sizes.data(), sizes.size()),
114+
static_cast<size_t>(executorch::aten::compute_numel_overflow(
115+
sizes.data(), sizes.size())),
115116
"Data size does not match tensor size.");
116117
if (type != deduced_type) {
117118
ET_CHECK_MSG(
@@ -359,7 +360,7 @@ inline TensorPtr make_tensor_ptr(
359360
const auto same_shape = same_rank &&
360361
std::equal(sizes.begin(), sizes.end(), tensor.sizes().begin());
361362
const auto element_count =
362-
executorch::aten::compute_numel(sizes.data(), sizes.size());
363+
executorch::aten::compute_numel_overflow(sizes.data(), sizes.size());
363364
const auto parent_element_count = tensor.numel();
364365
ET_CHECK_MSG(
365366
element_count <= parent_element_count,

extension/tensor/tensor_ptr_maker.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <random>
1212

13+
#include <c10/util/safe_numerics.h>
14+
1315
namespace executorch {
1416
namespace extension {
1517
namespace {
@@ -111,9 +113,18 @@ TensorPtr empty_strided(
111113
std::vector<executorch::aten::StridesType> strides,
112114
executorch::aten::ScalarType type,
113115
executorch::aten::TensorShapeDynamism dynamism) {
114-
std::vector<uint8_t> data(
115-
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
116+
const ssize_t numel =
117+
executorch::aten::compute_numel_overflow(sizes.data(), sizes.size());
118+
size_t nbytes;
119+
ET_CHECK_MSG(
120+
!c10::mul_overflows(
121+
static_cast<size_t>(numel),
122+
executorch::aten::elementSize(type),
123+
&nbytes),
124+
"Overflow computing nbytes: numel=%zd element_size=%zu",
125+
numel,
116126
executorch::aten::elementSize(type));
127+
std::vector<uint8_t> data(nbytes);
117128
return make_tensor_ptr(
118129
std::move(sizes),
119130
std::move(data),

0 commit comments

Comments
 (0)