Skip to content

Commit e85653b

Browse files
author
Github Executorch
committed
Fix TOB-EXECUTORCH-41: validate buffer_idx bounds in getConstantDataPtr
Add bounds checking on buffer_idx in both constant_buffer and constant_data code paths of getConstantDataPtr to prevent out-of-bounds vector access from malicious flatbuffer inputs. Authored-with: Claude
1 parent 5e8a0df commit e85653b

1 file changed

Lines changed: 53 additions & 14 deletions

File tree

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ Obtaining the constant data pointer can either be from within the flatbuffer
173173
payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174174
data associated with the tensor value, then returns nullptr.
175175
*/
176-
const uint8_t* getConstantDataPtr(
176+
Result<const uint8_t*> getConstantDataPtr(
177177
uint32_t buffer_idx,
178178
GraphPtr flatbuffer_graph,
179179
const uint8_t* constant_data_ptr,
@@ -184,13 +184,39 @@ const uint8_t* getConstantDataPtr(
184184
if (!constant_data_ptr) {
185185
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
186186
// window
187-
const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
188-
return constant_buffer[buffer_idx]->storage()->data();
187+
auto* cb = flatbuffer_graph->constant_buffer();
188+
ET_CHECK_OR_RETURN_ERROR(
189+
cb != nullptr, InvalidProgram, "constant_buffer is null");
190+
ET_CHECK_OR_RETURN_ERROR(
191+
buffer_idx < cb->size(),
192+
InvalidProgram,
193+
"buffer_idx %u out of bounds for constant_buffer of size %u",
194+
buffer_idx,
195+
cb->size());
196+
auto* buffer_entry = (*cb)[buffer_idx];
197+
ET_CHECK_OR_RETURN_ERROR(
198+
buffer_entry != nullptr && buffer_entry->storage() != nullptr,
199+
InvalidProgram,
200+
"Null constant_buffer entry at buffer_idx %u",
201+
buffer_idx);
202+
return buffer_entry->storage()->data();
189203
} else {
190-
ConstantDataOffsetPtr constant_data_offset =
191-
flatbuffer_graph->constant_data()->Get(buffer_idx);
204+
auto* cd = flatbuffer_graph->constant_data();
205+
ET_CHECK_OR_RETURN_ERROR(
206+
cd != nullptr, InvalidProgram, "constant_data is null");
207+
ET_CHECK_OR_RETURN_ERROR(
208+
buffer_idx < cd->size(),
209+
InvalidProgram,
210+
"buffer_idx %u out of bounds for constant_data of size %u",
211+
buffer_idx,
212+
cd->size());
213+
ConstantDataOffsetPtr constant_data_offset = cd->Get(buffer_idx);
214+
ET_CHECK_OR_RETURN_ERROR(
215+
constant_data_offset != nullptr,
216+
InvalidProgram,
217+
"Null constant_data entry at buffer_idx %u",
218+
buffer_idx);
192219
uint64_t offset = constant_data_offset->offset();
193-
194220
bool has_named_key = flatbuffers::IsFieldPresent(
195221
constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
196222
// If there is no tensor name
@@ -203,7 +229,7 @@ const uint8_t* getConstantDataPtr(
203229
weights_cache->load_unpacked_data(data_name);
204230
if (!data_ptr.ok()) {
205231
ET_LOG(Error, "Failed to load weights from cache");
206-
return nullptr;
232+
return Error::InvalidProgram;
207233
}
208234
return data_ptr.get();
209235
#else
@@ -215,7 +241,7 @@ const uint8_t* getConstantDataPtr(
215241
"Failed to get constant data for key %s from named_data_map. Error code: %u",
216242
data_name.c_str(),
217243
static_cast<uint32_t>(buffer.error()));
218-
return nullptr;
244+
return Error::InvalidProgram;
219245
}
220246
const uint8_t* data_ptr =
221247
static_cast<const uint8_t*>(buffer.get().data());
@@ -229,7 +255,7 @@ const uint8_t* getConstantDataPtr(
229255
return nullptr;
230256
}
231257

232-
const uint8_t* getConstantDataPtr(
258+
Result<const uint8_t*> getConstantDataPtr(
233259
const fb_xnnpack::XNNTensorValue* tensor_value,
234260
GraphPtr flatbuffer_graph,
235261
const uint8_t* constant_data_ptr,
@@ -298,13 +324,17 @@ Error defineTensor(
298324

299325
// Get Pointer to constant data from flatbuffer, if its non-constant
300326
// it is a nullptr
301-
const uint8_t* buffer_ptr = getConstantDataPtr(
327+
auto buffer_result = getConstantDataPtr(
302328
tensor_value,
303329
flatbuffer_graph,
304330
constant_data_ptr,
305331
named_data_map,
306332
freeable_buffers,
307333
weights_cache);
334+
if (!buffer_result.ok()) {
335+
return buffer_result.error();
336+
}
337+
const uint8_t* buffer_ptr = buffer_result.get();
308338

309339
xnn_status status;
310340
// The type we might have to convert to
@@ -449,13 +479,17 @@ Error defineTensor(
449479
const float* scale = qparams->scale()->data();
450480

451481
if (qparams->scale_buffer_idx() != 0) {
452-
scale = reinterpret_cast<const float*>(getConstantDataPtr(
482+
auto scale_result = getConstantDataPtr(
453483
qparams->scale_buffer_idx(),
454484
flatbuffer_graph,
455485
constant_data_ptr,
456486
named_data_map,
457487
freeable_buffers,
458-
weights_cache));
488+
weights_cache);
489+
if (!scale_result.ok()) {
490+
return scale_result.error();
491+
}
492+
scale = reinterpret_cast<const float*>(scale_result.get());
459493
ET_CHECK_OR_RETURN_ERROR(
460494
scale != nullptr, Internal, "Failed to load scale data.");
461495
}
@@ -491,13 +525,18 @@ Error defineTensor(
491525
// Block scales are preferably serialized as bf16 but can also be
492526
// serialized as fp32 for backwards compatability.
493527
if (qparams->scale_buffer_idx() != 0) {
494-
scale_data = reinterpret_cast<const uint16_t*>(getConstantDataPtr(
528+
auto scale_data_result = getConstantDataPtr(
495529
qparams->scale_buffer_idx(),
496530
flatbuffer_graph,
497531
constant_data_ptr,
498532
named_data_map,
499533
freeable_buffers,
500-
weights_cache));
534+
weights_cache);
535+
if (!scale_data_result.ok()) {
536+
return scale_data_result.error();
537+
}
538+
scale_data =
539+
reinterpret_cast<const uint16_t*>(scale_data_result.get());
501540
ET_CHECK_OR_RETURN_ERROR(
502541
scale_data != nullptr, Internal, "Failed to load scale data.");
503542
scale_numel = qparams->num_scales();

0 commit comments

Comments
 (0)