Skip to content

Commit fcbc108

Browse files
Hakan Boyrazfacebook-github-bot
authored andcommitted
Gate weights cache on runtime option instead of compile-time macro (#19603)
Summary: Replaces the compile-time `#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE` gate in XNNCompiler.cpp with a runtime boolean plumbed from `XnnpackBackendOptions::resolve_weight_cache(context)` through `XNNPACKBackend::init` to `XNNCompiler::compileModel`. This fixes a silent-disable bug: previously, runtime opt-in via `set_option(weight_cache_option_key, true)` was silently a no-op unless the build also set `-c executorch.xnnpack_weights_cache=1`, because the cache pointer handed to `xnn_create_runtime_v4` was hardcoded to nullptr when the macro was undefined. Multimethod LoRA models re-packed the entire backbone for every method load, costing hundreds of MB of resident memory. The runtime path now keys all three cache-relevant code regions (unpacked-data load, cache pointer handoff to xnn_create_runtime_v4, and finalize_for_runtime) on `bool use_weight_cache` resolved per-init from the BackendInitContext. The `Result<vector<string>>` declaration in compileModel was reshaped to plain `vector<string>` since `Result<>` is non-assignable, which is required for the new runtime branch. Reviewed By: GregoryComer Differential Revision: D105123995
1 parent 174d3ad commit fcbc108

3 files changed

Lines changed: 67 additions & 59 deletions

File tree

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 63 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ Result<const uint8_t*> getConstantDataPtr(
181181
const uint8_t* constant_data_ptr,
182182
const NamedDataMap* named_data_map,
183183
std::vector<FreeableBuffer>& freeable_buffers,
184-
XNNWeightsCache* weights_cache) {
184+
XNNWeightsCache* weights_cache,
185+
bool use_weight_cache) {
185186
if (buffer_idx) {
186187
if (!constant_data_ptr) {
187188
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
@@ -230,30 +231,30 @@ Result<const uint8_t*> getConstantDataPtr(
230231
InvalidProgram,
231232
"Named key is null");
232233
const std::string& data_name = constant_data_offset->named_key()->str();
233-
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
234-
Result<const uint8_t*> data_ptr =
235-
weights_cache->load_unpacked_data(data_name);
236-
if (!data_ptr.ok()) {
237-
ET_LOG(Error, "Failed to load weights from cache");
238-
return data_ptr.error();
239-
}
240-
return data_ptr.get();
241-
#else
242-
Result<FreeableBuffer> buffer =
243-
named_data_map->get_data(data_name.c_str());
244-
if (!buffer.ok()) {
245-
ET_LOG(
246-
Error,
247-
"Failed to get constant data for key %s from named_data_map. Error code: %u",
248-
data_name.c_str(),
249-
static_cast<uint32_t>(buffer.error()));
250-
return buffer.error();
234+
if (use_weight_cache) {
235+
Result<const uint8_t*> data_ptr =
236+
weights_cache->load_unpacked_data(data_name);
237+
if (!data_ptr.ok()) {
238+
ET_LOG(Error, "Failed to load weights from cache");
239+
return data_ptr.error();
240+
}
241+
return data_ptr.get();
242+
} else {
243+
Result<FreeableBuffer> buffer =
244+
named_data_map->get_data(data_name.c_str());
245+
if (!buffer.ok()) {
246+
ET_LOG(
247+
Error,
248+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
249+
data_name.c_str(),
250+
static_cast<uint32_t>(buffer.error()));
251+
return buffer.error();
252+
}
253+
const uint8_t* data_ptr =
254+
static_cast<const uint8_t*>(buffer.get().data());
255+
freeable_buffers.push_back(std::move(buffer.get()));
256+
return data_ptr;
251257
}
252-
const uint8_t* data_ptr =
253-
static_cast<const uint8_t*>(buffer.get().data());
254-
freeable_buffers.push_back(std::move(buffer.get()));
255-
return data_ptr;
256-
#endif
257258
}
258259
}
259260
}
@@ -267,14 +268,16 @@ Result<const uint8_t*> getConstantDataPtr(
267268
const uint8_t* constant_data_ptr,
268269
const NamedDataMap* named_data_map,
269270
std::vector<FreeableBuffer>& freeable_buffers,
270-
XNNWeightsCache* weights_cache) {
271+
XNNWeightsCache* weights_cache,
272+
bool use_weight_cache) {
271273
return getConstantDataPtr(
272274
tensor_value->constant_buffer_idx(),
273275
flatbuffer_graph,
274276
constant_data_ptr,
275277
named_data_map,
276278
freeable_buffers,
277-
weights_cache);
279+
weights_cache,
280+
use_weight_cache);
278281
}
279282

280283
/**
@@ -293,7 +296,8 @@ Error defineTensor(
293296
CompileAllocator& allocator,
294297
const NamedDataMap* named_data_map,
295298
std::vector<FreeableBuffer>& freeable_buffers,
296-
XNNWeightsCache* weights_cache) {
299+
XNNWeightsCache* weights_cache,
300+
bool use_weight_cache) {
297301
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
298302
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;
299303

@@ -347,7 +351,8 @@ Error defineTensor(
347351
constant_data_ptr,
348352
named_data_map,
349353
freeable_buffers,
350-
weights_cache);
354+
weights_cache,
355+
use_weight_cache);
351356
if (!buffer_result.ok()) {
352357
return buffer_result.error();
353358
}
@@ -502,7 +507,8 @@ Error defineTensor(
502507
constant_data_ptr,
503508
named_data_map,
504509
freeable_buffers,
505-
weights_cache);
510+
weights_cache,
511+
use_weight_cache);
506512
if (!scale_result.ok()) {
507513
return scale_result.error();
508514
}
@@ -548,7 +554,8 @@ Error defineTensor(
548554
constant_data_ptr,
549555
named_data_map,
550556
freeable_buffers,
551-
weights_cache);
557+
weights_cache,
558+
use_weight_cache);
552559
if (!scale_data_result.ok()) {
553560
return scale_data_result.error();
554561
}
@@ -1976,7 +1983,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
19761983
XNNExecutor* executor,
19771984
XNNWeightsCache* weights_cache,
19781985
xnn_workspace_t workspace,
1979-
const NamedDataMap* named_data_map) {
1986+
const NamedDataMap* named_data_map,
1987+
bool use_weight_cache) {
19801988
Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
19811989
const uint8_t* flatbuffer_data = nullptr;
19821990
const uint8_t* constant_data = nullptr;
@@ -2086,7 +2094,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20862094
compile_allocator,
20872095
named_data_map,
20882096
unpacked_buffers,
2089-
weights_cache);
2097+
weights_cache,
2098+
use_weight_cache);
20902099

20912100
if (err != Error::Ok) {
20922101
return err;
@@ -2108,19 +2117,16 @@ ET_NODISCARD Error XNNCompiler::compileModel(
21082117

21092118
xnn_runtime_t runtime_ptr = nullptr;
21102119

2111-
// XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache
2112-
// just manages the unpacked weights until the runtime is created.
2113-
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2114-
ET_CHECK_OR_RETURN_ERROR(
2115-
unpacked_buffers.size() == 0,
2116-
Internal,
2117-
"Weight Cache is enabled, which means unpacked buffers should be owned by the cache");
2118-
xnn_weights_cache_t weights_cache_ptr =
2119-
weights_cache->get_num_unpacked_data() > 0 ? weights_cache->get()
2120-
: nullptr;
2121-
#else
21222120
xnn_weights_cache_t weights_cache_ptr = nullptr;
2123-
#endif
2121+
if (use_weight_cache) {
2122+
ET_CHECK_OR_RETURN_ERROR(
2123+
unpacked_buffers.size() == 0,
2124+
Internal,
2125+
"Weight Cache is enabled, which means unpacked buffers should be owned by the cache");
2126+
weights_cache_ptr = weights_cache->get_num_unpacked_data() > 0
2127+
? weights_cache->get()
2128+
: nullptr;
2129+
}
21242130

21252131
// NOLINTBEGIN(facebook-hte-NullableDereference) - weights cache is allowed to
21262132
// be null
@@ -2139,25 +2145,25 @@ ET_NODISCARD Error XNNCompiler::compileModel(
21392145
"XNN Runtime creation failed with code: %s",
21402146
xnn_status_to_string(status));
21412147

2142-
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2143-
auto packed_weights_names = weights_cache->finalize_for_runtime();
2144-
ET_CHECK_OR_RETURN_ERROR(
2145-
packed_weights_names.ok(),
2146-
Internal,
2147-
"Failed to finalize weights cache after creating the xnn runtime")
2148-
#else
2149-
for (auto& buffer : unpacked_buffers) {
2150-
buffer.Free();
2148+
std::vector<std::string> packed_weights_names;
2149+
if (use_weight_cache) {
2150+
auto packed_weights_names_result = weights_cache->finalize_for_runtime();
2151+
ET_CHECK_OR_RETURN_ERROR(
2152+
packed_weights_names_result.ok(),
2153+
Internal,
2154+
"Failed to finalize weights cache after creating the xnn runtime");
2155+
packed_weights_names = std::move(packed_weights_names_result.get());
2156+
} else {
2157+
for (auto& buffer : unpacked_buffers) {
2158+
buffer.Free();
2159+
}
21512160
}
2152-
Result<std::vector<std::string>> packed_weights_names =
2153-
std::vector<std::string>();
2154-
#endif
21552161

21562162
err = executor->initialize( // NOLINT: runtime_ptr is non-null
21572163
runtime_ptr,
21582164
std::move(input_ids),
21592165
std::move(output_ids),
2160-
std::move(packed_weights_names.get()));
2166+
std::move(packed_weights_names));
21612167

21622168
return err;
21632169
};

backends/xnnpack/runtime/XNNCompiler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class XNNCompiler {
2929
XNNExecutor* executor,
3030
XNNWeightsCache* weights_cache,
3131
xnn_workspace_t workspace,
32-
const NamedDataMap* named_data_map);
32+
const NamedDataMap* named_data_map,
33+
bool use_weight_cache);
3334
};
3435

3536
} // namespace delegate

backends/xnnpack/runtime/XNNPACKBackend.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ class XnnpackBackend final
110110
executor,
111111
weights_cache_.get(),
112112
workspace_ptr,
113-
named_data_map);
113+
named_data_map,
114+
use_weight_cache);
114115
// This backend does not need its processed data after compiling the model.
115116
processed->Free();
116117

0 commit comments

Comments
 (0)