diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 47124bdeca6..08d6cc1254c 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -117,7 +118,16 @@ inline TensorPtr make_tensor_ptr( ET_CHECK_MSG( runtime::canCast(deduced_type, type), "Cannot cast deduced type to specified type."); - std::vector casted_data(data.size() * aten::elementSize(type)); + size_t casted_bytes = 0; + ET_CHECK_MSG( + !c10::mul_overflows( + data.size(), + static_cast(aten::elementSize(type)), + &casted_bytes), + "casted_data size overflow: %zu elements * %zu bytes/element", + data.size(), + static_cast(aten::elementSize(type))); + std::vector casted_data(casted_bytes); // Create a minimal context for error handling in ET_SWITCH struct { diff --git a/extension/tensor/tensor_ptr_maker.cpp b/extension/tensor/tensor_ptr_maker.cpp index b71dfab8eeb..52a3e8f281c 100644 --- a/extension/tensor/tensor_ptr_maker.cpp +++ b/extension/tensor/tensor_ptr_maker.cpp @@ -10,6 +10,8 @@ #include +#include + namespace executorch { namespace extension { namespace { @@ -111,9 +113,17 @@ TensorPtr empty_strided( std::vector strides, executorch::aten::ScalarType type, executorch::aten::TensorShapeDynamism dynamism) { - std::vector data( - executorch::aten::compute_numel(sizes.data(), sizes.size()) * - executorch::aten::elementSize(type)); + const auto numel = static_cast( + executorch::aten::compute_numel(sizes.data(), sizes.size())); + const auto elem_size = + static_cast(executorch::aten::elementSize(type)); + size_t nbytes = 0; + ET_CHECK_MSG( + !c10::mul_overflows(numel, elem_size, &nbytes), + "empty_strided size overflow: numel %zu * element size %zu", + numel, + elem_size); + std::vector data(nbytes); return make_tensor_ptr( std::move(sizes), std::move(data), diff --git a/runtime/core/tensor_layout.cpp b/runtime/core/tensor_layout.cpp index d33f79f27c4..97abe0b5130 100644 --- a/runtime/core/tensor_layout.cpp +++ b/runtime/core/tensor_layout.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include #include @@ -19,15 +20,25 @@ namespace { Result calculate_nbytes( const Span& sizes, const executorch::aten::ScalarType& scalar_type) { - ssize_t n = 1; + size_t n = 1; for (const auto i : c10::irange(sizes.size())) { if (sizes[i] < 0) { return Error::InvalidArgument; } - n *= sizes[i]; + size_t next = 0; + if (c10::mul_overflows(n, static_cast(sizes[i]), &next)) { + return Error::InvalidArgument; + } + n = next; } // Use the full namespace to disambiguate from c10::elementSize. - return n * executorch::runtime::elementSize(scalar_type); + const size_t elem_size = + static_cast(executorch::runtime::elementSize(scalar_type)); + size_t total = 0; + if (c10::mul_overflows(n, elem_size, &total)) { + return Error::InvalidArgument; + } + return total; } } // namespace