@@ -285,13 +285,30 @@ Error defineTensor(
285285 }
286286
287287 ET_CHECK_OR_RETURN_ERROR (
288- tensor_value != nullptr ,
289- Internal,
290- " Deserialized Tensor is Null, this should never happen" );
288+ tensor_value != nullptr , InvalidProgram, " Deserialized tensor is null" );
289+
290+ ET_CHECK_OR_RETURN_ERROR (
291+ tensor_value->num_dims () == 0 || tensor_value->dims () != nullptr ,
292+ InvalidProgram,
293+ " Tensor dims is null but num_dims is %u" ,
294+ tensor_value->num_dims ());
295+
296+ if (tensor_value->dims () != nullptr ) {
297+ ET_CHECK_OR_RETURN_ERROR (
298+ tensor_value->num_dims () == tensor_value->dims ()->size (),
299+ InvalidProgram,
300+ " Tensor num_dims %u does not match dims array size %u" ,
301+ tensor_value->num_dims (),
302+ tensor_value->dims ()->size ());
303+ }
291304
292305 // Get tensor dims, here we need to use a vector in order
293- // to properly convert the uint32_t* to size_t*
294- std::vector<size_t > dims_data = flatbufferDimsToVector (tensor_value->dims ());
306+ // to properly convert the uint32_t* to size_t*. For scalar tensors
307+ // (num_dims == 0), dims() is permitted to be null per the check above.
308+ std::vector<size_t > dims_data;
309+ if (tensor_value->dims () != nullptr ) {
310+ dims_data = flatbufferDimsToVector (tensor_value->dims ());
311+ }
295312
296313 // XNNPACK Id
297314 uint32_t id = XNN_INVALID_VALUE_ID;
@@ -966,10 +983,14 @@ Error defineStaticTransposeNode(
966983 auto graph_node = node->xnode_union_as_XNNStaticTranspose ();
967984
968985 // Get tensor dims, we need to convert the uint32_t* to size_t*
986+ ET_CHECK_OR_RETURN_ERROR (
987+ graph_node->perm () != nullptr ,
988+ InvalidProgram,
989+ " StaticTranspose: perm is null" );
969990 std::vector<size_t > dims_data = flatbufferDimsToVector (graph_node->perm ());
970991 xnn_status status = xnn_define_static_transpose (
971992 subgraph_ptr,
972- graph_node-> num_dims (),
993+ dims_data. size (),
973994 dims_data.data (),
974995 remapped_ids.at (graph_node->input_id ()),
975996 remapped_ids.at (graph_node->output_id ()),
@@ -1031,6 +1052,11 @@ Error defineStaticConstantPadNode(
10311052 const fb_xnnpack::XNNStaticConstantPad* graph_node =
10321053 node->xnode_union_as_XNNStaticConstantPad ();
10331054
1055+ ET_CHECK_OR_RETURN_ERROR (
1056+ graph_node->pre_paddings () != nullptr &&
1057+ graph_node->post_paddings () != nullptr ,
1058+ InvalidProgram,
1059+ " StaticConstantPad: pre_paddings or post_paddings is null" );
10341060 std::vector<size_t > pre_paddings_dims =
10351061 flatbufferDimsToVector (graph_node->pre_paddings ());
10361062 std::vector<size_t > post_paddings_dims =
@@ -1111,11 +1137,15 @@ Error defineStaticReshapeNode(
11111137 auto graph_node = node->xnode_union_as_XNNStaticReshape ();
11121138
11131139 // Get tensor dims, we need to convert the uint32_t* to size_t*
1140+ ET_CHECK_OR_RETURN_ERROR (
1141+ graph_node->new_shape () != nullptr ,
1142+ InvalidProgram,
1143+ " StaticReshape: new_shape is null" );
11141144 std::vector<size_t > dims_data =
11151145 flatbufferDimsToVector (graph_node->new_shape ());
11161146 xnn_status status = xnn_define_static_reshape (
11171147 subgraph_ptr,
1118- graph_node-> num_dims (),
1148+ dims_data. size (),
11191149 dims_data.data (),
11201150 remapped_ids.at (graph_node->input_id ()),
11211151 remapped_ids.at (graph_node->output_id ()),
@@ -1406,12 +1436,23 @@ Error defineStaticSliceNode(
14061436
14071437 auto graph_node = node->xnode_union_as_XNNStaticSlice ();
14081438
1439+ ET_CHECK_OR_RETURN_ERROR (
1440+ graph_node->offsets () != nullptr && graph_node->sizes () != nullptr ,
1441+ InvalidProgram,
1442+ " StaticSlice: offsets or sizes is null" );
14091443 std::vector<size_t > offsets = flatbufferDimsToVector (graph_node->offsets ());
14101444 std::vector<size_t > sizes = flatbufferDimsToVector (graph_node->sizes ());
14111445
1446+ ET_CHECK_OR_RETURN_ERROR (
1447+ offsets.size () == sizes.size (),
1448+ InvalidProgram,
1449+ " StaticSlice: offsets size %zu does not match sizes size %zu" ,
1450+ offsets.size (),
1451+ sizes.size ());
1452+
14121453 xnn_status status = xnn_define_static_slice (
14131454 subgraph_ptr,
1414- graph_node-> num_dims (),
1455+ offsets. size (),
14151456 offsets.data (),
14161457 sizes.data (),
14171458 remapped_ids.at (graph_node->input_id ()),
0 commit comments