@@ -154,14 +154,28 @@ bool isQuantizedDataType(const xnn_datatype data_type) {
154154Converts dims from uint32 to size_t. Takes in a flatbuffer vector
155155of uint32_t and returns a std::vector of size_t. XNNPACK takes in
156156dims of size_t* but tensor shape is serialized in flatbuffer as
157- int32_t. As a result, we need to static cast the shapes to size_t
157+ int32_t. As a result, we need to static cast the shapes to size_t.
158+ Individual dimension values are validated to prevent unbounded memory
159+ allocation from malicious inputs.
158160*/
161+ constexpr uint32_t kMaxDimensionValue = 1 << 24 ; // 16M per dimension
162+
159163template <typename T = size_t >
160- std::vector<T> flatbufferDimsToVector (
164+ Result< std::vector<T> > flatbufferDimsToVector (
161165 const flatbuffers::Vector<uint32_t >* fb_dims) {
166+ ET_CHECK_OR_RETURN_ERROR (
167+ fb_dims != nullptr ,
168+ InvalidProgram,
169+ " flatbufferDimsToVector: dims vector is null" );
162170 std::vector<T> dims_data;
163171 dims_data.reserve (fb_dims->size ());
164172 for (auto fb_dim : *fb_dims) {
173+ ET_CHECK_OR_RETURN_ERROR (
174+ fb_dim <= kMaxDimensionValue ,
175+ InvalidProgram,
176+ " Dimension value %u exceeds maximum allowed %u" ,
177+ fb_dim,
178+ kMaxDimensionValue );
165179 dims_data.push_back (static_cast <T>(fb_dim));
166180 }
167181 return dims_data;
@@ -285,13 +299,24 @@ Error defineTensor(
285299 }
286300
287301 ET_CHECK_OR_RETURN_ERROR (
288- tensor_value != nullptr ,
289- Internal,
290- " Deserialized Tensor is Null, this should never happen" );
302+ tensor_value != nullptr && tensor_value->dims () != nullptr ,
303+ InvalidProgram,
304+ " Deserialized tensor is null, or tensor dims is null" );
305+
306+ ET_CHECK_OR_RETURN_ERROR (
307+ tensor_value->num_dims () == tensor_value->dims ()->size (),
308+ InvalidProgram,
309+ " Tensor num_dims %u does not match dims array size %u" ,
310+ tensor_value->num_dims (),
311+ tensor_value->dims ()->size ());
291312
292313 // Get tensor dims, here we need to use a vector in order
293314 // to properly convert the uint32_t* to size_t*
294- std::vector<size_t > dims_data = flatbufferDimsToVector (tensor_value->dims ());
315+ auto dims_result = flatbufferDimsToVector (tensor_value->dims ());
316+ if (!dims_result.ok ()) {
317+ return dims_result.error ();
318+ }
319+ std::vector<size_t > dims_data = std::move (dims_result.get ());
295320
296321 // XNNPACK Id
297322 uint32_t id = XNN_INVALID_VALUE_ID;
@@ -966,7 +991,12 @@ Error defineStaticTransposeNode(
966991 auto graph_node = node->xnode_union_as_XNNStaticTranspose ();
967992
968993 // Get tensor dims, we need to convert the uint32_t* to size_t*
969- std::vector<size_t > dims_data = flatbufferDimsToVector (graph_node->perm ());
994+ auto dims_result = flatbufferDimsToVector (graph_node->perm ());
995+ if (!dims_result.ok ()) {
996+ return dims_result.error ();
997+ }
998+ std::vector<size_t > dims_data = std::move (dims_result.get ());
999+
9701000 xnn_status status = xnn_define_static_transpose (
9711001 subgraph_ptr,
9721002 graph_node->num_dims (),
@@ -1031,10 +1061,16 @@ Error defineStaticConstantPadNode(
10311061 const fb_xnnpack::XNNStaticConstantPad* graph_node =
10321062 node->xnode_union_as_XNNStaticConstantPad ();
10331063
1034- std::vector<size_t > pre_paddings_dims =
1035- flatbufferDimsToVector (graph_node->pre_paddings ());
1036- std::vector<size_t > post_paddings_dims =
1037- flatbufferDimsToVector (graph_node->post_paddings ());
1064+ auto pre_result = flatbufferDimsToVector (graph_node->pre_paddings ());
1065+ if (!pre_result.ok ()) {
1066+ return pre_result.error ();
1067+ }
1068+ std::vector<size_t > pre_paddings_dims = std::move (pre_result.get ());
1069+ auto post_result = flatbufferDimsToVector (graph_node->post_paddings ());
1070+ if (!post_result.ok ()) {
1071+ return post_result.error ();
1072+ }
1073+ std::vector<size_t > post_paddings_dims = std::move (post_result.get ());
10381074
10391075 xnn_status status = xnn_define_static_constant_pad (
10401076 subgraph_ptr,
@@ -1111,8 +1147,12 @@ Error defineStaticReshapeNode(
11111147 auto graph_node = node->xnode_union_as_XNNStaticReshape ();
11121148
11131149 // Get tensor dims, we need to convert the uint32_t* to size_t*
1114- std::vector<size_t > dims_data =
1115- flatbufferDimsToVector (graph_node->new_shape ());
1150+ auto dims_result = flatbufferDimsToVector (graph_node->new_shape ());
1151+ if (!dims_result.ok ()) {
1152+ return dims_result.error ();
1153+ }
1154+ std::vector<size_t > dims_data = std::move (dims_result.get ());
1155+
11161156 xnn_status status = xnn_define_static_reshape (
11171157 subgraph_ptr,
11181158 graph_node->num_dims (),
@@ -1406,8 +1446,16 @@ Error defineStaticSliceNode(
14061446
14071447 auto graph_node = node->xnode_union_as_XNNStaticSlice ();
14081448
1409- std::vector<size_t > offsets = flatbufferDimsToVector (graph_node->offsets ());
1410- std::vector<size_t > sizes = flatbufferDimsToVector (graph_node->sizes ());
1449+ auto offsets_result = flatbufferDimsToVector (graph_node->offsets ());
1450+ if (!offsets_result.ok ()) {
1451+ return offsets_result.error ();
1452+ }
1453+ std::vector<size_t > offsets = std::move (offsets_result.get ());
1454+ auto sizes_result = flatbufferDimsToVector (graph_node->sizes ());
1455+ if (!sizes_result.ok ()) {
1456+ return sizes_result.error ();
1457+ }
1458+ std::vector<size_t > sizes = std::move (sizes_result.get ());
14111459
14121460 xnn_status status = xnn_define_static_slice (
14131461 subgraph_ptr,
0 commit comments