|
10 | 10 |
|
11 | 11 | #include <numeric> |
12 | 12 |
|
| 13 | +#include <c10/util/safe_numerics.h> |
| 14 | + |
13 | 15 | #include <executorch/runtime/core/exec_aten/util/tensor_util.h> |
14 | 16 |
|
15 | 17 | namespace executorch { |
@@ -147,11 +149,26 @@ TensorPtr make_tensor_ptr( |
147 | 149 | std::vector<executorch::aten::StridesType> strides, |
148 | 150 | executorch::aten::ScalarType type, |
149 | 151 | executorch::aten::TensorShapeDynamism dynamism) { |
| 152 | + auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size()); |
| 153 | + ET_CHECK_MSG( |
| 154 | + numel_result.ok(), |
| 155 | + "safe_numel failed: %d", |
| 156 | + static_cast<int>(numel_result.error())); |
| 157 | + const ssize_t numel = numel_result.get(); |
| 158 | + size_t nbytes; |
150 | 159 | ET_CHECK_MSG( |
151 | | - data.size() == |
152 | | - executorch::aten::compute_numel(sizes.data(), sizes.size()) * |
153 | | - executorch::aten::elementSize(type), |
154 | | - "Data size does not match tensor size."); |
| 160 | + !c10::mul_overflows( |
| 161 | + static_cast<size_t>(numel), |
| 162 | + executorch::aten::elementSize(type), |
| 163 | + &nbytes), |
| 164 | + "Overflow computing nbytes: numel=%zd element_size=%zu", |
| 165 | + numel, |
| 166 | + executorch::aten::elementSize(type)); |
| 167 | + ET_CHECK_MSG( |
| 168 | + data.size() == nbytes, |
| 169 | + "Data size (%zu) does not match tensor size (%zu).", |
| 170 | + data.size(), |
| 171 | + nbytes); |
155 | 172 | auto data_ptr = data.data(); |
156 | 173 | return make_tensor_ptr( |
157 | 174 | std::move(sizes), |
@@ -205,7 +222,13 @@ TensorPtr clone_tensor_ptr( |
205 | 222 | runtime::canCast(tensor_type, type), |
206 | 223 | "Cannot cast tensor type to desired type."); |
207 | 224 | const auto tensor_numel = static_cast<size_t>(tensor.numel()); |
208 | | - std::vector<uint8_t> data(tensor_numel * aten::elementSize(type)); |
| 225 | + size_t clone_nbytes; |
| 226 | + ET_CHECK_MSG( |
| 227 | + !c10::mul_overflows(tensor_numel, aten::elementSize(type), &clone_nbytes), |
| 228 | + "Overflow computing clone nbytes: numel=%zu element_size=%zu", |
| 229 | + tensor_numel, |
| 230 | + aten::elementSize(type)); |
| 231 | + std::vector<uint8_t> data(clone_nbytes); |
209 | 232 |
|
210 | 233 | // Create a minimal context for error handling in ET_SWITCH |
211 | 234 | struct { |
|
0 commit comments