Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions extension/tensor/tensor_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <numeric>

#include <c10/util/safe_numerics.h>

#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

namespace executorch {
Expand Down Expand Up @@ -147,11 +149,26 @@ TensorPtr make_tensor_ptr(
std::vector<executorch::aten::StridesType> strides,
executorch::aten::ScalarType type,
executorch::aten::TensorShapeDynamism dynamism) {
auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size());
ET_CHECK_MSG(
numel_result.ok(),
"safe_numel failed: %d",
static_cast<int>(numel_result.error()));
const ssize_t numel = numel_result.get();
size_t nbytes;
ET_CHECK_MSG(
data.size() ==
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
executorch::aten::elementSize(type),
"Data size does not match tensor size.");
!c10::mul_overflows(
static_cast<size_t>(numel),
executorch::aten::elementSize(type),
&nbytes),
"Overflow computing nbytes: numel=%zd element_size=%zu",
numel,
executorch::aten::elementSize(type));
ET_CHECK_MSG(
data.size() == nbytes,
"Data size (%zu) does not match tensor size (%zu).",
data.size(),
nbytes);
auto data_ptr = data.data();
return make_tensor_ptr(
std::move(sizes),
Expand Down Expand Up @@ -205,7 +222,13 @@ TensorPtr clone_tensor_ptr(
runtime::canCast(tensor_type, type),
"Cannot cast tensor type to desired type.");
const auto tensor_numel = static_cast<size_t>(tensor.numel());
std::vector<uint8_t> data(tensor_numel * aten::elementSize(type));
size_t clone_nbytes;
ET_CHECK_MSG(
!c10::mul_overflows(tensor_numel, aten::elementSize(type), &clone_nbytes),
"Overflow computing clone nbytes: numel=%zu element_size=%zu",
tensor_numel,
aten::elementSize(type));
std::vector<uint8_t> data(clone_nbytes);

// Create a minimal context for error handling in ET_SWITCH
struct {
Expand Down
17 changes: 13 additions & 4 deletions extension/tensor/tensor_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,13 @@ inline TensorPtr make_tensor_ptr(
executorch::aten::ScalarType type = deduced_type,
executorch::aten::TensorShapeDynamism dynamism =
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size());
ET_CHECK_MSG(
data.size() ==
executorch::aten::compute_numel(sizes.data(), sizes.size()),
numel_result.ok(),
"safe_numel failed: %d",
static_cast<int>(numel_result.error()));
ET_CHECK_MSG(
data.size() == static_cast<size_t>(numel_result.get()),
"Data size does not match tensor size.");
if (type != deduced_type) {
ET_CHECK_MSG(
Expand Down Expand Up @@ -368,8 +372,13 @@ inline TensorPtr make_tensor_ptr(
const auto same_rank = sizes.size() == static_cast<size_t>(tensor.dim());
const auto same_shape = same_rank &&
std::equal(sizes.begin(), sizes.end(), tensor.sizes().begin());
const auto element_count =
executorch::aten::compute_numel(sizes.data(), sizes.size());
auto element_count_result =
executorch::aten::safe_numel(sizes.data(), sizes.size());
ET_CHECK_MSG(
element_count_result.ok(),
"safe_numel failed: %d",
static_cast<int>(element_count_result.error()));
const auto element_count = element_count_result.get();
const auto parent_element_count = tensor.numel();
ET_CHECK_MSG(
element_count <= parent_element_count,
Expand Down
21 changes: 13 additions & 8 deletions extension/tensor/tensor_ptr_maker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,21 @@ TensorPtr empty_strided(
std::vector<executorch::aten::StridesType> strides,
executorch::aten::ScalarType type,
executorch::aten::TensorShapeDynamism dynamism) {
const auto numel = static_cast<size_t>(
executorch::aten::compute_numel(sizes.data(), sizes.size()));
const auto elem_size =
static_cast<size_t>(executorch::aten::elementSize(type));
size_t nbytes = 0;
auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size());
ET_CHECK_MSG(
!c10::mul_overflows(numel, elem_size, &nbytes),
"empty_strided size overflow: numel %zu * element size %zu",
numel_result.ok(),
"safe_numel failed: %d",
static_cast<int>(numel_result.error()));
const ssize_t numel = numel_result.get();
size_t nbytes;
ET_CHECK_MSG(
!c10::mul_overflows(
static_cast<size_t>(numel),
executorch::aten::elementSize(type),
&nbytes),
"Overflow computing nbytes: numel=%zd element_size=%zu",
numel,
elem_size);
executorch::aten::elementSize(type));
std::vector<uint8_t> data(nbytes);
return make_tensor_ptr(
std::move(sizes),
Expand Down
10 changes: 6 additions & 4 deletions extension/wasm/wasm_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,22 @@ inline void js_array_push(val_array<T>& array, const T& value) {
_(float, Float) \
_(int64_t, Long)

inline ssize_t compute_expected_numel(
inline ::executorch::runtime::Result<ssize_t> compute_expected_numel(
const std::vector<torch::executor::Tensor::SizesType>& sizes) {
return executorch::aten::compute_numel(sizes.data(), sizes.size());
return executorch::aten::safe_numel(sizes.data(), sizes.size());
}

template <typename T>
inline void assert_valid_numel(
const std::vector<T>& data,
const std::vector<torch::executor::Tensor::SizesType>& sizes) {
auto computed_numel = compute_expected_numel(sizes);
THROW_IF_ERROR(
computed_numel.error(), "Invalid tensor sizes: numel computation failed");
THROW_IF_FALSE(
data.size() >= computed_numel,
data.size() >= static_cast<size_t>(computed_numel.get()),
"Required %ld elements, given %ld",
computed_numel,
computed_numel.get(),
data.size());
}

Expand Down
6 changes: 5 additions & 1 deletion runtime/core/portable_type/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef<SizesType> new_sizes) {
// TODO(T175194371): Unbounded dynamic tensor resizing is not yet
// supported: treat them as upper-bounded.
case TensorShapeDynamism::DYNAMIC_UNBOUND: {
const auto new_numel = compute_numel(new_sizes.data(), dim_);
auto new_numel_result = safe_numel(new_sizes.data(), dim_);
if (!new_numel_result.ok()) {
return new_numel_result.error();
}
const auto new_numel = new_numel_result.get();

ET_CHECK_OR_RETURN_ERROR(
static_cast<size_t>(new_numel) <= numel_bound_,
Expand Down
Loading