1717#include " device.hpp"
1818#include " platform.hpp"
1919#include " ur/ur.hpp"
20+ #include " ur2offload.hpp"
2021#include " ur_api.h"
2122
22- ur_adapter_handle_t_ Adapter{} ;
23+ ur_adapter_handle_t Adapter = nullptr ;
2324
2425// Initialize liboffload and perform the initial platform and device discovery
2526ur_result_t ur_adapter_handle_t_::init () {
@@ -30,7 +31,7 @@ ur_result_t ur_adapter_handle_t_::init() {
3031 Res = olIterateDevices (
3132 [](ol_device_handle_t D, void *UserData) {
3233 auto *Platforms =
33- reinterpret_cast <decltype (Adapter. Platforms ) *>(UserData);
34+ reinterpret_cast <decltype (Adapter-> Platforms ) *>(UserData);
3435
3536 ol_platform_handle_t Platform;
3637 olGetDeviceInfo (D, OL_DEVICE_INFO_PLATFORM, sizeof (Platform),
@@ -39,7 +40,7 @@ ur_result_t ur_adapter_handle_t_::init() {
3940 olGetPlatformInfo (Platform, OL_PLATFORM_INFO_BACKEND, sizeof (Backend),
4041 &Backend);
4142 if (Backend == OL_PLATFORM_BACKEND_HOST) {
42- Adapter. HostDevice = D;
43+ Adapter-> HostDevice = D;
4344 } else if (Backend != OL_PLATFORM_BACKEND_UNKNOWN) {
4445 auto URPlatform =
4546 std::find_if (Platforms->begin (), Platforms->end (), [&](auto &P) {
@@ -57,37 +58,52 @@ ur_result_t ur_adapter_handle_t_::init() {
5758 }
5859 return false ;
5960 },
60- &Adapter. Platforms );
61+ &Adapter-> Platforms );
6162
62- (void )Res;
63-
64- return UR_RESULT_SUCCESS;
63+ return offloadResultToUR (Res);
6564}
6665
6766UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet (
6867 uint32_t , ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) {
68+ std::mutex InitMutex{};
69+
6970 if (phAdapters) {
70- if (++Adapter.RefCount == 1 ) {
71- Adapter.init ();
71+ std::lock_guard Guard{InitMutex};
72+
73+ // We explicitly only initialize the adapter when outputting it
74+ if (!Adapter) {
75+ Adapter = new ur_adapter_handle_t_{};
76+ auto Res = Adapter->init ();
77+ if (Res) {
78+ delete Adapter;
79+ Adapter = nullptr ;
80+ return Res;
81+ }
7282 }
73- *phAdapters = &Adapter;
83+ Adapter->RefCount ++;
84+ *phAdapters = Adapter;
7485 }
86+
7587 if (pNumAdapters) {
7688 *pNumAdapters = 1 ;
7789 }
90+
7891 return UR_RESULT_SUCCESS;
7992}
8093
8194UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease (ur_adapter_handle_t ) {
82- if (--Adapter.RefCount == 0 ) {
95+ // Doesn't need protecting by a lock - There is no way to reinitialize the
96+ // adapter after the final reference is released
97+ if (--Adapter->RefCount == 0 ) {
8398 // This can crash when tracing is enabled.
8499 // olShutDown();
100+ delete Adapter;
85101 };
86102 return UR_RESULT_SUCCESS;
87103}
88104
89105UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain (ur_adapter_handle_t ) {
90- Adapter. RefCount ++;
106+ Adapter-> RefCount ++;
91107 return UR_RESULT_SUCCESS;
92108}
93109
@@ -102,7 +118,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
102118 case UR_ADAPTER_INFO_BACKEND:
103119 return ReturnValue (UR_BACKEND_OFFLOAD);
104120 case UR_ADAPTER_INFO_REFERENCE_COUNT:
105- return ReturnValue (Adapter. RefCount .load ());
121+ return ReturnValue (Adapter-> RefCount .load ());
106122 case UR_ADAPTER_INFO_VERSION:
107123 return ReturnValue (1 );
108124 default :
@@ -124,15 +140,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterSetLoggerCallback(
124140 ur_adapter_handle_t , ur_logger_callback_t pfnLoggerCallback,
125141 void *pUserData, ur_logger_level_t level = UR_LOGGER_LEVEL_QUIET) {
126142
127- Adapter. Logger .setCallbackSink (pfnLoggerCallback, pUserData, level);
143+ Adapter-> Logger .setCallbackSink (pfnLoggerCallback, pUserData, level);
128144
129145 return UR_RESULT_SUCCESS;
130146}
131147
132148UR_APIEXPORT ur_result_t UR_APICALL
133149urAdapterSetLoggerCallbackLevel (ur_adapter_handle_t , ur_logger_level_t level) {
134150
135- Adapter. Logger .setCallbackLevel (level);
151+ Adapter-> Logger .setCallbackLevel (level);
136152
137153 return UR_RESULT_SUCCESS;
138154}
0 commit comments