@@ -76,16 +76,19 @@ def _run_for_config(
7676 ) -> Union [ONNXModelHandler , CompositeModelHandler ]:
7777 # session created using providers argument so will use the ort.get_available_providers()
7878 # TODO(jambayk): consider switching to the new EP API for Windows
79+ from onnxruntime import __version__ as OrtVersion
7980 from onnxruntime import get_available_providers
8081
8182 # TODO(jambayk): validate and support other NPU EPs
8283 assert self .accelerator_spec .execution_provider == ExecutionProvider .QNNExecutionProvider , (
8384 "Only QNNExecutionProvider is supported for now."
8485 )
85- assert self .accelerator_spec .execution_provider in get_available_providers (), (
86- f"Execution provider { self .accelerator_spec .execution_provider } is not available. Available providers:"
87- f" { get_available_providers ()} "
88- )
86+
87+ if version .parse (OrtVersion ).release <= version .parse ("1.23.2" ).release :
88+ assert self .accelerator_spec .execution_provider in get_available_providers (), (
89+ f"Execution provider { self .accelerator_spec .execution_provider } is not available. Available providers:"
90+ f" { get_available_providers ()} "
91+ )
8992
9093 result = self ._run_single_target (model , config , output_model_path )
9194
@@ -257,17 +260,22 @@ def _generate_context_binary(
257260 import onnxruntime as ort
258261 from onnxruntime import __version__ as OrtVersion
259262
263+ is_abi = (
264+ "QNNExecutionProvider" not in ort .get_available_providers ()
265+ or version .parse (OrtVersion ).release >= version .parse ("1.25.0" ).release
266+ )
267+ logger .debug (" Using ABI EP: %s" , str (is_abi ))
268+
260269 # prepare provider options
261270 provider_options = provider_options or {}
262271 if execution_provider == ExecutionProvider .QNNExecutionProvider :
263272 if str (device ).lower () == "gpu" :
264273 provider_options ["backend_path" ] = "libQnnGpu.so" if platform .system () == "Linux" else "QnnGpu.dll"
265274 update_llm_pipeline_genai_config_gpu_ctxbin (model_path )
266275 else :
267- if version .parse (OrtVersion ).release < version .parse ("1.22.0" ).release :
268- provider_options ["backend_path" ] = "libQnnHtp.so" if platform .system () == "Linux" else "QnnHtp.dll"
269- if share_ep_contexts :
270- provider_options ["enable_htp_weight_sharing" ] = "1"
276+ provider_options ["backend_path" ] = "libQnnHtp.so" if platform .system () == "Linux" else "QnnHtp.dll"
277+ if share_ep_contexts :
278+ provider_options ["enable_htp_weight_sharing" ] = "1"
271279
272280 # prepare session options
273281 session_options = session_options or {}
@@ -299,9 +307,40 @@ def _generate_context_binary(
299307 # create the inference session
300308 # requires regular onnxruntime package, not winml (not tested with winml)
301309 logger .debug ("Creating context binary for model %s" , str (model_path ))
302- ort .InferenceSession (
303- model_path , sess_options = sess_options , providers = [execution_provider ], provider_options = [provider_options ]
304- )
310+
311+ if is_abi :
312+ try :
313+ import onnxruntime_qnn as qnn_ep
314+
315+ ep_lib_path = qnn_ep .get_library_path ()
316+ ep_registration_name = "QNNExecutionProvider"
317+ ort .register_execution_provider_library (ep_registration_name , ep_lib_path )
318+ except Exception as e :
319+ if "already registered" in str (e ):
320+ logger .debug (
321+ "Execution provider %s is already registered, skipping registration." , ep_registration_name
322+ )
323+ else :
324+ raise
325+ all_ep_devices = ort .get_ep_devices ()
326+ selected_ep_devices = [
327+ ep_device for ep_device in all_ep_devices if ep_device .ep_name == ExecutionProvider .QNNExecutionProvider
328+ ]
329+
330+ # Add QNN EP to session for abi ep
331+ sess_options .add_provider_for_devices (selected_ep_devices , provider_options )
332+ ort .InferenceSession (
333+ model_path ,
334+ sess_options = sess_options ,
335+ )
336+ ort .unregister_execution_provider_library (ep_registration_name )
337+ else :
338+ ort .InferenceSession (
339+ model_path ,
340+ sess_options = sess_options ,
341+ providers = [execution_provider ],
342+ provider_options = [provider_options ],
343+ )
305344
306345 assert output_model_path .exists (), f"Context binary not found at { output_model_path } "
307346
0 commit comments