diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 352d7af5a14..103bdeb6b82 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -319,6 +319,15 @@ Error defineTensor( ET_CHECK_OR_RETURN_ERROR( tensor_value != nullptr, InvalidProgram, "Deserialized tensor is null"); + // Validate that tensor_value->flags() is a subset of the allowed flags. + constexpr uint32_t kAllowedFlagsMask = + XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT; + ET_CHECK_OR_RETURN_ERROR( + (tensor_value->flags() & ~kAllowedFlagsMask) == 0, + InvalidProgram, + "Tensor value has unsupported flag bits 0x%x", + tensor_value->flags()); + // Get tensor dims, here we need to use a vector in order to properly // convert the uint32_t* to size_t*. Scalar tensors (rank 0) are permitted // to have a null dims vector; in that case dims_data is empty.