Skip to content

Commit 459b8f5

Browse files
author
Github Executorch
committed
Fix TOB-EXECUTORCH-39, -42: validate tensor dimensions in XNNPACK compiler
Validate that dims array is non-null and num_dims matches the actual array size in defineTensor to prevent heap buffer overflows. Change flatbufferDimsToVector to return Result<> with null-check and per-dimension validation against a 16M limit to prevent unbounded memory allocation from malicious dimension values. Authored-with: Claude
1 parent 5e8a0df commit 459b8f5

1 file changed

Lines changed: 64 additions & 15 deletions

File tree

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -150,18 +150,33 @@ bool isQuantizedDataType(const xnn_datatype data_type) {
150150
}
151151
}
152152

153+
// Max dimension value allowed for a tensor. This is also used for validating
154+
// permutation values and padding values.
155+
constexpr uint32_t kMaxDimensionValue = 1u << 24; // 16M per dimension
153156
/**
154157
Converts dims from uint32 to size_t. Takes in a flatbuffer vector
155158
of uint32_t and returns a std::vector of size_t. XNNPACK takes in
156159
dims of size_t* but tensor shape is serialized in flatbuffer as
157-
int32_t. As a result, we need to static cast the shapes to size_t
160+
int32_t. As a result, we need to static cast the shapes to size_t.
161+
Individual dimension values are validated to prevent unbounded memory
162+
allocation from malicious inputs.
158163
*/
159164
template <typename T = size_t>
160-
std::vector<T> flatbufferDimsToVector(
165+
Result<std::vector<T>> flatbufferDimsToVector(
161166
const flatbuffers::Vector<uint32_t>* fb_dims) {
167+
ET_CHECK_OR_RETURN_ERROR(
168+
fb_dims != nullptr,
169+
InvalidProgram,
170+
"flatbufferDimsToVector: dims vector is null");
162171
std::vector<T> dims_data;
163172
dims_data.reserve(fb_dims->size());
164173
for (auto fb_dim : *fb_dims) {
174+
ET_CHECK_OR_RETURN_ERROR(
175+
fb_dim <= kMaxDimensionValue,
176+
InvalidProgram,
177+
"Dimension value %u exceeds maximum allowed %u",
178+
fb_dim,
179+
kMaxDimensionValue);
165180
dims_data.push_back(static_cast<T>(fb_dim));
166181
}
167182
return dims_data;
@@ -285,13 +300,24 @@ Error defineTensor(
285300
}
286301

287302
ET_CHECK_OR_RETURN_ERROR(
288-
tensor_value != nullptr,
289-
Internal,
290-
"Deserialized Tensor is Null, this should never happen");
303+
tensor_value != nullptr && tensor_value->dims() != nullptr,
304+
InvalidProgram,
305+
"Deserialized tensor is null, or tensor dims is null");
306+
307+
ET_CHECK_OR_RETURN_ERROR(
308+
tensor_value->num_dims() == tensor_value->dims()->size(),
309+
InvalidProgram,
310+
"Tensor num_dims %u does not match dims array size %u",
311+
tensor_value->num_dims(),
312+
tensor_value->dims()->size());
291313

292314
// Get tensor dims, here we need to use a vector in order
293315
// to properly convert the uint32_t* to size_t*
294-
std::vector<size_t> dims_data = flatbufferDimsToVector(tensor_value->dims());
316+
auto dims_result = flatbufferDimsToVector(tensor_value->dims());
317+
if (!dims_result.ok()) {
318+
return dims_result.error();
319+
}
320+
std::vector<size_t> dims_data = std::move(dims_result.get());
295321

296322
// XNNPACK Id
297323
uint32_t id = XNN_INVALID_VALUE_ID;
@@ -966,7 +992,12 @@ Error defineStaticTransposeNode(
966992
auto graph_node = node->xnode_union_as_XNNStaticTranspose();
967993

968994
// Get tensor dims, we need to convert the uint32_t* to size_t*
969-
std::vector<size_t> dims_data = flatbufferDimsToVector(graph_node->perm());
995+
auto dims_result = flatbufferDimsToVector(graph_node->perm());
996+
if (!dims_result.ok()) {
997+
return dims_result.error();
998+
}
999+
std::vector<size_t> dims_data = std::move(dims_result.get());
1000+
9701001
xnn_status status = xnn_define_static_transpose(
9711002
subgraph_ptr,
9721003
graph_node->num_dims(),
@@ -1031,10 +1062,16 @@ Error defineStaticConstantPadNode(
10311062
const fb_xnnpack::XNNStaticConstantPad* graph_node =
10321063
node->xnode_union_as_XNNStaticConstantPad();
10331064

1034-
std::vector<size_t> pre_paddings_dims =
1035-
flatbufferDimsToVector(graph_node->pre_paddings());
1036-
std::vector<size_t> post_paddings_dims =
1037-
flatbufferDimsToVector(graph_node->post_paddings());
1065+
auto pre_result = flatbufferDimsToVector(graph_node->pre_paddings());
1066+
if (!pre_result.ok()) {
1067+
return pre_result.error();
1068+
}
1069+
std::vector<size_t> pre_paddings_dims = std::move(pre_result.get());
1070+
auto post_result = flatbufferDimsToVector(graph_node->post_paddings());
1071+
if (!post_result.ok()) {
1072+
return post_result.error();
1073+
}
1074+
std::vector<size_t> post_paddings_dims = std::move(post_result.get());
10381075

10391076
xnn_status status = xnn_define_static_constant_pad(
10401077
subgraph_ptr,
@@ -1111,8 +1148,12 @@ Error defineStaticReshapeNode(
11111148
auto graph_node = node->xnode_union_as_XNNStaticReshape();
11121149

11131150
// Get tensor dims, we need to convert the uint32_t* to size_t*
1114-
std::vector<size_t> dims_data =
1115-
flatbufferDimsToVector(graph_node->new_shape());
1151+
auto dims_result = flatbufferDimsToVector(graph_node->new_shape());
1152+
if (!dims_result.ok()) {
1153+
return dims_result.error();
1154+
}
1155+
std::vector<size_t> dims_data = std::move(dims_result.get());
1156+
11161157
xnn_status status = xnn_define_static_reshape(
11171158
subgraph_ptr,
11181159
graph_node->num_dims(),
@@ -1406,8 +1447,16 @@ Error defineStaticSliceNode(
14061447

14071448
auto graph_node = node->xnode_union_as_XNNStaticSlice();
14081449

1409-
std::vector<size_t> offsets = flatbufferDimsToVector(graph_node->offsets());
1410-
std::vector<size_t> sizes = flatbufferDimsToVector(graph_node->sizes());
1450+
auto offsets_result = flatbufferDimsToVector(graph_node->offsets());
1451+
if (!offsets_result.ok()) {
1452+
return offsets_result.error();
1453+
}
1454+
std::vector<size_t> offsets = std::move(offsets_result.get());
1455+
auto sizes_result = flatbufferDimsToVector(graph_node->sizes());
1456+
if (!sizes_result.ok()) {
1457+
return sizes_result.error();
1458+
}
1459+
std::vector<size_t> sizes = std::move(sizes_result.get());
14111460

14121461
xnn_status status = xnn_define_static_slice(
14131462
subgraph_ptr,

0 commit comments

Comments
 (0)