Skip to content

Commit d7cf003

Browse files
author
Github Executorch
committed
Fix TOB-EXECUTORCH-39, -42: validate tensor dimensions in XNNPACK compiler
Validate that dims array is non-null and num_dims matches the actual array size in defineTensor to prevent heap buffer overflows. Change flatbufferDimsToVector to return Result<> with null-check and per-dimension validation against a 16M limit to prevent unbounded memory allocation from malicious dimension values. Authored-with: Claude
1 parent 5e8a0df commit d7cf003

1 file changed

Lines changed: 63 additions & 15 deletions

File tree

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,28 @@ bool isQuantizedDataType(const xnn_datatype data_type) {
154154
Converts dims from uint32 to size_t. Takes in a flatbuffer vector
155155
of uint32_t and returns a std::vector of size_t. XNNPACK takes in
156156
dims of size_t* but tensor shape is serialized in flatbuffer as
157-
int32_t. As a result, we need to static cast the shapes to size_t
157+
int32_t. As a result, we need to static cast the shapes to size_t.
158+
Individual dimension values are validated to prevent unbounded memory
159+
allocation from malicious inputs.
158160
*/
161+
constexpr uint32_t kMaxDimensionValue = 1 << 24; // 16M per dimension
162+
159163
template <typename T = size_t>
160-
std::vector<T> flatbufferDimsToVector(
164+
Result<std::vector<T>> flatbufferDimsToVector(
161165
const flatbuffers::Vector<uint32_t>* fb_dims) {
166+
ET_CHECK_OR_RETURN_ERROR(
167+
fb_dims != nullptr,
168+
InvalidProgram,
169+
"flatbufferDimsToVector: dims vector is null");
162170
std::vector<T> dims_data;
163171
dims_data.reserve(fb_dims->size());
164172
for (auto fb_dim : *fb_dims) {
173+
ET_CHECK_OR_RETURN_ERROR(
174+
fb_dim <= kMaxDimensionValue,
175+
InvalidProgram,
176+
"Dimension value %u exceeds maximum allowed %u",
177+
fb_dim,
178+
kMaxDimensionValue);
165179
dims_data.push_back(static_cast<T>(fb_dim));
166180
}
167181
return dims_data;
@@ -285,13 +299,24 @@ Error defineTensor(
285299
}
286300

287301
ET_CHECK_OR_RETURN_ERROR(
288-
tensor_value != nullptr,
289-
Internal,
290-
"Deserialized Tensor is Null, this should never happen");
302+
tensor_value != nullptr && tensor_value->dims() != nullptr,
303+
InvalidProgram,
304+
"Deserialized tensor is null, or tensor dims is null");
305+
306+
ET_CHECK_OR_RETURN_ERROR(
307+
tensor_value->num_dims() == tensor_value->dims()->size(),
308+
InvalidProgram,
309+
"Tensor num_dims %u does not match dims array size %u",
310+
tensor_value->num_dims(),
311+
tensor_value->dims()->size());
291312

292313
// Get tensor dims, here we need to use a vector in order
293314
// to properly convert the uint32_t* to size_t*
294-
std::vector<size_t> dims_data = flatbufferDimsToVector(tensor_value->dims());
315+
auto dims_result = flatbufferDimsToVector(tensor_value->dims());
316+
if (!dims_result.ok()) {
317+
return dims_result.error();
318+
}
319+
std::vector<size_t> dims_data = std::move(dims_result.get());
295320

296321
// XNNPACK Id
297322
uint32_t id = XNN_INVALID_VALUE_ID;
@@ -966,7 +991,12 @@ Error defineStaticTransposeNode(
966991
auto graph_node = node->xnode_union_as_XNNStaticTranspose();
967992

968993
// Get tensor dims, we need to convert the uint32_t* to size_t*
969-
std::vector<size_t> dims_data = flatbufferDimsToVector(graph_node->perm());
994+
auto dims_result = flatbufferDimsToVector(graph_node->perm());
995+
if (!dims_result.ok()) {
996+
return dims_result.error();
997+
}
998+
std::vector<size_t> dims_data = std::move(dims_result.get());
999+
9701000
xnn_status status = xnn_define_static_transpose(
9711001
subgraph_ptr,
9721002
graph_node->num_dims(),
@@ -1031,10 +1061,16 @@ Error defineStaticConstantPadNode(
10311061
const fb_xnnpack::XNNStaticConstantPad* graph_node =
10321062
node->xnode_union_as_XNNStaticConstantPad();
10331063

1034-
std::vector<size_t> pre_paddings_dims =
1035-
flatbufferDimsToVector(graph_node->pre_paddings());
1036-
std::vector<size_t> post_paddings_dims =
1037-
flatbufferDimsToVector(graph_node->post_paddings());
1064+
auto pre_result = flatbufferDimsToVector(graph_node->pre_paddings());
1065+
if (!pre_result.ok()) {
1066+
return pre_result.error();
1067+
}
1068+
std::vector<size_t> pre_paddings_dims = std::move(pre_result.get());
1069+
auto post_result = flatbufferDimsToVector(graph_node->post_paddings());
1070+
if (!post_result.ok()) {
1071+
return post_result.error();
1072+
}
1073+
std::vector<size_t> post_paddings_dims = std::move(post_result.get());
10381074

10391075
xnn_status status = xnn_define_static_constant_pad(
10401076
subgraph_ptr,
@@ -1111,8 +1147,12 @@ Error defineStaticReshapeNode(
11111147
auto graph_node = node->xnode_union_as_XNNStaticReshape();
11121148

11131149
// Get tensor dims, we need to convert the uint32_t* to size_t*
1114-
std::vector<size_t> dims_data =
1115-
flatbufferDimsToVector(graph_node->new_shape());
1150+
auto dims_result = flatbufferDimsToVector(graph_node->new_shape());
1151+
if (!dims_result.ok()) {
1152+
return dims_result.error();
1153+
}
1154+
std::vector<size_t> dims_data = std::move(dims_result.get());
1155+
11161156
xnn_status status = xnn_define_static_reshape(
11171157
subgraph_ptr,
11181158
graph_node->num_dims(),
@@ -1406,8 +1446,16 @@ Error defineStaticSliceNode(
14061446

14071447
auto graph_node = node->xnode_union_as_XNNStaticSlice();
14081448

1409-
std::vector<size_t> offsets = flatbufferDimsToVector(graph_node->offsets());
1410-
std::vector<size_t> sizes = flatbufferDimsToVector(graph_node->sizes());
1449+
auto offsets_result = flatbufferDimsToVector(graph_node->offsets());
1450+
if (!offsets_result.ok()) {
1451+
return offsets_result.error();
1452+
}
1453+
std::vector<size_t> offsets = std::move(offsets_result.get());
1454+
auto sizes_result = flatbufferDimsToVector(graph_node->sizes());
1455+
if (!sizes_result.ok()) {
1456+
return sizes_result.error();
1457+
}
1458+
std::vector<size_t> sizes = std::move(sizes_result.get());
14111459

14121460
xnn_status status = xnn_define_static_slice(
14131461
subgraph_ptr,

0 commit comments

Comments
 (0)