33
44#include " graph_flatbuffers_utils.h"
55
6+ #include < limits>
7+
68#include " core/common/flatbuffers.h"
79
810#include " core/common/narrow.h"
911#include " core/flatbuffers/flatbuffers_utils.h"
1012#include " core/flatbuffers/schema/ort.fbs.h"
13+ #include " core/framework/allocator.h"
1114#include " core/framework/tensorprotoutils.h"
1215#include " core/framework/tensor_external_data_info.h"
1316#include " core/graph/graph.h"
@@ -215,13 +218,50 @@ Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
215218 * to accommodate fbs::Tensors with external data.
216219 *
217220 * @param tensor flatbuffer representation of a tensor.
218- * @return size_t size in bytes of the tensor's data.
221+ * @param size_in_bytes Output size in bytes of the tensor's data.
222+ * @return Status indicating success or providing error information.
219223 */
220- size_t GetSizeInBytesFromFbsTensor (const fbs::Tensor& tensor) {
221- auto fbs_dims = tensor.dims ();
224+ Status GetSizeInBytesFromFbsTensor (const fbs::Tensor& tensor, size_t & size_in_bytes) {
225+ const auto * tensor_name = tensor.name ();
226+ const auto * tensor_name_str = tensor_name ? tensor_name->c_str () : " <unnamed>" ;
227+ const auto * tensor_data_type_str = fbs::EnumNameTensorDataType (tensor.data_type ());
228+ if (tensor_data_type_str[0 ] == ' \0 ' ) {
229+ tensor_data_type_str = " <unknown>" ;
230+ }
231+
232+ const auto * fbs_dims = tensor.dims ();
233+ if (nullptr == fbs_dims) {
234+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
235+ " Missing dimensions for tensor '" , tensor_name_str,
236+ " ' with data type '" , tensor_data_type_str,
237+ " '. Invalid ORT format model." );
238+ }
239+
240+ size_t num_elements = 1 ;
241+ for (int64_t dim : *fbs_dims) {
242+ if (dim < 0 ) {
243+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
244+ " Invalid negative dimension " , dim,
245+ " for tensor '" , tensor_name_str,
246+ " ' with data type '" , tensor_data_type_str,
247+ " '. Invalid ORT format model." );
248+ }
222249
223- auto num_elements = std::accumulate (fbs_dims->cbegin (), fbs_dims->cend (), SafeInt<size_t >(1 ),
224- std::multiplies<>());
250+ if (static_cast <uint64_t >(dim) > static_cast <uint64_t >(std::numeric_limits<size_t >::max ())) {
251+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
252+ " Dimension " , dim,
253+ " does not fit in size_t for tensor '" , tensor_name_str,
254+ " ' with data type '" , tensor_data_type_str,
255+ " '. Invalid ORT format model." );
256+ }
257+
258+ if (!IAllocator::CalcMemSizeForArray (num_elements, static_cast <size_t >(dim), &num_elements)) {
259+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
260+ " Tensor element count overflows size_t for tensor '" , tensor_name_str,
261+ " ' with data type '" , tensor_data_type_str,
262+ " '. Invalid ORT format model." );
263+ }
264+ }
225265
226266 size_t byte_size_of_one_element;
227267
@@ -280,11 +320,24 @@ size_t GetSizeInBytesFromFbsTensor(const fbs::Tensor& tensor) {
280320 break ;
281321#endif
282322 case fbs::TensorDataType::STRING:
283- ORT_THROW (" String data type is not supported for on-device training" , tensor.name ());
323+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
324+ " String data type is not supported for tensor '" , tensor_name_str,
325+ " ' in on-device training." );
284326 default :
285- ORT_THROW (" Unsupported tensor data type for tensor " , tensor.name ());
327+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
328+ " Unsupported tensor data type '" , tensor_data_type_str,
329+ " ' for tensor '" , tensor_name_str,
330+ " '. Invalid ORT format model." );
286331 }
287- return num_elements * byte_size_of_one_element;
332+
333+ if (!IAllocator::CalcMemSizeForArray (num_elements, byte_size_of_one_element, &size_in_bytes)) {
334+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
335+ " Tensor byte size overflows size_t for tensor '" , tensor_name_str,
336+ " ' with data type '" , tensor_data_type_str,
337+ " '. Invalid ORT format model." );
338+ }
339+
340+ return Status::OK ();
288341}
289342
290343Status LoadInitializerOrtFormat (const fbs::Tensor& fbs_tensor, TensorProto& initializer,
@@ -306,7 +359,14 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
306359 ORT_RETURN_IF (nullptr == fbs_str_data, " Missing string data for initializer. Invalid ORT format model." );
307360 auto mutable_str_data = initializer.mutable_string_data ();
308361 mutable_str_data->Reserve (fbs_str_data->size ());
309- for (const auto * fbs_str : *fbs_str_data) {
362+ const auto * raw_string_offsets = reinterpret_cast <const uint8_t *>(fbs_str_data->Data ());
363+ for (flatbuffers::uoffset_t i = 0 ; i < fbs_str_data->size (); ++i) {
364+ const auto entry_offset =
365+ flatbuffers::ReadScalar<flatbuffers::uoffset_t >(raw_string_offsets + i * sizeof (flatbuffers::uoffset_t ));
366+ ORT_RETURN_IF (entry_offset == 0 , " Null string data entry for initializer. Invalid ORT format model." );
367+
368+ const auto * fbs_str = fbs_str_data->Get (i);
369+ ORT_RETURN_IF (nullptr == fbs_str, " Null string data entry for initializer. Invalid ORT format model." );
310370 mutable_str_data->Add (fbs_str->str ());
311371 }
312372 } else {
@@ -338,7 +398,8 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
338398
339399 // FUTURE: This could be setup similarly to can_use_flatbuffer_for_initializers above if the external data file
340400 // is memory mapped and guaranteed to remain valid. This would avoid the copy.
341- auto num_bytes = GetSizeInBytesFromFbsTensor (fbs_tensor);
401+ size_t num_bytes = 0 ;
402+ ORT_RETURN_IF_ERROR (GetSizeInBytesFromFbsTensor (fbs_tensor, num_bytes));
342403
343404 // pre-allocate so we can write directly to the string buffer
344405 std::string& raw_data = *initializer.mutable_raw_data ();
@@ -542,7 +603,8 @@ struct UnpackTensorWithType {
542603 // no external data. should have had raw data.
543604 ORT_RETURN_IF (fbs_tensor_external_data_offset < 0 , " Missing raw data for initializer. Invalid ORT format model." );
544605
545- const size_t raw_data_len = fbs::utils::GetSizeInBytesFromFbsTensor (fbs_tensor);
606+ size_t raw_data_len = 0 ;
607+ ORT_RETURN_IF_ERROR (fbs::utils::GetSizeInBytesFromFbsTensor (fbs_tensor, raw_data_len));
546608
547609 auto raw_buf = std::make_unique<uint8_t []>(raw_data_len);
548610 gsl::span<uint8_t > raw_buf_span (raw_buf.get (), raw_data_len);
0 commit comments