Skip to content

Commit ca37d7d

Browse files
author
Github Executorch
committed
Fix TOB-EXECUTORCH-41: validate buffer_idx bounds in getConstantDataPtr
Add bounds checking on buffer_idx in both constant_buffer and constant_data code paths of getConstantDataPtr to prevent out-of-bounds vector access from malicious flatbuffer inputs. Authored-with: Claude
1 parent 21d9c64 commit ca37d7d

1 file changed

Lines changed: 52 additions & 14 deletions

File tree

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ Obtaining the constant data pointer can either be from within the flatbuffer
173173
payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174174
data associated with the tensor value, then returns nullptr.
175175
*/
176-
const uint8_t* getConstantDataPtr(
176+
Result<const uint8_t*> getConstantDataPtr(
177177
uint32_t buffer_idx,
178178
GraphPtr flatbuffer_graph,
179179
const uint8_t* constant_data_ptr,
@@ -184,13 +184,38 @@ const uint8_t* getConstantDataPtr(
184184
if (!constant_data_ptr) {
185185
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
186186
// window
187-
const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
188-
return constant_buffer[buffer_idx]->storage()->data();
187+
auto* cb = flatbuffer_graph->constant_buffer();
188+
if (cb == nullptr || buffer_idx >= cb->size()) {
189+
ET_LOG(
190+
Error,
191+
"Invalid buffer_idx %u for constant_buffer of size %u",
192+
buffer_idx,
193+
cb ? cb->size() : 0);
194+
return Error::InvalidProgram;
195+
}
196+
auto* buffer_entry = (*cb)[buffer_idx];
197+
if (buffer_entry == nullptr || buffer_entry->storage() == nullptr) {
198+
ET_LOG(
199+
Error, "Null constant_buffer entry at buffer_idx %u", buffer_idx);
200+
return Error::InvalidProgram;
201+
}
202+
return buffer_entry->storage()->data();
189203
} else {
190-
ConstantDataOffsetPtr constant_data_offset =
191-
flatbuffer_graph->constant_data()->Get(buffer_idx);
204+
auto* cd = flatbuffer_graph->constant_data();
205+
if (cd == nullptr || buffer_idx >= cd->size()) {
206+
ET_LOG(
207+
Error,
208+
"Invalid buffer_idx %u for constant_data of size %u",
209+
buffer_idx,
210+
cd ? cd->size() : 0);
211+
return Error::InvalidProgram;
212+
}
213+
ConstantDataOffsetPtr constant_data_offset = cd->Get(buffer_idx);
214+
if (constant_data_offset == nullptr) {
215+
ET_LOG(Error, "Null constant_data entry at buffer_idx %u", buffer_idx);
216+
return Error::InvalidProgram;
217+
}
192218
uint64_t offset = constant_data_offset->offset();
193-
194219
bool has_named_key = flatbuffers::IsFieldPresent(
195220
constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
196221
// If there is no tensor name
@@ -203,7 +228,7 @@ const uint8_t* getConstantDataPtr(
203228
weights_cache->load_unpacked_data(data_name);
204229
if (!data_ptr.ok()) {
205230
ET_LOG(Error, "Failed to load weights from cache");
206-
return nullptr;
231+
return Error::InvalidProgram;
207232
}
208233
return data_ptr.get();
209234
#else
@@ -215,7 +240,7 @@ const uint8_t* getConstantDataPtr(
215240
"Failed to get constant data for key %s from named_data_map. Error code: %u",
216241
data_name.c_str(),
217242
static_cast<uint32_t>(buffer.error()));
218-
return nullptr;
243+
return Error::InvalidProgram;
219244
}
220245
const uint8_t* data_ptr =
221246
static_cast<const uint8_t*>(buffer.get().data());
@@ -229,7 +254,7 @@ const uint8_t* getConstantDataPtr(
229254
return nullptr;
230255
}
231256

232-
const uint8_t* getConstantDataPtr(
257+
Result<const uint8_t*> getConstantDataPtr(
233258
const fb_xnnpack::XNNTensorValue* tensor_value,
234259
GraphPtr flatbuffer_graph,
235260
const uint8_t* constant_data_ptr,
@@ -298,13 +323,17 @@ Error defineTensor(
298323

299324
// Get Pointer to constant data from flatbuffer, if its non-constant
300325
// it is a nullptr
301-
const uint8_t* buffer_ptr = getConstantDataPtr(
326+
auto buffer_result = getConstantDataPtr(
302327
tensor_value,
303328
flatbuffer_graph,
304329
constant_data_ptr,
305330
named_data_map,
306331
freeable_buffers,
307332
weights_cache);
333+
if (!buffer_result.ok()) {
334+
return buffer_result.error();
335+
}
336+
const uint8_t* buffer_ptr = buffer_result.get();
308337

309338
xnn_status status;
310339
// The type we might have to convert to
@@ -449,13 +478,17 @@ Error defineTensor(
449478
const float* scale = qparams->scale()->data();
450479

451480
if (qparams->scale_buffer_idx() != 0) {
452-
scale = reinterpret_cast<const float*>(getConstantDataPtr(
481+
auto scale_result = getConstantDataPtr(
453482
qparams->scale_buffer_idx(),
454483
flatbuffer_graph,
455484
constant_data_ptr,
456485
named_data_map,
457486
freeable_buffers,
458-
weights_cache));
487+
weights_cache);
488+
if (!scale_result.ok()) {
489+
return scale_result.error();
490+
}
491+
scale = reinterpret_cast<const float*>(scale_result.get());
459492
ET_CHECK_OR_RETURN_ERROR(
460493
scale != nullptr, Internal, "Failed to load scale data.");
461494
}
@@ -491,13 +524,18 @@ Error defineTensor(
491524
// Block scales are preferably serialized as bf16 but can also be
492525
// serialized as fp32 for backwards compatability.
493526
if (qparams->scale_buffer_idx() != 0) {
494-
scale_data = reinterpret_cast<const uint16_t*>(getConstantDataPtr(
527+
auto scale_data_result = getConstantDataPtr(
495528
qparams->scale_buffer_idx(),
496529
flatbuffer_graph,
497530
constant_data_ptr,
498531
named_data_map,
499532
freeable_buffers,
500-
weights_cache));
533+
weights_cache);
534+
if (!scale_data_result.ok()) {
535+
return scale_data_result.error();
536+
}
537+
scale_data =
538+
reinterpret_cast<const uint16_t*>(scale_data_result.get());
501539
ET_CHECK_OR_RETURN_ERROR(
502540
scale_data != nullptr, Internal, "Failed to load scale data.");
503541
scale_numel = qparams->num_scales();

0 commit comments

Comments
 (0)