Skip to content

Commit 2d53535

Browse files
lucylqGithub Executorch
andauthored
Xnnpack flatbuffer: check non-null and num_dims matches array size in defineTensor (pytorch#18896)
Validate that dims array is non-null and num_dims matches the actual array size in defineTensor to prevent heap buffer overflows. Authored-with: Claude Co-authored-by: Github Executorch <github_executorch@arm.com>
1 parent deaf73b commit 2d53535

1 file changed

Lines changed: 49 additions & 8 deletions

File tree

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,30 @@ Error defineTensor(
317317
}
318318

319319
ET_CHECK_OR_RETURN_ERROR(
320-
tensor_value != nullptr,
321-
Internal,
322-
"Deserialized Tensor is Null, this should never happen");
320+
tensor_value != nullptr, InvalidProgram, "Deserialized tensor is null");
321+
322+
ET_CHECK_OR_RETURN_ERROR(
323+
tensor_value->num_dims() == 0 || tensor_value->dims() != nullptr,
324+
InvalidProgram,
325+
"Tensor dims is null but num_dims is %u",
326+
tensor_value->num_dims());
327+
328+
if (tensor_value->dims() != nullptr) {
329+
ET_CHECK_OR_RETURN_ERROR(
330+
tensor_value->num_dims() == tensor_value->dims()->size(),
331+
InvalidProgram,
332+
"Tensor num_dims %u does not match dims array size %u",
333+
tensor_value->num_dims(),
334+
tensor_value->dims()->size());
335+
}
323336

324337
// Get tensor dims, here we need to use a vector in order
325-
// to properly convert the uint32_t* to size_t*
326-
std::vector<size_t> dims_data = flatbufferDimsToVector(tensor_value->dims());
338+
// to properly convert the uint32_t* to size_t*. For scalar tensors
339+
// (num_dims == 0), dims() is permitted to be null per the check above.
340+
std::vector<size_t> dims_data;
341+
if (tensor_value->dims() != nullptr) {
342+
dims_data = flatbufferDimsToVector(tensor_value->dims());
343+
}
327344

328345
// XNNPACK Id
329346
uint32_t id = XNN_INVALID_VALUE_ID;
@@ -1052,14 +1069,18 @@ Error defineStaticTransposeNode(
10521069
auto graph_node = node->xnode_union_as_XNNStaticTranspose();
10531070

10541071
// Get tensor dims, we need to convert the uint32_t* to size_t*
1072+
ET_CHECK_OR_RETURN_ERROR(
1073+
graph_node->perm() != nullptr,
1074+
InvalidProgram,
1075+
"StaticTranspose: perm is null");
10551076
std::vector<size_t> dims_data = flatbufferDimsToVector(graph_node->perm());
10561077

10571078
REMAP_ID(remapped_ids, graph_node->input_id(), st_input);
10581079
REMAP_ID(remapped_ids, graph_node->output_id(), st_output);
10591080

10601081
xnn_status status = xnn_define_static_transpose(
10611082
subgraph_ptr,
1062-
graph_node->num_dims(),
1083+
dims_data.size(),
10631084
dims_data.data(),
10641085
st_input,
10651086
st_output,
@@ -1123,6 +1144,11 @@ Error defineStaticConstantPadNode(
11231144
const fb_xnnpack::XNNStaticConstantPad* graph_node =
11241145
node->xnode_union_as_XNNStaticConstantPad();
11251146

1147+
ET_CHECK_OR_RETURN_ERROR(
1148+
graph_node->pre_paddings() != nullptr &&
1149+
graph_node->post_paddings() != nullptr,
1150+
InvalidProgram,
1151+
"StaticConstantPad: pre_paddings or post_paddings is null");
11261152
std::vector<size_t> pre_paddings_dims =
11271153
flatbufferDimsToVector(graph_node->pre_paddings());
11281154
std::vector<size_t> post_paddings_dims =
@@ -1211,6 +1237,10 @@ Error defineStaticReshapeNode(
12111237
auto graph_node = node->xnode_union_as_XNNStaticReshape();
12121238

12131239
// Get tensor dims, we need to convert the uint32_t* to size_t*
1240+
ET_CHECK_OR_RETURN_ERROR(
1241+
graph_node->new_shape() != nullptr,
1242+
InvalidProgram,
1243+
"StaticReshape: new_shape is null");
12141244
std::vector<size_t> dims_data =
12151245
flatbufferDimsToVector(graph_node->new_shape());
12161246

@@ -1219,7 +1249,7 @@ Error defineStaticReshapeNode(
12191249

12201250
xnn_status status = xnn_define_static_reshape(
12211251
subgraph_ptr,
1222-
graph_node->num_dims(),
1252+
dims_data.size(),
12231253
dims_data.data(),
12241254
sr_input,
12251255
sr_output,
@@ -1532,15 +1562,26 @@ Error defineStaticSliceNode(
15321562

15331563
auto graph_node = node->xnode_union_as_XNNStaticSlice();
15341564

1565+
ET_CHECK_OR_RETURN_ERROR(
1566+
graph_node->offsets() != nullptr && graph_node->sizes() != nullptr,
1567+
InvalidProgram,
1568+
"StaticSlice: offsets or sizes is null");
15351569
std::vector<size_t> offsets = flatbufferDimsToVector(graph_node->offsets());
15361570
std::vector<size_t> sizes = flatbufferDimsToVector(graph_node->sizes());
15371571

1572+
ET_CHECK_OR_RETURN_ERROR(
1573+
offsets.size() == sizes.size(),
1574+
InvalidProgram,
1575+
"StaticSlice: offsets size %zu does not match sizes size %zu",
1576+
offsets.size(),
1577+
sizes.size());
1578+
15381579
REMAP_ID(remapped_ids, graph_node->input_id(), ss_input);
15391580
REMAP_ID(remapped_ids, graph_node->output_id(), ss_output);
15401581

15411582
xnn_status status = xnn_define_static_slice(
15421583
subgraph_ptr,
1543-
graph_node->num_dims(),
1584+
offsets.size(),
15441585
offsets.data(),
15451586
sizes.data(),
15461587
ss_input,

0 commit comments

Comments
 (0)