Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,30 @@ Error defineTensor(
}

ET_CHECK_OR_RETURN_ERROR(
tensor_value != nullptr,
Internal,
"Deserialized Tensor is Null, this should never happen");
tensor_value != nullptr, InvalidProgram, "Deserialized tensor is null");

Comment thread
lucylq marked this conversation as resolved.
ET_CHECK_OR_RETURN_ERROR(
tensor_value->num_dims() == 0 || tensor_value->dims() != nullptr,
InvalidProgram,
"Tensor dims is null but num_dims is %u",
tensor_value->num_dims());

if (tensor_value->dims() != nullptr) {
ET_CHECK_OR_RETURN_ERROR(
tensor_value->num_dims() == tensor_value->dims()->size(),
InvalidProgram,
"Tensor num_dims %u does not match dims array size %u",
tensor_value->num_dims(),
tensor_value->dims()->size());
}
Comment on lines +322 to +335

// Get tensor dims, here we need to use a vector in order
// to properly convert the uint32_t* to size_t*
std::vector<size_t> dims_data = flatbufferDimsToVector(tensor_value->dims());
// to properly convert the uint32_t* to size_t*. For scalar tensors
// (num_dims == 0), dims() is permitted to be null per the check above.
std::vector<size_t> dims_data;
if (tensor_value->dims() != nullptr) {
dims_data = flatbufferDimsToVector(tensor_value->dims());
}

// XNNPACK Id
uint32_t id = XNN_INVALID_VALUE_ID;
Expand Down Expand Up @@ -1052,14 +1069,18 @@ Error defineStaticTransposeNode(
auto graph_node = node->xnode_union_as_XNNStaticTranspose();

// Get tensor dims, we need to convert the uint32_t* to size_t*
ET_CHECK_OR_RETURN_ERROR(
graph_node->perm() != nullptr,
InvalidProgram,
"StaticTranspose: perm is null");
std::vector<size_t> dims_data = flatbufferDimsToVector(graph_node->perm());

REMAP_ID(remapped_ids, graph_node->input_id(), st_input);
REMAP_ID(remapped_ids, graph_node->output_id(), st_output);

xnn_status status = xnn_define_static_transpose(
subgraph_ptr,
graph_node->num_dims(),
dims_data.size(),
dims_data.data(),
Comment thread
lucylq marked this conversation as resolved.
st_input,
st_output,
Expand Down Expand Up @@ -1123,6 +1144,11 @@ Error defineStaticConstantPadNode(
const fb_xnnpack::XNNStaticConstantPad* graph_node =
node->xnode_union_as_XNNStaticConstantPad();

ET_CHECK_OR_RETURN_ERROR(
graph_node->pre_paddings() != nullptr &&
graph_node->post_paddings() != nullptr,
InvalidProgram,
"StaticConstantPad: pre_paddings or post_paddings is null");
std::vector<size_t> pre_paddings_dims =
flatbufferDimsToVector(graph_node->pre_paddings());
std::vector<size_t> post_paddings_dims =
Expand Down Expand Up @@ -1211,6 +1237,10 @@ Error defineStaticReshapeNode(
auto graph_node = node->xnode_union_as_XNNStaticReshape();

// Get tensor dims, we need to convert the uint32_t* to size_t*
ET_CHECK_OR_RETURN_ERROR(
graph_node->new_shape() != nullptr,
InvalidProgram,
"StaticReshape: new_shape is null");
std::vector<size_t> dims_data =
flatbufferDimsToVector(graph_node->new_shape());

Expand All @@ -1219,7 +1249,7 @@ Error defineStaticReshapeNode(

xnn_status status = xnn_define_static_reshape(
subgraph_ptr,
graph_node->num_dims(),
dims_data.size(),
dims_data.data(),
Comment thread
lucylq marked this conversation as resolved.
sr_input,
sr_output,
Expand Down Expand Up @@ -1532,15 +1562,26 @@ Error defineStaticSliceNode(

auto graph_node = node->xnode_union_as_XNNStaticSlice();

ET_CHECK_OR_RETURN_ERROR(
graph_node->offsets() != nullptr && graph_node->sizes() != nullptr,
InvalidProgram,
"StaticSlice: offsets or sizes is null");
std::vector<size_t> offsets = flatbufferDimsToVector(graph_node->offsets());
std::vector<size_t> sizes = flatbufferDimsToVector(graph_node->sizes());

ET_CHECK_OR_RETURN_ERROR(
offsets.size() == sizes.size(),
InvalidProgram,
"StaticSlice: offsets size %zu does not match sizes size %zu",
offsets.size(),
sizes.size());

REMAP_ID(remapped_ids, graph_node->input_id(), ss_input);
REMAP_ID(remapped_ids, graph_node->output_id(), ss_output);

xnn_status status = xnn_define_static_slice(
subgraph_ptr,
graph_node->num_dims(),
offsets.size(),
offsets.data(),
Comment thread
lucylq marked this conversation as resolved.
sizes.data(),
ss_input,
Expand Down
Loading