@@ -173,7 +173,7 @@ Obtaining the constant data pointer can either be from within the flatbuffer
173173payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174174data associated with the tensor value, then returns nullptr.
175175*/
176- const uint8_t * getConstantDataPtr (
176+ Result< const uint8_t *> getConstantDataPtr (
177177 uint32_t buffer_idx,
178178 GraphPtr flatbuffer_graph,
179179 const uint8_t * constant_data_ptr,
@@ -184,13 +184,39 @@ const uint8_t* getConstantDataPtr(
184184 if (!constant_data_ptr) {
185185 // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
186186 // window
187- const auto & constant_buffer = *flatbuffer_graph->constant_buffer ();
188- return constant_buffer[buffer_idx]->storage ()->data ();
187+ auto * cb = flatbuffer_graph->constant_buffer ();
188+ ET_CHECK_OR_RETURN_ERROR (
189+ cb != nullptr , InvalidProgram, " constant_buffer is null" );
190+ ET_CHECK_OR_RETURN_ERROR (
191+ buffer_idx < cb->size (),
192+ InvalidProgram,
193+ " buffer_idx %u out of bounds for constant_buffer of size %u" ,
194+ buffer_idx,
195+ cb->size ());
196+ auto * buffer_entry = (*cb)[buffer_idx];
197+ ET_CHECK_OR_RETURN_ERROR (
198+ buffer_entry != nullptr && buffer_entry->storage () != nullptr ,
199+ InvalidProgram,
200+ " Null constant_buffer entry at buffer_idx %u" ,
201+ buffer_idx);
202+ return buffer_entry->storage ()->data ();
189203 } else {
190- ConstantDataOffsetPtr constant_data_offset =
191- flatbuffer_graph->constant_data ()->Get (buffer_idx);
204+ auto * cd = flatbuffer_graph->constant_data ();
205+ ET_CHECK_OR_RETURN_ERROR (
206+ cd != nullptr , InvalidProgram, " constant_data is null" );
207+ ET_CHECK_OR_RETURN_ERROR (
208+ buffer_idx < cd->size (),
209+ InvalidProgram,
210+ " buffer_idx %u out of bounds for constant_data of size %u" ,
211+ buffer_idx,
212+ cd->size ());
213+ ConstantDataOffsetPtr constant_data_offset = cd->Get (buffer_idx);
214+ ET_CHECK_OR_RETURN_ERROR (
215+ constant_data_offset != nullptr ,
216+ InvalidProgram,
217+ " Null constant_data entry at buffer_idx %u" ,
218+ buffer_idx);
192219 uint64_t offset = constant_data_offset->offset ();
193-
194220 bool has_named_key = flatbuffers::IsFieldPresent (
195221 constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
196222 // If there is no tensor name
@@ -203,7 +229,7 @@ const uint8_t* getConstantDataPtr(
203229 weights_cache->load_unpacked_data (data_name);
204230 if (!data_ptr.ok ()) {
205231 ET_LOG (Error, " Failed to load weights from cache" );
206- return nullptr ;
232+ return Error::InvalidProgram ;
207233 }
208234 return data_ptr.get ();
209235#else
@@ -215,7 +241,7 @@ const uint8_t* getConstantDataPtr(
215241 " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
216242 data_name.c_str (),
217243 static_cast <uint32_t >(buffer.error ()));
218- return nullptr ;
244+ return Error::InvalidProgram ;
219245 }
220246 const uint8_t * data_ptr =
221247 static_cast <const uint8_t *>(buffer.get ().data ());
@@ -229,7 +255,7 @@ const uint8_t* getConstantDataPtr(
229255 return nullptr ;
230256}
231257
232- const uint8_t * getConstantDataPtr (
258+ Result< const uint8_t *> getConstantDataPtr (
233259 const fb_xnnpack::XNNTensorValue* tensor_value,
234260 GraphPtr flatbuffer_graph,
235261 const uint8_t * constant_data_ptr,
@@ -298,13 +324,17 @@ Error defineTensor(
298324
299325 // Get Pointer to constant data from flatbuffer, if its non-constant
300326 // it is a nullptr
301- const uint8_t * buffer_ptr = getConstantDataPtr (
327+ auto buffer_result = getConstantDataPtr (
302328 tensor_value,
303329 flatbuffer_graph,
304330 constant_data_ptr,
305331 named_data_map,
306332 freeable_buffers,
307333 weights_cache);
334+ if (!buffer_result.ok ()) {
335+ return buffer_result.error ();
336+ }
337+ const uint8_t * buffer_ptr = buffer_result.get ();
308338
309339 xnn_status status;
310340 // The type we might have to convert to
@@ -449,13 +479,17 @@ Error defineTensor(
449479 const float * scale = qparams->scale ()->data ();
450480
451481 if (qparams->scale_buffer_idx () != 0 ) {
452- scale = reinterpret_cast < const float *>( getConstantDataPtr (
482+ auto scale_result = getConstantDataPtr (
453483 qparams->scale_buffer_idx (),
454484 flatbuffer_graph,
455485 constant_data_ptr,
456486 named_data_map,
457487 freeable_buffers,
458- weights_cache));
488+ weights_cache);
489+ if (!scale_result.ok ()) {
490+ return scale_result.error ();
491+ }
492+ scale = reinterpret_cast <const float *>(scale_result.get ());
459493 ET_CHECK_OR_RETURN_ERROR (
460494 scale != nullptr , Internal, " Failed to load scale data." );
461495 }
@@ -491,13 +525,18 @@ Error defineTensor(
491525 // Block scales are preferably serialized as bf16 but can also be
492526 // serialized as fp32 for backwards compatability.
493527 if (qparams->scale_buffer_idx () != 0 ) {
494- scale_data = reinterpret_cast < const uint16_t *>( getConstantDataPtr (
528+ auto scale_data_result = getConstantDataPtr (
495529 qparams->scale_buffer_idx (),
496530 flatbuffer_graph,
497531 constant_data_ptr,
498532 named_data_map,
499533 freeable_buffers,
500- weights_cache));
534+ weights_cache);
535+ if (!scale_data_result.ok ()) {
536+ return scale_data_result.error ();
537+ }
538+ scale_data =
539+ reinterpret_cast <const uint16_t *>(scale_data_result.get ());
501540 ET_CHECK_OR_RETURN_ERROR (
502541 scale_data != nullptr , Internal, " Failed to load scale data." );
503542 scale_numel = qparams->num_scales ();
0 commit comments