@@ -181,7 +181,8 @@ Result<const uint8_t*> getConstantDataPtr(
181181 const uint8_t * constant_data_ptr,
182182 const NamedDataMap* named_data_map,
183183 std::vector<FreeableBuffer>& freeable_buffers,
184- XNNWeightsCache* weights_cache) {
184+ XNNWeightsCache* weights_cache,
185+ bool use_weight_cache) {
185186 if (buffer_idx) {
186187 if (!constant_data_ptr) {
187188 // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
@@ -230,30 +231,30 @@ Result<const uint8_t*> getConstantDataPtr(
230231 InvalidProgram,
231232 " Named key is null" );
232233 const std::string& data_name = constant_data_offset->named_key ()->str ();
233- #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
234- Result<const uint8_t *> data_ptr =
235- weights_cache->load_unpacked_data (data_name);
236- if (!data_ptr.ok ()) {
237- ET_LOG (Error, " Failed to load weights from cache" );
238- return data_ptr.error ();
239- }
240- return data_ptr.get ();
241- #else
242- Result<FreeableBuffer> buffer =
243- named_data_map->get_data (data_name.c_str ());
244- if (!buffer.ok ()) {
245- ET_LOG (
246- Error,
247- " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
248- data_name.c_str (),
249- static_cast <uint32_t >(buffer.error ()));
250- return buffer.error ();
234+ if (use_weight_cache) {
235+ Result<const uint8_t *> data_ptr =
236+ weights_cache->load_unpacked_data (data_name);
237+ if (!data_ptr.ok ()) {
238+ ET_LOG (Error, " Failed to load weights from cache" );
239+ return data_ptr.error ();
240+ }
241+ return data_ptr.get ();
242+ } else {
243+ Result<FreeableBuffer> buffer =
244+ named_data_map->get_data (data_name.c_str ());
245+ if (!buffer.ok ()) {
246+ ET_LOG (
247+ Error,
248+ " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
249+ data_name.c_str (),
250+ static_cast <uint32_t >(buffer.error ()));
251+ return buffer.error ();
252+ }
253+ const uint8_t * data_ptr =
254+ static_cast <const uint8_t *>(buffer.get ().data ());
255+ freeable_buffers.push_back (std::move (buffer.get ()));
256+ return data_ptr;
251257 }
252- const uint8_t * data_ptr =
253- static_cast <const uint8_t *>(buffer.get ().data ());
254- freeable_buffers.push_back (std::move (buffer.get ()));
255- return data_ptr;
256- #endif
257258 }
258259 }
259260 }
@@ -267,14 +268,16 @@ Result<const uint8_t*> getConstantDataPtr(
267268 const uint8_t * constant_data_ptr,
268269 const NamedDataMap* named_data_map,
269270 std::vector<FreeableBuffer>& freeable_buffers,
270- XNNWeightsCache* weights_cache) {
271+ XNNWeightsCache* weights_cache,
272+ bool use_weight_cache) {
271273 return getConstantDataPtr (
272274 tensor_value->constant_buffer_idx (),
273275 flatbuffer_graph,
274276 constant_data_ptr,
275277 named_data_map,
276278 freeable_buffers,
277- weights_cache);
279+ weights_cache,
280+ use_weight_cache);
278281}
279282
280283/* *
@@ -293,7 +296,8 @@ Error defineTensor(
293296 CompileAllocator& allocator,
294297 const NamedDataMap* named_data_map,
295298 std::vector<FreeableBuffer>& freeable_buffers,
296- XNNWeightsCache* weights_cache) {
299+ XNNWeightsCache* weights_cache,
300+ bool use_weight_cache) {
297301 const fb_xnnpack::XNNTensorValue* tensor_value = nullptr ;
298302 const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr ;
299303
@@ -347,7 +351,8 @@ Error defineTensor(
347351 constant_data_ptr,
348352 named_data_map,
349353 freeable_buffers,
350- weights_cache);
354+ weights_cache,
355+ use_weight_cache);
351356 if (!buffer_result.ok ()) {
352357 return buffer_result.error ();
353358 }
@@ -502,7 +507,8 @@ Error defineTensor(
502507 constant_data_ptr,
503508 named_data_map,
504509 freeable_buffers,
505- weights_cache);
510+ weights_cache,
511+ use_weight_cache);
506512 if (!scale_result.ok ()) {
507513 return scale_result.error ();
508514 }
@@ -548,7 +554,8 @@ Error defineTensor(
548554 constant_data_ptr,
549555 named_data_map,
550556 freeable_buffers,
551- weights_cache);
557+ weights_cache,
558+ use_weight_cache);
552559 if (!scale_data_result.ok ()) {
553560 return scale_data_result.error ();
554561 }
@@ -1976,7 +1983,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
19761983 XNNExecutor* executor,
19771984 XNNWeightsCache* weights_cache,
19781985 xnn_workspace_t workspace,
1979- const NamedDataMap* named_data_map) {
1986+ const NamedDataMap* named_data_map,
1987+ bool use_weight_cache) {
19801988 Result<XNNHeader> header = XNNHeader::Parse (buffer_pointer, num_bytes);
19811989 const uint8_t * flatbuffer_data = nullptr ;
19821990 const uint8_t * constant_data = nullptr ;
@@ -2086,7 +2094,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20862094 compile_allocator,
20872095 named_data_map,
20882096 unpacked_buffers,
2089- weights_cache);
2097+ weights_cache,
2098+ use_weight_cache);
20902099
20912100 if (err != Error::Ok) {
20922101 return err;
@@ -2108,19 +2117,16 @@ ET_NODISCARD Error XNNCompiler::compileModel(
21082117
21092118 xnn_runtime_t runtime_ptr = nullptr ;
21102119
2111- // XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache
2112- // just manages the unpacked weights until the runtime is created.
2113- #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2114- ET_CHECK_OR_RETURN_ERROR (
2115- unpacked_buffers.size () == 0 ,
2116- Internal,
2117- " Weight Cache is enabled, which means unpacked buffers should be owned by the cache" );
2118- xnn_weights_cache_t weights_cache_ptr =
2119- weights_cache->get_num_unpacked_data () > 0 ? weights_cache->get ()
2120- : nullptr ;
2121- #else
21222120 xnn_weights_cache_t weights_cache_ptr = nullptr ;
2123- #endif
2121+ if (use_weight_cache) {
2122+ ET_CHECK_OR_RETURN_ERROR (
2123+ unpacked_buffers.size () == 0 ,
2124+ Internal,
2125+ " Weight Cache is enabled, which means unpacked buffers should be owned by the cache" );
2126+ weights_cache_ptr = weights_cache->get_num_unpacked_data () > 0
2127+ ? weights_cache->get ()
2128+ : nullptr ;
2129+ }
21242130
21252131 // NOLINTBEGIN(facebook-hte-NullableDereference) - weights cache is allowed to
21262132 // be null
@@ -2139,25 +2145,25 @@ ET_NODISCARD Error XNNCompiler::compileModel(
21392145 " XNN Runtime creation failed with code: %s" ,
21402146 xnn_status_to_string (status));
21412147
2142- #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2143- auto packed_weights_names = weights_cache->finalize_for_runtime ();
2144- ET_CHECK_OR_RETURN_ERROR (
2145- packed_weights_names.ok (),
2146- Internal,
2147- " Failed to finalize weights cache after creating the xnn runtime" )
2148- #else
2149- for (auto & buffer : unpacked_buffers) {
2150- buffer.Free ();
2148+ std::vector<std::string> packed_weights_names;
2149+ if (use_weight_cache) {
2150+ auto packed_weights_names_result = weights_cache->finalize_for_runtime ();
2151+ ET_CHECK_OR_RETURN_ERROR (
2152+ packed_weights_names_result.ok (),
2153+ Internal,
2154+ " Failed to finalize weights cache after creating the xnn runtime" );
2155+ packed_weights_names = std::move (packed_weights_names_result.get ());
2156+ } else {
2157+ for (auto & buffer : unpacked_buffers) {
2158+ buffer.Free ();
2159+ }
21512160 }
2152- Result<std::vector<std::string>> packed_weights_names =
2153- std::vector<std::string>();
2154- #endif
21552161
21562162 err = executor->initialize ( // NOLINT: runtime_ptr is non-null
21572163 runtime_ptr,
21582164 std::move (input_ids),
21592165 std::move (output_ids),
2160- std::move (packed_weights_names. get () ));
2166+ std::move (packed_weights_names));
21612167
21622168 return err;
21632169};
0 commit comments