diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index 17243fca0fd..a7e59595d7e 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -9,9 +9,11 @@ #include #include +#include #include #include +#include #include #include @@ -38,7 +40,14 @@ ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) { "Size must be non-negative, got %zd at dimension %zd", static_cast(sizes[i]), i); - numel *= sizes[i]; + ssize_t next_numel; + ET_CHECK_MSG( + !c10::mul_overflows(numel, static_cast(sizes[i]), &next_numel), + "Overflow computing numel: %zd * %zd would overflow ssize_t at dimension %zd", + numel, + static_cast(sizes[i]), + i); + numel = next_numel; } return numel; } diff --git a/runtime/executor/tensor_parser_portable.cpp b/runtime/executor/tensor_parser_portable.cpp index 2fc9a2dc140..7fdd01e8f62 100644 --- a/runtime/executor/tensor_parser_portable.cpp +++ b/runtime/executor/tensor_parser_portable.cpp @@ -8,6 +8,8 @@ #include +#include + #include #include #include @@ -118,10 +120,11 @@ Result parseTensor( dim_order = const_cast(serialized_dim_order); } - // Validate sizes before using them in case the PTE data is bad. We can't - // detect bad positive values, but we can reject negative values, which would - // otherwise panic in the TensorImpl ctor. dim_order_to_stride() will validate - // dim_order. + // Validate sizes before using them in case the PTE data is bad. Reject + // negative values and check that the product of all dimensions doesn't + // overflow ssize_t, which would otherwise abort in the TensorImpl ctor. + // dim_order_to_stride() will validate dim_order. + ssize_t numel = 1; for (flatbuffers::uoffset_t i = 0; i < dim; i++) { ET_CHECK_OR_RETURN_ERROR( sizes[i] >= 0, @@ -129,6 +132,13 @@ Result parseTensor( "Negative size[%zu] %" PRId32, static_cast(i), sizes[i]); + ssize_t next_numel; + ET_CHECK_OR_RETURN_ERROR( + !c10::mul_overflows(numel, static_cast(sizes[i]), &next_numel), + InvalidProgram, + "Overflow computing numel at dim %zu", + static_cast(i)); + numel = next_numel; } // We will remove strides from schema.