@@ -801,7 +801,8 @@ static float _half_to_float(uint16_t h)
801801static gboolean _try_provider (OrtSessionOptions * session_opts ,
802802 const char * symbol_name ,
803803 const char * provider_name ,
804- const char * device_type )
804+ const char * device_type ,
805+ uint32_t flags )
805806{
806807 OrtStatus * status = NULL ;
807808 gboolean ok = FALSE;
@@ -851,7 +852,7 @@ static gboolean _try_provider(OrtSessionOptions *session_opts,
851852 // integer-argument providers (CUDA, CoreML, DML, MIGraphX, ROCm)
852853 typedef OrtStatus * (* ProviderAppenderInt )(OrtSessionOptions * , uint32_t );
853854 ProviderAppenderInt appender = (ProviderAppenderInt )func_ptr ;
854- status = appender (session_opts , 0 );
855+ status = appender (session_opts , flags );
855856 }
856857 if (!status )
857858 {
@@ -880,7 +881,9 @@ static gboolean _try_provider(OrtSessionOptions *session_opts,
880881}
881882
882883static void
883- _enable_acceleration (OrtSessionOptions * session_opts , dt_ai_provider_t provider )
884+ _enable_acceleration (OrtSessionOptions * session_opts ,
885+ dt_ai_provider_t provider ,
886+ uint32_t coreml_flags )
884887{
885888 switch (provider )
886889 {
@@ -894,36 +897,36 @@ _enable_acceleration(OrtSessionOptions *session_opts, dt_ai_provider_t provider)
894897 _try_provider (
895898 session_opts ,
896899 "OrtSessionOptionsAppendExecutionProvider_CoreML" ,
897- "Apple CoreML" , NULL );
900+ "Apple CoreML" , NULL , coreml_flags );
898901#else
899902 dt_print (DT_DEBUG_AI , "[darktable_ai] apple CoreML not available on this platform" );
900903#endif
901904 break ;
902905
903906 case DT_AI_PROVIDER_CUDA :
904- _try_provider (session_opts , "OrtSessionOptionsAppendExecutionProvider_CUDA" , "NVIDIA CUDA" , NULL );
907+ _try_provider (session_opts , "OrtSessionOptionsAppendExecutionProvider_CUDA" , "NVIDIA CUDA" , NULL , 0 );
905908 break ;
906909
907910 case DT_AI_PROVIDER_MIGRAPHX :
908911 // MIGraphX reads its cache env vars once at provider library
909912 // load time, so they must be set before CreateEnv() — see
910913 // _setup_amd_caches() above. OpenVINO (below) takes options
911914 // per-session, so its cache path is passed inline here
912- if (!_try_provider (session_opts , "OrtSessionOptionsAppendExecutionProvider_MIGraphX" , "AMD MIGraphX" , NULL ))
913- _try_provider (session_opts , "OrtSessionOptionsAppendExecutionProvider_ROCM" , "AMD ROCm (legacy)" , NULL );
915+ if (!_try_provider (session_opts , "OrtSessionOptionsAppendExecutionProvider_MIGraphX" , "AMD MIGraphX" , NULL , 0 ))
916+ _try_provider (session_opts , "OrtSessionOptionsAppendExecutionProvider_ROCM" , "AMD ROCm (legacy)" , NULL , 0 );
914917 break ;
915918
916919 case DT_AI_PROVIDER_OPENVINO :
917920 if (!_try_openvino_with_cache (session_opts ))
918- _try_provider (session_opts , "OrtSessionOptionsAppendExecutionProvider_OpenVINO" , "Intel OpenVINO" , "AUTO" );
921+ _try_provider (session_opts , "OrtSessionOptionsAppendExecutionProvider_OpenVINO" , "Intel OpenVINO" , "AUTO" , 0 );
919922 break ;
920923
921924 case DT_AI_PROVIDER_DIRECTML :
922925#if defined(_WIN32 )
923926 _try_provider (
924927 session_opts ,
925928 "OrtSessionOptionsAppendExecutionProvider_DML" ,
926- "Windows DirectML" , NULL );
929+ "Windows DirectML" , NULL , 0 );
927930#else
928931 dt_print (DT_DEBUG_AI , "[darktable_ai] windows DirectML not available on this platform" );
929932#endif
@@ -936,27 +939,27 @@ _enable_acceleration(OrtSessionOptions *session_opts, dt_ai_provider_t provider)
936939 _try_provider (
937940 session_opts ,
938941 "OrtSessionOptionsAppendExecutionProvider_CoreML" ,
939- "Apple CoreML" , NULL );
942+ "Apple CoreML" , NULL , coreml_flags );
940943#elif defined(_WIN32 )
941944 _try_provider (
942945 session_opts ,
943946 "OrtSessionOptionsAppendExecutionProvider_DML" ,
944- "Windows DirectML" , NULL );
947+ "Windows DirectML" , NULL , 0 );
945948#elif defined(__linux__ )
946949 // try CUDA first, then MIGraphX (cache configured at env init)
947950 if (!_try_provider (
948951 session_opts ,
949952 "OrtSessionOptionsAppendExecutionProvider_CUDA" ,
950- "NVIDIA CUDA" , NULL ))
953+ "NVIDIA CUDA" , NULL , 0 ))
951954 {
952955 if (!_try_provider (
953956 session_opts ,
954957 "OrtSessionOptionsAppendExecutionProvider_MIGraphX" ,
955- "AMD MIGraphX" , NULL ))
958+ "AMD MIGraphX" , NULL , 0 ))
956959 _try_provider (
957960 session_opts ,
958961 "OrtSessionOptionsAppendExecutionProvider_ROCM" ,
959- "AMD ROCm (legacy)" , NULL );
962+ "AMD ROCm (legacy)" , NULL , 0 );
960963 }
961964#endif
962965 break ;
@@ -996,20 +999,20 @@ int dt_ai_probe_provider(dt_ai_provider_t provider)
996999 switch (provider )
9971000 {
9981001 case DT_AI_PROVIDER_COREML :
999- ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_CoreML" , "Apple CoreML" , NULL );
1002+ ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_CoreML" , "Apple CoreML" , NULL , 0 );
10001003 break ;
10011004 case DT_AI_PROVIDER_CUDA :
1002- ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_CUDA" , "NVIDIA CUDA" , NULL );
1005+ ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_CUDA" , "NVIDIA CUDA" , NULL , 0 );
10031006 break ;
10041007 case DT_AI_PROVIDER_MIGRAPHX :
1005- ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_MIGraphX" , "AMD MIGraphX" , NULL )
1006- || _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_ROCM" , "AMD ROCm (legacy)" , NULL );
1008+ ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_MIGraphX" , "AMD MIGraphX" , NULL , 0 )
1009+ || _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_ROCM" , "AMD ROCm (legacy)" , NULL , 0 );
10071010 break ;
10081011 case DT_AI_PROVIDER_OPENVINO :
1009- ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_OpenVINO" , "Intel OpenVINO" , "AUTO" );
1012+ ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_OpenVINO" , "Intel OpenVINO" , "AUTO" , 0 );
10101013 break ;
10111014 case DT_AI_PROVIDER_DIRECTML :
1012- ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_DML" , "Windows DirectML" , NULL );
1015+ ok = _try_provider (opts , "OrtSessionOptionsAppendExecutionProvider_DML" , "Windows DirectML" , NULL , 0 );
10131016 break ;
10141017 default :
10151018 break ;
@@ -1026,7 +1029,8 @@ int dt_ai_probe_provider(dt_ai_provider_t provider)
10261029dt_ai_context_t *
10271030dt_ai_onnx_load_ext (const char * model_dir , const char * model_file ,
10281031 dt_ai_provider_t provider , dt_ai_opt_level_t opt_level ,
1029- const dt_ai_dim_override_t * dim_overrides , int n_overrides )
1032+ const dt_ai_dim_override_t * dim_overrides , int n_overrides ,
1033+ uint32_t ep_flags )
10301034{
10311035 if (!model_dir )
10321036 return NULL ;
@@ -1111,7 +1115,7 @@ dt_ai_onnx_load_ext(const char *model_dir, const char *model_file,
11111115 }
11121116
11131117 // optimize: enable hardware acceleration (AMD caches set at env init)
1114- _enable_acceleration (session_opts , provider );
1118+ _enable_acceleration (session_opts , provider , ep_flags );
11151119
11161120#ifdef _WIN32
11171121 // on windows, CreateSession expects a wide character string
@@ -1176,7 +1180,7 @@ dt_ai_onnx_load_ext(const char *model_dir, const char *model_file,
11761180 if (s ) g_ort -> ReleaseStatus (s );
11771181 }
11781182 if (fallbacks [fb ].prov != DT_AI_PROVIDER_CPU )
1179- _enable_acceleration (session_opts , fallbacks [fb ].prov );
1183+ _enable_acceleration (session_opts , fallbacks [fb ].prov , ep_flags );
11801184#ifdef _WIN32
11811185 status = g_ort -> CreateSession (g_env , onnx_path_wide , session_opts , & ctx -> session );
11821186#else
0 commit comments