Skip to content

Commit 36192cb

Browse files
RonakrM-planet
authored andcommitted
olive changes to support QNN ABI EP
1 parent fe20993 commit 36192cb

3 files changed

Lines changed: 76 additions & 13 deletions

File tree

olive/common/ort_inference.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ def maybe_register_ep_libraries(ep_paths: dict[str, str]):
5353
if (Path(ort.__file__).parent / "capi" / builtin_library_name).exists():
5454
ep_paths[provider] = builtin_library_name
5555

56+
# ABI Ep
57+
if "QNNExecutionProvider" in ep_paths and ep_paths["QNNExecutionProvider"] is None:
58+
try:
59+
import onnxruntime_qnn as qnn_ep
60+
61+
ep_paths["QNNExecutionProvider"] = qnn_ep.get_library_path()
62+
except ImportError:
63+
logger.info("Failed to import onnxruntime_qnn")
64+
5665
for ep_name, ep_path in ep_paths.items():
5766
if ep_path is None:
5867
continue

olive/passes/onnx/context_binary.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,19 @@ def _run_for_config(
7676
) -> Union[ONNXModelHandler, CompositeModelHandler]:
7777
# session created using providers argument so will use the ort.get_available_providers()
7878
# TODO(jambayk): consider switching to the new EP API for Windows
79+
from onnxruntime import __version__ as OrtVersion
7980
from onnxruntime import get_available_providers
8081

8182
# TODO(jambayk): validate and support other NPU EPs
8283
assert self.accelerator_spec.execution_provider == ExecutionProvider.QNNExecutionProvider, (
8384
"Only QNNExecutionProvider is supported for now."
8485
)
85-
assert self.accelerator_spec.execution_provider in get_available_providers(), (
86-
f"Execution provider {self.accelerator_spec.execution_provider} is not available. Available providers:"
87-
f" {get_available_providers()}"
88-
)
86+
87+
if version.parse(OrtVersion).release <= version.parse("1.23.2").release:
88+
assert self.accelerator_spec.execution_provider in get_available_providers(), (
89+
f"Execution provider {self.accelerator_spec.execution_provider} is not available. Available providers:"
90+
f" {get_available_providers()}"
91+
)
8992

9093
result = self._run_single_target(model, config, output_model_path)
9194

@@ -257,17 +260,22 @@ def _generate_context_binary(
257260
import onnxruntime as ort
258261
from onnxruntime import __version__ as OrtVersion
259262

263+
is_abi = (
264+
"QNNExecutionProvider" not in ort.get_available_providers()
265+
or version.parse(OrtVersion).release >= version.parse("1.25.0").release
266+
)
267+
logger.debug(" Using ABI EP: %s", str(is_abi))
268+
260269
# prepare provider options
261270
provider_options = provider_options or {}
262271
if execution_provider == ExecutionProvider.QNNExecutionProvider:
263272
if str(device).lower() == "gpu":
264273
provider_options["backend_path"] = "libQnnGpu.so" if platform.system() == "Linux" else "QnnGpu.dll"
265274
update_llm_pipeline_genai_config_gpu_ctxbin(model_path)
266275
else:
267-
if version.parse(OrtVersion).release < version.parse("1.22.0").release:
268-
provider_options["backend_path"] = "libQnnHtp.so" if platform.system() == "Linux" else "QnnHtp.dll"
269-
if share_ep_contexts:
270-
provider_options["enable_htp_weight_sharing"] = "1"
276+
provider_options["backend_path"] = "libQnnHtp.so" if platform.system() == "Linux" else "QnnHtp.dll"
277+
if share_ep_contexts:
278+
provider_options["enable_htp_weight_sharing"] = "1"
271279

272280
# prepare session options
273281
session_options = session_options or {}
@@ -299,9 +307,40 @@ def _generate_context_binary(
299307
# create the inference session
300308
# requires regular onnxruntime package, not winml (not tested with winml)
301309
logger.debug("Creating context binary for model %s", str(model_path))
302-
ort.InferenceSession(
303-
model_path, sess_options=sess_options, providers=[execution_provider], provider_options=[provider_options]
304-
)
310+
311+
if is_abi:
312+
try:
313+
import onnxruntime_qnn as qnn_ep
314+
315+
ep_lib_path = qnn_ep.get_library_path()
316+
ep_registration_name = "QNNExecutionProvider"
317+
ort.register_execution_provider_library(ep_registration_name, ep_lib_path)
318+
except Exception as e:
319+
if "already registered" in str(e):
320+
logger.debug(
321+
"Execution provider %s is already registered, skipping registration.", ep_registration_name
322+
)
323+
else:
324+
raise
325+
all_ep_devices = ort.get_ep_devices()
326+
selected_ep_devices = [
327+
ep_device for ep_device in all_ep_devices if ep_device.ep_name == ExecutionProvider.QNNExecutionProvider
328+
]
329+
330+
# Add QNN EP to session for abi ep
331+
sess_options.add_provider_for_devices(selected_ep_devices, provider_options)
332+
ort.InferenceSession(
333+
model_path,
334+
sess_options=sess_options,
335+
)
336+
ort.unregister_execution_provider_library(ep_registration_name)
337+
else:
338+
ort.InferenceSession(
339+
model_path,
340+
sess_options=sess_options,
341+
providers=[execution_provider],
342+
provider_options=[provider_options],
343+
)
305344

306345
assert output_model_path.exists(), f"Context binary not found at {output_model_path}"
307346

olive/systems/utils/available_providers_runner.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
# NOTE: Only onnxruntime and its dependencies can be imported in this file.
66
import argparse
77
import json
8+
import logging
89
from pathlib import Path
910

1011
import onnxruntime as ort
1112

13+
logger = logging.getLogger(__name__)
14+
1215

1316
def get_args(raw_args):
1417
parser = argparse.ArgumentParser(description="Get available execution providers")
@@ -19,10 +22,22 @@ def get_args(raw_args):
1922

2023
def main(raw_args=None):
2124
args = get_args(raw_args)
22-
25+
available_eps = []
26+
try:
27+
import onnxruntime_qnn as qnn_ep
28+
29+
ep_lib_path = qnn_ep.get_library_path()
30+
ep_registration_name = "QNNExecutionProvider"
31+
ort.register_execution_provider_library(ep_registration_name, ep_lib_path)
32+
33+
# get available providers for ABI EP with ort 1.24 is broken. Hence the below hack
34+
available_eps.append("QNNExecutionProvider")
35+
ort.unregister_execution_provider_library(ep_registration_name)
36+
except Exception as e:
37+
logger.warning("Failed to register QNNExecutionProvider: %s", str(e))
2338
# get available execution providers
2439
# python environment system doesn't use EP registration yet
25-
available_eps = ort.get_available_providers()
40+
available_eps.extend(ort.get_available_providers())
2641

2742
# save to json
2843
with Path(args.output_path).open("w") as f:

0 commit comments

Comments
 (0)