diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 4881844ac6d..82353a6e45e 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1816,16 +1816,19 @@ ET_NODISCARD Error XNNCompiler::compileModel( Result header = XNNHeader::Parse(buffer_pointer, num_bytes); const uint8_t* flatbuffer_data = nullptr; const uint8_t* constant_data = nullptr; + size_t flatbuffer_size = 0; CompileAllocator compile_allocator; // Header status can only either be Error::Ok or Error::NotFound if (header.ok()) { flatbuffer_data = reinterpret_cast(buffer_pointer) + header->flatbuffer_offset; + flatbuffer_size = header->flatbuffer_size; constant_data = reinterpret_cast(buffer_pointer) + header->constant_data_offset; } else if (header.error() == Error::NotFound) { flatbuffer_data = reinterpret_cast(buffer_pointer); + flatbuffer_size = num_bytes; } else { ET_LOG(Error, "XNNHeader may be corrupt"); return header.error(); @@ -1843,6 +1846,15 @@ ET_NODISCARD Error XNNCompiler::compileModel( "XNNPACK Delegate Serialization Format version identifier '%.4s' != expected XN00 or XN01'", flatbuffers::GetBufferIdentifier(flatbuffer_data)); + // Verify the FlatBuffer data integrity before accessing it. Without this, + // malformed data could cause out-of-bounds reads when traversing the + // FlatBuffer's internal offset tables. + flatbuffers::Verifier verifier(flatbuffer_data, flatbuffer_size); + ET_CHECK_OR_RETURN_ERROR( + verifier.VerifyBuffer(nullptr), + DelegateInvalidCompatibility, + "FlatBuffer verification failed; data may be truncated or corrupt"); + auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(flatbuffer_data); // initialize xnnpack xnn_status status = xnn_initialize(/*allocator =*/nullptr); diff --git a/backends/xnnpack/runtime/XNNHeader.cpp b/backends/xnnpack/runtime/XNNHeader.cpp index 9397948c55d..59e74655565 100644 --- a/backends/xnnpack/runtime/XNNHeader.cpp +++ b/backends/xnnpack/runtime/XNNHeader.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -64,6 +65,48 @@ Result XNNHeader::Parse(const void* data, size_t size) { uint64_t constant_data_size = GetUInt64LE(header_data + XNNHeader::kConstantDataSizeOffset); + // Validate min flatbuffer size. + constexpr size_t kMinFlatbufferSize = + sizeof(uint32_t) + 4; // root offset + identifier + ET_CHECK_OR_RETURN_ERROR( + flatbuffer_size >= kMinFlatbufferSize, + InvalidArgument, + "flatbuffer_size %" PRIu32 " is too small (minimum %zu)", + flatbuffer_size, + kMinFlatbufferSize); + + // Validate that flatbuffer region does not overflow or exceed the buffer. + ET_CHECK_OR_RETURN_ERROR( + flatbuffer_offset <= size && flatbuffer_size <= size - flatbuffer_offset, + InvalidArgument, + "flatbuffer_offset: %" PRIu32 " and flatbuffer_size: %" PRIu32 + " are invalid for buffer of size: %zu", + flatbuffer_offset, + flatbuffer_size, + size); + // Validate that constant data region does not overflow or exceed the buffer. + ET_CHECK_OR_RETURN_ERROR( + constant_data_offset <= size && + constant_data_size <= size - constant_data_offset, + InvalidArgument, + "constant_data_offset: %" PRIu32 " and constant_data_size: %" PRIu64 + " are invalid for buffer of size: %zu", + constant_data_offset, + constant_data_size, + size); + + // Validate that constant data region does not overlap with flatbuffer region. + // flatbuffer should come before constant data. + ET_CHECK_OR_RETURN_ERROR( + constant_data_offset >= flatbuffer_offset && + constant_data_offset - flatbuffer_offset >= flatbuffer_size, + InvalidArgument, + "constant_data_offset: %" PRIu32 " and flatbuffer_offset: %" PRIu32 + " with flatbuffer_size: %" PRIu32 " are overlapping.", + constant_data_offset, + flatbuffer_offset, + flatbuffer_size); + return XNNHeader{ flatbuffer_offset, flatbuffer_size,