Skip to content

Commit 606da60

Browse files
author
Github Executorch
committed
Use safe_numel() in et
Pull Request resolved: #19075 Replace compute_numel() with safe_numel() Authored with Claude. ghstack-source-id: 372480318 @exported-using-ghexport Differential Revision: [D102082911](https://our.internmc.facebook.com/intern/diff/D102082911/)
1 parent 30566e4 commit 606da60

5 files changed

Lines changed: 61 additions & 22 deletions

File tree

extension/tensor/tensor_ptr.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <numeric>
1212

13+
#include <c10/util/safe_numerics.h>
14+
1315
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1416

1517
namespace executorch {
@@ -147,11 +149,26 @@ TensorPtr make_tensor_ptr(
147149
std::vector<executorch::aten::StridesType> strides,
148150
executorch::aten::ScalarType type,
149151
executorch::aten::TensorShapeDynamism dynamism) {
152+
auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size());
153+
ET_CHECK_MSG(
154+
numel_result.ok(),
155+
"safe_numel failed: %d",
156+
static_cast<int>(numel_result.error()));
157+
const ssize_t numel = numel_result.get();
158+
size_t nbytes;
150159
ET_CHECK_MSG(
151-
data.size() ==
152-
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
153-
executorch::aten::elementSize(type),
154-
"Data size does not match tensor size.");
160+
!c10::mul_overflows(
161+
static_cast<size_t>(numel),
162+
executorch::aten::elementSize(type),
163+
&nbytes),
164+
"Overflow computing nbytes: numel=%zd element_size=%zu",
165+
numel,
166+
executorch::aten::elementSize(type));
167+
ET_CHECK_MSG(
168+
data.size() == nbytes,
169+
"Data size (%zu) does not match tensor size (%zu).",
170+
data.size(),
171+
nbytes);
155172
auto data_ptr = data.data();
156173
return make_tensor_ptr(
157174
std::move(sizes),
@@ -205,7 +222,13 @@ TensorPtr clone_tensor_ptr(
205222
runtime::canCast(tensor_type, type),
206223
"Cannot cast tensor type to desired type.");
207224
const auto tensor_numel = static_cast<size_t>(tensor.numel());
208-
std::vector<uint8_t> data(tensor_numel * aten::elementSize(type));
225+
size_t clone_nbytes;
226+
ET_CHECK_MSG(
227+
!c10::mul_overflows(tensor_numel, aten::elementSize(type), &clone_nbytes),
228+
"Overflow computing clone nbytes: numel=%zu element_size=%zu",
229+
tensor_numel,
230+
aten::elementSize(type));
231+
std::vector<uint8_t> data(clone_nbytes);
209232

210233
// Create a minimal context for error handling in ET_SWITCH
211234
struct {

extension/tensor/tensor_ptr.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,13 @@ inline TensorPtr make_tensor_ptr(
110110
executorch::aten::ScalarType type = deduced_type,
111111
executorch::aten::TensorShapeDynamism dynamism =
112112
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
113+
auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size());
113114
ET_CHECK_MSG(
114-
data.size() ==
115-
executorch::aten::compute_numel(sizes.data(), sizes.size()),
115+
numel_result.ok(),
116+
"safe_numel failed: %d",
117+
static_cast<int>(numel_result.error()));
118+
ET_CHECK_MSG(
119+
data.size() == static_cast<size_t>(numel_result.get()),
116120
"Data size does not match tensor size.");
117121
if (type != deduced_type) {
118122
ET_CHECK_MSG(
@@ -368,8 +372,13 @@ inline TensorPtr make_tensor_ptr(
368372
const auto same_rank = sizes.size() == static_cast<size_t>(tensor.dim());
369373
const auto same_shape = same_rank &&
370374
std::equal(sizes.begin(), sizes.end(), tensor.sizes().begin());
371-
const auto element_count =
372-
executorch::aten::compute_numel(sizes.data(), sizes.size());
375+
auto element_count_result =
376+
executorch::aten::safe_numel(sizes.data(), sizes.size());
377+
ET_CHECK_MSG(
378+
element_count_result.ok(),
379+
"safe_numel failed: %d",
380+
static_cast<int>(element_count_result.error()));
381+
const auto element_count = element_count_result.get();
373382
const auto parent_element_count = tensor.numel();
374383
ET_CHECK_MSG(
375384
element_count <= parent_element_count,

extension/tensor/tensor_ptr_maker.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,21 @@ TensorPtr empty_strided(
113113
std::vector<executorch::aten::StridesType> strides,
114114
executorch::aten::ScalarType type,
115115
executorch::aten::TensorShapeDynamism dynamism) {
116-
const auto numel = static_cast<size_t>(
117-
executorch::aten::compute_numel(sizes.data(), sizes.size()));
118-
const auto elem_size =
119-
static_cast<size_t>(executorch::aten::elementSize(type));
120-
size_t nbytes = 0;
116+
auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size());
121117
ET_CHECK_MSG(
122-
!c10::mul_overflows(numel, elem_size, &nbytes),
123-
"empty_strided size overflow: numel %zu * element size %zu",
118+
numel_result.ok(),
119+
"safe_numel failed: %d",
120+
static_cast<int>(numel_result.error()));
121+
const ssize_t numel = numel_result.get();
122+
size_t nbytes;
123+
ET_CHECK_MSG(
124+
!c10::mul_overflows(
125+
static_cast<size_t>(numel),
126+
executorch::aten::elementSize(type),
127+
&nbytes),
128+
"Overflow computing nbytes: numel=%zd element_size=%zu",
124129
numel,
125-
elem_size);
130+
executorch::aten::elementSize(type));
126131
std::vector<uint8_t> data(nbytes);
127132
return make_tensor_ptr(
128133
std::move(sizes),

extension/wasm/wasm_bindings.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,22 @@ inline void js_array_push(val_array<T>& array, const T& value) {
8484
_(float, Float) \
8585
_(int64_t, Long)
8686

87-
inline ssize_t compute_expected_numel(
87+
inline ::executorch::runtime::Result<ssize_t> compute_expected_numel(
8888
const std::vector<torch::executor::Tensor::SizesType>& sizes) {
89-
return executorch::aten::compute_numel(sizes.data(), sizes.size());
89+
return executorch::aten::safe_numel(sizes.data(), sizes.size());
9090
}
9191

9292
template <typename T>
9393
inline void assert_valid_numel(
9494
const std::vector<T>& data,
9595
const std::vector<torch::executor::Tensor::SizesType>& sizes) {
9696
auto computed_numel = compute_expected_numel(sizes);
97+
THROW_IF_ERROR(
98+
computed_numel.error(), "Invalid tensor sizes: numel computation failed");
9799
THROW_IF_FALSE(
98-
data.size() >= computed_numel,
100+
data.size() >= static_cast<size_t>(computed_numel.get()),
99101
"Required %ld elements, given %ld",
100-
computed_numel,
102+
computed_numel.get(),
101103
data.size());
102104
}
103105

runtime/core/portable_type/tensor_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef<SizesType> new_sizes) {
147147
// TODO(T175194371): Unbounded dynamic tensor resizing is not yet
148148
// supported: treat them as upper-bounded.
149149
case TensorShapeDynamism::DYNAMIC_UNBOUND: {
150-
const auto new_numel = compute_numel(new_sizes.data(), dim_);
150+
const auto new_numel = ET_UNWRAP(safe_numel(new_sizes.data(), dim_));
151151

152152
ET_CHECK_OR_RETURN_ERROR(
153153
static_cast<size_t>(new_numel) <= numel_bound_,

0 commit comments

Comments
 (0)