@@ -150,18 +150,33 @@ bool isQuantizedDataType(const xnn_datatype data_type) {
150150 }
151151}
152152
153+ // Max dimension value allowed for a tensor. This is also used for validating
154+ // permutation values and padding values.
155+ constexpr uint32_t kMaxDimensionValue = 1u << 24 ; // 16M per dimension
153156/* *
154157Converts dims from uint32 to size_t. Takes in a flatbuffer vector
155158of uint32_t and returns a std::vector of size_t. XNNPACK takes in
156159dims of size_t* but tensor shape is serialized in flatbuffer as
157- int32_t. As a result, we need to static cast the shapes to size_t
160+ int32_t. As a result, we need to static cast the shapes to size_t.
161+ Individual dimension values are validated to prevent unbounded memory
162+ allocation from malicious inputs.
158163*/
159164template <typename T = size_t >
160- std::vector<T> flatbufferDimsToVector (
165+ Result< std::vector<T> > flatbufferDimsToVector (
161166 const flatbuffers::Vector<uint32_t >* fb_dims) {
167+ ET_CHECK_OR_RETURN_ERROR (
168+ fb_dims != nullptr ,
169+ InvalidProgram,
170+ " flatbufferDimsToVector: dims vector is null" );
162171 std::vector<T> dims_data;
163172 dims_data.reserve (fb_dims->size ());
164173 for (auto fb_dim : *fb_dims) {
174+ ET_CHECK_OR_RETURN_ERROR (
175+ fb_dim <= kMaxDimensionValue ,
176+ InvalidProgram,
177+ " Dimension value %u exceeds maximum allowed %u" ,
178+ fb_dim,
179+ kMaxDimensionValue );
165180 dims_data.push_back (static_cast <T>(fb_dim));
166181 }
167182 return dims_data;
@@ -285,13 +300,24 @@ Error defineTensor(
285300 }
286301
287302 ET_CHECK_OR_RETURN_ERROR (
288- tensor_value != nullptr ,
289- Internal,
290- " Deserialized Tensor is Null, this should never happen" );
303+ tensor_value != nullptr && tensor_value->dims () != nullptr ,
304+ InvalidProgram,
305+ " Deserialized tensor is null, or tensor dims is null" );
306+
307+ ET_CHECK_OR_RETURN_ERROR (
308+ tensor_value->num_dims () == tensor_value->dims ()->size (),
309+ InvalidProgram,
310+ " Tensor num_dims %u does not match dims array size %u" ,
311+ tensor_value->num_dims (),
312+ tensor_value->dims ()->size ());
291313
292314 // Get tensor dims, here we need to use a vector in order
293315 // to properly convert the uint32_t* to size_t*
294- std::vector<size_t > dims_data = flatbufferDimsToVector (tensor_value->dims ());
316+ auto dims_result = flatbufferDimsToVector (tensor_value->dims ());
317+ if (!dims_result.ok ()) {
318+ return dims_result.error ();
319+ }
320+ std::vector<size_t > dims_data = std::move (dims_result.get ());
295321
296322 // XNNPACK Id
297323 uint32_t id = XNN_INVALID_VALUE_ID;
@@ -966,7 +992,12 @@ Error defineStaticTransposeNode(
966992 auto graph_node = node->xnode_union_as_XNNStaticTranspose ();
967993
968994 // Get tensor dims, we need to convert the uint32_t* to size_t*
969- std::vector<size_t > dims_data = flatbufferDimsToVector (graph_node->perm ());
995+ auto dims_result = flatbufferDimsToVector (graph_node->perm ());
996+ if (!dims_result.ok ()) {
997+ return dims_result.error ();
998+ }
999+ std::vector<size_t > dims_data = std::move (dims_result.get ());
1000+
9701001 xnn_status status = xnn_define_static_transpose (
9711002 subgraph_ptr,
9721003 graph_node->num_dims (),
@@ -1031,10 +1062,16 @@ Error defineStaticConstantPadNode(
10311062 const fb_xnnpack::XNNStaticConstantPad* graph_node =
10321063 node->xnode_union_as_XNNStaticConstantPad ();
10331064
1034- std::vector<size_t > pre_paddings_dims =
1035- flatbufferDimsToVector (graph_node->pre_paddings ());
1036- std::vector<size_t > post_paddings_dims =
1037- flatbufferDimsToVector (graph_node->post_paddings ());
1065+ auto pre_result = flatbufferDimsToVector (graph_node->pre_paddings ());
1066+ if (!pre_result.ok ()) {
1067+ return pre_result.error ();
1068+ }
1069+ std::vector<size_t > pre_paddings_dims = std::move (pre_result.get ());
1070+ auto post_result = flatbufferDimsToVector (graph_node->post_paddings ());
1071+ if (!post_result.ok ()) {
1072+ return post_result.error ();
1073+ }
1074+ std::vector<size_t > post_paddings_dims = std::move (post_result.get ());
10381075
10391076 xnn_status status = xnn_define_static_constant_pad (
10401077 subgraph_ptr,
@@ -1111,8 +1148,12 @@ Error defineStaticReshapeNode(
11111148 auto graph_node = node->xnode_union_as_XNNStaticReshape ();
11121149
11131150 // Get tensor dims, we need to convert the uint32_t* to size_t*
1114- std::vector<size_t > dims_data =
1115- flatbufferDimsToVector (graph_node->new_shape ());
1151+ auto dims_result = flatbufferDimsToVector (graph_node->new_shape ());
1152+ if (!dims_result.ok ()) {
1153+ return dims_result.error ();
1154+ }
1155+ std::vector<size_t > dims_data = std::move (dims_result.get ());
1156+
11161157 xnn_status status = xnn_define_static_reshape (
11171158 subgraph_ptr,
11181159 graph_node->num_dims (),
@@ -1406,8 +1447,16 @@ Error defineStaticSliceNode(
14061447
14071448 auto graph_node = node->xnode_union_as_XNNStaticSlice ();
14081449
1409- std::vector<size_t > offsets = flatbufferDimsToVector (graph_node->offsets ());
1410- std::vector<size_t > sizes = flatbufferDimsToVector (graph_node->sizes ());
1450+ auto offsets_result = flatbufferDimsToVector (graph_node->offsets ());
1451+ if (!offsets_result.ok ()) {
1452+ return offsets_result.error ();
1453+ }
1454+ std::vector<size_t > offsets = std::move (offsets_result.get ());
1455+ auto sizes_result = flatbufferDimsToVector (graph_node->sizes ());
1456+ if (!sizes_result.ok ()) {
1457+ return sizes_result.error ();
1458+ }
1459+ std::vector<size_t > sizes = std::move (sizes_result.get ());
14111460
14121461 xnn_status status = xnn_define_static_slice (
14131462 subgraph_ptr,
0 commit comments