Skip to content

Commit dd8fe9d

Browse files
author
Github Executorch
committed
Fix TOB-EXECUTORCH-39, -42: validate tensor dimensions in XNNPACK compiler
Validate that dims array is non-null and num_dims matches the actual array size in defineTensor to prevent heap buffer overflows. Change flatbufferDimsToVector to return Result<> with null-check and per-dimension validation against a 16M limit to prevent unbounded memory allocation from malicious dimension values. Authored-with: Claude
1 parent 5e8a0df commit dd8fe9d

1 file changed

Lines changed: 67 additions & 12 deletions

File tree

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,30 @@ bool isQuantizedDataType(const xnn_datatype data_type) {
154154
Converts dims from uint32 to size_t. Takes in a flatbuffer vector
155155
of uint32_t and returns a std::vector of size_t. XNNPACK takes in
156156
dims of size_t* but tensor shape is serialized in flatbuffer as
157-
int32_t. As a result, we need to static cast the shapes to size_t
157+
int32_t. As a result, we need to static cast the shapes to size_t.
158+
Individual dimension values are validated to prevent unbounded memory
159+
allocation from malicious inputs.
158160
*/
161+
constexpr uint32_t kMaxDimensionValue = 1 << 24; // 16M per dimension
162+
159163
template <typename T = size_t>
160-
std::vector<T> flatbufferDimsToVector(
164+
Result<std::vector<T>> flatbufferDimsToVector(
161165
const flatbuffers::Vector<uint32_t>* fb_dims) {
166+
if (fb_dims == nullptr) {
167+
ET_LOG(Error, "flatbufferDimsToVector: dims vector is null");
168+
return Error::InvalidProgram;
169+
}
162170
std::vector<T> dims_data;
163171
dims_data.reserve(fb_dims->size());
164172
for (auto fb_dim : *fb_dims) {
173+
if (fb_dim > kMaxDimensionValue) {
174+
ET_LOG(
175+
Error,
176+
"Dimension value %u exceeds maximum allowed %u",
177+
fb_dim,
178+
kMaxDimensionValue);
179+
return Error::InvalidProgram;
180+
}
165181
dims_data.push_back(static_cast<T>(fb_dim));
166182
}
167183
return dims_data;
@@ -289,9 +305,25 @@ Error defineTensor(
289305
Internal,
290306
"Deserialized Tensor is Null, this should never happen");
291307

308+
ET_CHECK_OR_RETURN_ERROR(
309+
tensor_value->dims() != nullptr,
310+
InvalidProgram,
311+
"Tensor value has null dims array");
312+
313+
ET_CHECK_OR_RETURN_ERROR(
314+
tensor_value->num_dims() == tensor_value->dims()->size(),
315+
InvalidProgram,
316+
"Tensor num_dims %u does not match dims array size %u",
317+
tensor_value->num_dims(),
318+
tensor_value->dims()->size());
319+
292320
// Get tensor dims, here we need to use a vector in order
293321
// to properly convert the uint32_t* to size_t*
294-
std::vector<size_t> dims_data = flatbufferDimsToVector(tensor_value->dims());
322+
auto dims_result = flatbufferDimsToVector(tensor_value->dims());
323+
if (!dims_result.ok()) {
324+
return dims_result.error();
325+
}
326+
std::vector<size_t> dims_data = std::move(dims_result.get());
295327

296328
// XNNPACK Id
297329
uint32_t id = XNN_INVALID_VALUE_ID;
@@ -966,7 +998,12 @@ Error defineStaticTransposeNode(
966998
auto graph_node = node->xnode_union_as_XNNStaticTranspose();
967999

9681000
// Get tensor dims, we need to convert the uint32_t* to size_t*
969-
std::vector<size_t> dims_data = flatbufferDimsToVector(graph_node->perm());
1001+
auto dims_result = flatbufferDimsToVector(graph_node->perm());
1002+
if (!dims_result.ok()) {
1003+
return dims_result.error();
1004+
}
1005+
std::vector<size_t> dims_data = std::move(dims_result.get());
1006+
9701007
xnn_status status = xnn_define_static_transpose(
9711008
subgraph_ptr,
9721009
graph_node->num_dims(),
@@ -1031,10 +1068,16 @@ Error defineStaticConstantPadNode(
10311068
const fb_xnnpack::XNNStaticConstantPad* graph_node =
10321069
node->xnode_union_as_XNNStaticConstantPad();
10331070

1034-
std::vector<size_t> pre_paddings_dims =
1035-
flatbufferDimsToVector(graph_node->pre_paddings());
1036-
std::vector<size_t> post_paddings_dims =
1037-
flatbufferDimsToVector(graph_node->post_paddings());
1071+
auto pre_result = flatbufferDimsToVector(graph_node->pre_paddings());
1072+
if (!pre_result.ok()) {
1073+
return pre_result.error();
1074+
}
1075+
std::vector<size_t> pre_paddings_dims = std::move(pre_result.get());
1076+
auto post_result = flatbufferDimsToVector(graph_node->post_paddings());
1077+
if (!post_result.ok()) {
1078+
return post_result.error();
1079+
}
1080+
std::vector<size_t> post_paddings_dims = std::move(post_result.get());
10381081

10391082
xnn_status status = xnn_define_static_constant_pad(
10401083
subgraph_ptr,
@@ -1111,8 +1154,12 @@ Error defineStaticReshapeNode(
11111154
auto graph_node = node->xnode_union_as_XNNStaticReshape();
11121155

11131156
// Get tensor dims, we need to convert the uint32_t* to size_t*
1114-
std::vector<size_t> dims_data =
1115-
flatbufferDimsToVector(graph_node->new_shape());
1157+
auto dims_result = flatbufferDimsToVector(graph_node->new_shape());
1158+
if (!dims_result.ok()) {
1159+
return dims_result.error();
1160+
}
1161+
std::vector<size_t> dims_data = std::move(dims_result.get());
1162+
11161163
xnn_status status = xnn_define_static_reshape(
11171164
subgraph_ptr,
11181165
graph_node->num_dims(),
@@ -1406,8 +1453,16 @@ Error defineStaticSliceNode(
14061453

14071454
auto graph_node = node->xnode_union_as_XNNStaticSlice();
14081455

1409-
std::vector<size_t> offsets = flatbufferDimsToVector(graph_node->offsets());
1410-
std::vector<size_t> sizes = flatbufferDimsToVector(graph_node->sizes());
1456+
auto offsets_result = flatbufferDimsToVector(graph_node->offsets());
1457+
if (!offsets_result.ok()) {
1458+
return offsets_result.error();
1459+
}
1460+
std::vector<size_t> offsets = std::move(offsets_result.get());
1461+
auto sizes_result = flatbufferDimsToVector(graph_node->sizes());
1462+
if (!sizes_result.ok()) {
1463+
return sizes_result.error();
1464+
}
1465+
std::vector<size_t> sizes = std::move(sizes_result.get());
14111466

14121467
xnn_status status = xnn_define_static_slice(
14131468
subgraph_ptr,

0 commit comments

Comments
 (0)