@@ -173,7 +173,7 @@ Obtaining the constant data pointer can either be from within the flatbuffer
173173payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174174data associated with the tensor value, then returns nullptr.
175175*/
176- const uint8_t * getConstantDataPtr (
176+ Result< const uint8_t *> getConstantDataPtr (
177177 uint32_t buffer_idx,
178178 GraphPtr flatbuffer_graph,
179179 const uint8_t * constant_data_ptr,
@@ -184,13 +184,38 @@ const uint8_t* getConstantDataPtr(
184184 if (!constant_data_ptr) {
185185 // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
186186 // window
187- const auto & constant_buffer = *flatbuffer_graph->constant_buffer ();
188- return constant_buffer[buffer_idx]->storage ()->data ();
187+ auto * cb = flatbuffer_graph->constant_buffer ();
188+ if (cb == nullptr || buffer_idx >= cb->size ()) {
189+ ET_LOG (
190+ Error,
191+ " Invalid buffer_idx %u for constant_buffer of size %u" ,
192+ buffer_idx,
193+ cb ? cb->size () : 0 );
194+ return Error::InvalidProgram;
195+ }
196+ auto * buffer_entry = (*cb)[buffer_idx];
197+ if (buffer_entry == nullptr || buffer_entry->storage () == nullptr ) {
198+ ET_LOG (
199+ Error, " Null constant_buffer entry at buffer_idx %u" , buffer_idx);
200+ return Error::InvalidProgram;
201+ }
202+ return buffer_entry->storage ()->data ();
189203 } else {
190- ConstantDataOffsetPtr constant_data_offset =
191- flatbuffer_graph->constant_data ()->Get (buffer_idx);
204+ auto * cd = flatbuffer_graph->constant_data ();
205+ if (cd == nullptr || buffer_idx >= cd->size ()) {
206+ ET_LOG (
207+ Error,
208+ " Invalid buffer_idx %u for constant_data of size %u" ,
209+ buffer_idx,
210+ cd ? cd->size () : 0 );
211+ return Error::InvalidProgram;
212+ }
213+ ConstantDataOffsetPtr constant_data_offset = cd->Get (buffer_idx);
214+ if (constant_data_offset == nullptr ) {
215+ ET_LOG (Error, " Null constant_data entry at buffer_idx %u" , buffer_idx);
216+ return Error::InvalidProgram;
217+ }
192218 uint64_t offset = constant_data_offset->offset ();
193-
194219 bool has_named_key = flatbuffers::IsFieldPresent (
195220 constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
196221 // If there is no tensor name
@@ -203,7 +228,7 @@ const uint8_t* getConstantDataPtr(
203228 weights_cache->load_unpacked_data (data_name);
204229 if (!data_ptr.ok ()) {
205230 ET_LOG (Error, " Failed to load weights from cache" );
206- return nullptr ;
231+ return Error::InvalidProgram ;
207232 }
208233 return data_ptr.get ();
209234#else
@@ -215,7 +240,7 @@ const uint8_t* getConstantDataPtr(
215240 " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
216241 data_name.c_str (),
217242 static_cast <uint32_t >(buffer.error ()));
218- return nullptr ;
243+ return Error::InvalidProgram ;
219244 }
220245 const uint8_t * data_ptr =
221246 static_cast <const uint8_t *>(buffer.get ().data ());
@@ -229,7 +254,7 @@ const uint8_t* getConstantDataPtr(
229254 return nullptr ;
230255}
231256
232- const uint8_t * getConstantDataPtr (
257+ Result< const uint8_t *> getConstantDataPtr (
233258 const fb_xnnpack::XNNTensorValue* tensor_value,
234259 GraphPtr flatbuffer_graph,
235260 const uint8_t * constant_data_ptr,
@@ -298,13 +323,17 @@ Error defineTensor(
298323
299324 // Get Pointer to constant data from flatbuffer, if its non-constant
300325 // it is a nullptr
301- const uint8_t * buffer_ptr = getConstantDataPtr (
326+ auto buffer_result = getConstantDataPtr (
302327 tensor_value,
303328 flatbuffer_graph,
304329 constant_data_ptr,
305330 named_data_map,
306331 freeable_buffers,
307332 weights_cache);
333+ if (!buffer_result.ok ()) {
334+ return buffer_result.error ();
335+ }
336+ const uint8_t * buffer_ptr = buffer_result.get ();
308337
309338 xnn_status status;
310339 // The type we might have to convert to
@@ -449,13 +478,17 @@ Error defineTensor(
449478 const float * scale = qparams->scale ()->data ();
450479
451480 if (qparams->scale_buffer_idx () != 0 ) {
452- scale = reinterpret_cast < const float *>( getConstantDataPtr (
481+ auto scale_result = getConstantDataPtr (
453482 qparams->scale_buffer_idx (),
454483 flatbuffer_graph,
455484 constant_data_ptr,
456485 named_data_map,
457486 freeable_buffers,
458- weights_cache));
487+ weights_cache);
488+ if (!scale_result.ok ()) {
489+ return scale_result.error ();
490+ }
491+ scale = reinterpret_cast <const float *>(scale_result.get ());
459492 ET_CHECK_OR_RETURN_ERROR (
460493 scale != nullptr , Internal, " Failed to load scale data." );
461494 }
@@ -491,13 +524,18 @@ Error defineTensor(
491524 // Block scales are preferably serialized as bf16 but can also be
492525 // serialized as fp32 for backwards compatability.
493526 if (qparams->scale_buffer_idx () != 0 ) {
494- scale_data = reinterpret_cast < const uint16_t *>( getConstantDataPtr (
527+ auto scale_data_result = getConstantDataPtr (
495528 qparams->scale_buffer_idx (),
496529 flatbuffer_graph,
497530 constant_data_ptr,
498531 named_data_map,
499532 freeable_buffers,
500- weights_cache));
533+ weights_cache);
534+ if (!scale_data_result.ok ()) {
535+ return scale_data_result.error ();
536+ }
537+ scale_data =
538+ reinterpret_cast <const uint16_t *>(scale_data_result.get ());
501539 ET_CHECK_OR_RETURN_ERROR (
502540 scale_data != nullptr , Internal, " Failed to load scale data." );
503541 scale_numel = qparams->num_scales ();
0 commit comments