forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathQnnExecuTorchBackend.cpp
More file actions
348 lines (321 loc) · 13 KB
/
QnnExecuTorchBackend.cpp
File metadata and controls
348 lines (321 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
/*
* Copyright (c) Qualcomm Innovation Center, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/qualcomm/aot/wrappers/TensorWrapper.h>
#include <executorch/backends/qualcomm/qc_compiler_spec_generated.h>
#include <executorch/backends/qualcomm/runtime/QnnBackendOptions.h>
#include <executorch/backends/qualcomm/runtime/QnnExecuTorchBackend.h>
#include <executorch/backends/qualcomm/runtime/QnnManager.h>
#include <executorch/backends/qualcomm/runtime/backends/QnnCustomProtocol.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
namespace executorch {
namespace backends {
namespace qnn {
using namespace qnn_delegate;
using executorch::runtime::ArrayRef;
using executorch::runtime::BackendExecutionContext;
using executorch::runtime::BackendInitContext;
using executorch::runtime::CompileSpec;
using executorch::runtime::DelegateHandle;
using executorch::runtime::EValue;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::Result;
using executorch::runtime::Span;
// ========== Public method implementations =========================
constexpr const char* QNN_COMPILE_SPEC = "qnn_compile_spec";
Result<DelegateHandle*> QnnExecuTorchBackend::init(
BackendInitContext& context,
FreeableBuffer* processed,
ArrayRef<CompileSpec> compile_specs) const {
// covert SizedBuffer to qnn ExecuTorch option
QnnExecuTorchContextBinary qnn_context_blob;
const qnn_delegate::QnnExecuTorchOptions* qnn_executorch_options = nullptr;
auto [status, signature, ctx_size, ctx_bin] =
QnnContextCustomProtocol().DeserializeContextCustomBuffer(
const_cast<void*>(processed->data()));
if (status == Error::Ok) {
QNN_EXECUTORCH_LOG_INFO(
"Deserializing processed data using QnnContextCustomProtocol");
// After this stage, qnn_context_blob.nbytes & qnn_context_blob.buffer will
// only store qnn_context_binary.
qnn_context_blob.nbytes = ctx_size;
qnn_context_blob.buffer = ctx_bin;
} else {
// This buffer will be verified again in QnnBackendCache.
QNN_EXECUTORCH_LOG_INFO("Deserializing processed data using Dlc");
qnn_context_blob.buffer = const_cast<void*>(processed->data());
qnn_context_blob.nbytes = processed->size();
}
// convert CompileSpec to qnn ExecuTorch option
for (auto& compile_spec : compile_specs) {
if (std::strcmp(compile_spec.key, QNN_COMPILE_SPEC) == 0)
qnn_executorch_options =
GetQnnExecuTorchOptions(compile_spec.value.buffer);
else
QNN_EXECUTORCH_LOG_WARN("unknown argument: %s", compile_spec.key);
}
// Create QnnManager
MemoryAllocator* runtime_allocator = context.get_runtime_allocator();
QnnManager* qnn_manager = runtime_allocator->allocateInstance<QnnManager>();
if (qnn_manager == nullptr) {
return Error::MemoryAllocationFailed;
}
// NOTE: Since we use placement new and since this type is not trivially
// destructible, we must call the destructor manually in destroy().
new (qnn_manager) QnnManager(qnn_executorch_options, qnn_context_blob);
// TODO: this is a temporal solution for multi-graph support, will be
// removed once framework starts to accept runtime configuration
// ---
// check if current context binary has already been initialized
// return cached one for reducing memory footprint
auto iter = delegate_map_.find(signature);
if (iter != delegate_map_.end()) {
QNN_EXECUTORCH_LOG_INFO(
"Use cached delegate handle for current method: %s",
context.get_method_name());
return iter->second;
}
ET_CHECK_OR_RETURN_ERROR(
qnn_manager->InitBackend() == Error::Ok,
Internal,
"Fail to initialize Qnn Manager");
ET_CHECK_OR_RETURN_ERROR(
qnn_manager->InitContext() == Error::Ok,
Internal,
"Fail to initialize Qnn Manager");
if (qnn_manager->IsOnlinePrepare()) {
ET_CHECK_OR_RETURN_ERROR(
qnn_manager->CompileDlc() == Error::Ok,
Internal,
"Fail to compile binary in Dlc format");
} else {
for (const std::string& graph_name : qnn_manager->GetGraphNames()) {
ET_CHECK_OR_RETURN_ERROR(
qnn_manager->AllocateTensor(graph_name) == Error::Ok,
Internal,
"Fail to allocate tensor");
}
}
add_cached_delegate(signature, qnn_manager);
#ifndef __hexagon__
// This backend does not need its processed data after Init.
processed->Free();
#endif
return qnn_manager;
}
Error QnnExecuTorchBackend::execute(
BackendExecutionContext& context,
DelegateHandle* handle,
Span<EValue*> args) const {
ET_CHECK_OR_RETURN_ERROR(
delegate_map_rev_.count(handle) != 0,
Internal,
"DelegateHandle has been deleted");
QnnManager* qnn_manager = static_cast<QnnManager*>(handle);
std::string method_name = context.get_method_name();
std::vector<std::shared_ptr<TensorWrapper>> input_tensors =
qnn_manager->GetGraphInputs(method_name);
std::vector<std::shared_ptr<TensorWrapper>> output_tensors =
qnn_manager->GetGraphOutputs(method_name);
std::vector<Qnn_Tensor_t> input_tensor_structs;
std::vector<Qnn_Tensor_t> output_tensor_structs;
int args_index = 0;
input_tensor_structs.reserve(input_tensors.size());
for (const auto& input_tensor : input_tensors) {
if (input_tensor->GetName().find("mutbuf_") == std::string::npos) {
if (qnn_manager->RegisterMem(
args[args_index]->toTensor().mutable_data_ptr(), input_tensor) !=
Error::Ok) {
// update data ptr only should be fine
input_tensor->FillDataBuffer(
args[args_index]->toTensor().const_data_ptr());
// use the real input shape instead of nominal one to make sure
// dynamic shape is functional
auto dims = args[args_index]->toTensor().sizes();
input_tensor->SetDims(dims.data(), dims.size());
}
args_index++;
}
input_tensor_structs.emplace_back(input_tensor->CloneTensorStruct());
}
for (const auto& output_tensor : output_tensors) {
// pos=0 limits the search to the prefix
if (output_tensor->GetName().rfind("output_", 0) == 0 &&
output_tensor->GetName().find("mutbuf_") == std::string::npos) {
void* mutable_data_ptr = args[args_index]->toTensor().mutable_data_ptr();
if (qnn_manager->RegisterMem(mutable_data_ptr, output_tensor) !=
Error::Ok) {
output_tensor->FillDataBuffer(mutable_data_ptr);
}
args_index++;
}
output_tensor_structs.push_back(output_tensor->CloneTensorStruct());
}
ET_CHECK_OR_RETURN_ERROR(
qnn_manager->Execute(
method_name,
input_tensor_structs,
output_tensor_structs,
context.event_tracer()) == Error::Ok,
Internal,
"Fail to execute graph");
ET_CHECK_OR_RETURN_ERROR(
qnn_manager->ProfileExecuteData(method_name, context.event_tracer()) ==
Error::Ok,
Internal,
"Fail to profile graph");
return Error::Ok;
}
void QnnExecuTorchBackend::destroy(DelegateHandle* handle) const {
if (handle != nullptr && delegate_map_rev_.count(handle)) {
QnnManager* qnn_manager = static_cast<QnnManager*>(handle);
qnn_manager->Destroy();
erase_cached_delegate(handle);
}
}
executorch::runtime::Error QnnExecuTorchBackend::set_option(
executorch::runtime::BackendOptionContext& context,
const executorch::runtime::Span<executorch::runtime::BackendOption>&
backend_options) {
std::lock_guard<std::mutex> guard(runtime_option_mutex_);
size_t matches = backend_options.size();
for (const auto& option : backend_options) {
if (strcmp(option.key, QNN_RUNTIME_LOG_LEVEL) == 0) {
if (auto* val = std::get_if<int>(&option.value)) {
qnn_runtime_log_level_.value = *val;
qnn_runtime_log_level_.is_set = true;
}
} else if (strcmp(option.key, QNN_RUNTIME_HTP_PERFORMANCE_MODE) == 0) {
if (auto* val = std::get_if<int>(&option.value)) {
qnn_runtime_performance_mode_.value = *val;
qnn_runtime_performance_mode_.is_set = true;
}
} else if (strcmp(option.key, QNN_RUNTIME_PROFILE_LEVEL) == 0) {
if (auto* val = std::get_if<int>(&option.value)) {
qnn_runtime_profile_level_.value = *val;
qnn_runtime_profile_level_.is_set = true;
}
} else if (strcmp(option.key, QNN_RUNTIME_LPAI_FPS) == 0) {
if (auto* val = std::get_if<int>(&option.value)) {
qnn_runtime_lpai_fps_.value = *val;
qnn_runtime_lpai_fps_.is_set = true;
}
} else if (strcmp(option.key, QNN_RUNTIME_LPAI_FTRT_RATIO) == 0) {
if (auto* val = std::get_if<int>(&option.value)) {
qnn_runtime_lpai_ftrt_ratio_.value = *val;
qnn_runtime_lpai_ftrt_ratio_.is_set = true;
}
} else if (strcmp(option.key, QNN_RUNTIME_LPAI_CLIENT_PERF_TYPE) == 0) {
if (auto* val = std::get_if<int>(&option.value)) {
qnn_runtime_lpai_client_perf_type_.value = *val;
qnn_runtime_lpai_client_perf_type_.is_set = true;
}
} else if (strcmp(option.key, QNN_RUNTIME_LPAI_AFFINITY) == 0) {
if (auto* val = std::get_if<int>(&option.value)) {
qnn_runtime_lpai_affinity_.value = *val;
qnn_runtime_lpai_affinity_.is_set = true;
}
} else if (strcmp(option.key, QNN_RUNTIME_LPAI_CORE_SELECTION) == 0) {
if (auto* val = std::get_if<int>(&option.value)) {
qnn_runtime_lpai_core_selection_.value = *val;
qnn_runtime_lpai_core_selection_.is_set = true;
}
} else {
ET_LOG(
Error,
"Unable to set the following runtime option for QnnExecuTorchBackend: %s.",
option.key);
matches--;
}
}
ET_CHECK_OR_RETURN_ERROR(
matches == backend_options.size(),
Internal,
"Some set options are not supported by QnnExecuTorchBackend. %zu options provided but only %zu is supported.",
backend_options.size(),
matches);
return Error::Ok;
}
executorch::runtime::Error QnnExecuTorchBackend::get_option(
executorch::runtime::BackendOptionContext& context,
executorch::runtime::Span<executorch::runtime::BackendOption>&
backend_options) {
size_t matches = backend_options.size();
for (size_t i = 0; i < backend_options.size(); ++i) {
// Set the value to what was stored by set_option
if (strcmp(backend_options[i].key, QNN_RUNTIME_LOG_LEVEL) == 0 &&
qnn_runtime_log_level_.is_set) {
backend_options[i].value = qnn_runtime_log_level_.value;
} else if (
strcmp(backend_options[i].key, QNN_RUNTIME_HTP_PERFORMANCE_MODE) == 0 &&
qnn_runtime_performance_mode_.is_set) {
backend_options[i].value = qnn_runtime_performance_mode_.value;
} else if (
strcmp(backend_options[i].key, QNN_RUNTIME_PROFILE_LEVEL) == 0 &&
qnn_runtime_profile_level_.is_set) {
backend_options[i].value = qnn_runtime_profile_level_.value;
} else if (
strcmp(backend_options[i].key, QNN_RUNTIME_LPAI_FPS) == 0 &&
qnn_runtime_lpai_fps_.is_set) {
backend_options[i].value = qnn_runtime_lpai_fps_.value;
} else if (
strcmp(backend_options[i].key, QNN_RUNTIME_LPAI_FTRT_RATIO) == 0 &&
qnn_runtime_lpai_ftrt_ratio_.is_set) {
backend_options[i].value = qnn_runtime_lpai_ftrt_ratio_.value;
} else if (
strcmp(backend_options[i].key, QNN_RUNTIME_LPAI_CLIENT_PERF_TYPE) ==
0 &&
qnn_runtime_lpai_client_perf_type_.is_set) {
backend_options[i].value = qnn_runtime_lpai_client_perf_type_.value;
} else if (
strcmp(backend_options[i].key, QNN_RUNTIME_LPAI_AFFINITY) == 0 &&
qnn_runtime_lpai_affinity_.is_set) {
backend_options[i].value = qnn_runtime_lpai_affinity_.value;
} else if (
strcmp(backend_options[i].key, QNN_RUNTIME_LPAI_CORE_SELECTION) == 0 &&
qnn_runtime_lpai_core_selection_.is_set) {
backend_options[i].value = qnn_runtime_lpai_core_selection_.value;
} else {
// either runtime never called set_option or key does not exist
matches--;
}
}
if (matches != backend_options.size()) {
return Error::Internal;
}
return Error::Ok;
}
bool QnnExecuTorchBackend::is_available() const {
return true;
}
void QnnExecuTorchBackend::add_cached_delegate(
const std::int64_t& signature,
executorch::runtime::DelegateHandle* handle) const {
std::lock_guard<std::mutex> guard(mutex_);
delegate_map_[signature] = handle;
delegate_map_rev_[handle] = signature;
}
void QnnExecuTorchBackend::erase_cached_delegate(
executorch::runtime::DelegateHandle* handle) const {
std::lock_guard<std::mutex> guard(mutex_);
auto iter = delegate_map_rev_.find(handle);
if (iter == delegate_map_rev_.end()) {
return;
}
delegate_map_.erase(iter->second);
delegate_map_rev_.erase(handle);
}
namespace {
auto cls = QnnExecuTorchBackend();
executorch::runtime::Backend backend{QNN_BACKEND, &cls};
static auto success_with_compiler = register_backend(backend);
} // namespace
} // namespace qnn
} // namespace backends
} // namespace executorch