@@ -86,82 +86,16 @@ void Context::closeExternalHandle(uint64_t handle) {
8686
8787std::pair<NEO::GraphicsAllocation *, void *> Context::getMemHandlePtr (ze_device_handle_t hDevice, uint64_t handle, NEO::AllocationType allocationType, bool isHostIpcAllocation, unsigned int processId, ze_ipc_memory_flags_t flags, uint64_t cacheID, void *reservedHandleData, bool compressedMemory, bool isOpaqueHandle) {
8888 auto neoDevice = Device::fromHandle (hDevice)->getNEODevice ();
89- uint64_t importHandle = handle;
9089 uint64_t effectiveCacheID = cacheID;
91- bool pidfdSuccess = false ;
92- bool socketFallbackSuccess = false ;
93-
94- if (isOpaqueHandle && settings.useOpaqueHandle && processId != 0 ) {
95- // Check cache first for opaque handles
96- if (this ->driverHandle ->tryGetCachedImportHandle (cacheID, importHandle)) {
97- pidfdSuccess = true ; // Mark as successful to skip import logic
98- }
99-
100- if (!pidfdSuccess && reservedHandleData) {
101- int importHandleFromReserved = -1 ;
102- importHandleFromReserved = this ->driverHandle ->getMemoryManager ()->getImportHandleFromReservedHandleData (reservedHandleData, neoDevice->getRootDeviceIndex ());
103- if (importHandleFromReserved != -1 ) {
104- importHandle = static_cast <uint64_t >(importHandleFromReserved);
105- pidfdSuccess = true ;
106- }
107- }
108-
109- // Try pidfd approach first (unless forced to use socket fallback or already cached)
110- if (!pidfdSuccess && (settings.useOpaqueHandle & OpaqueHandlingType::pidfd) && !NEO::debugManager.flags .ForceIpcSocketFallback .get ()) {
111- pid_t exporterPid = static_cast <pid_t >(processId);
112- unsigned int pidfdFlags = 0u ;
113- int pidfd = NEO::SysCalls::pidfdopen (exporterPid, pidfdFlags);
114- if (pidfd == -1 ) {
115- PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr, " pidfd_open Syscall failed: %s\n " , strerror (errno));
116- } else {
117- unsigned int getfdFlags = 0u ;
118- int newfd = NEO::SysCalls::pidfdgetfd (pidfd, static_cast <int >(handle), getfdFlags);
119- NEO::SysCalls::close (pidfd);
120- if (newfd < 0 ) {
121- PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr, " pidfd_getfd Syscall failed: %s\n " , strerror (errno));
122- } else {
123- importHandle = static_cast <uint64_t >(newfd);
124- pidfdSuccess = true ;
125- // Cache the imported handle for future use
126- this ->driverHandle ->setCachedImportHandle (cacheID, importHandle);
127- PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr,
128- " Cached import handle %lu for cache ID %lu\n " ,
129- importHandle, cacheID);
130- }
131- }
132- }
133-
134- // Try socket fallback if pidfd failed and socket fallback is enabled
135- if (!pidfdSuccess && (settings.useOpaqueHandle == OpaqueHandlingType::sockets)) {
136- pid_t exporterPid = static_cast <pid_t >(processId);
137- std::string socketPath = " neo_ipc_" + std::to_string (exporterPid);
138- NEO::IpcSocketClient socketClient;
139-
140- if (socketClient.connectToServer (socketPath)) {
141- int receivedFd = socketClient.requestHandle (handle);
142- if (receivedFd != -1 ) {
143- importHandle = static_cast <uint64_t >(receivedFd);
144- socketFallbackSuccess = true ;
145- // Cache the imported handle for future use
146- this ->driverHandle ->setCachedImportHandle (cacheID, importHandle);
147- PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr,
148- " IPC socket fallback successful for handle %lu, cached as %lu\n " ,
149- handle, importHandle);
150- } else {
151- PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr,
152- " IPC socket fallback failed for handle %lu\n " , handle);
153- }
154- } else {
155- PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr,
156- " Failed to connect to IPC socket server at %s\n " , socketPath.c_str ());
157- }
90+ uint64_t importHandle = handle;
15891
159- if (!socketFallbackSuccess ) {
160- PRINT_STRING (NEO::debugManager. flags . PrintDebugMessages . get (), stderr,
161- " Socket fallback failed for handle %lu, returning nullptr \n " , handle );
162- return { nullptr , nullptr };
163- }
92+ if (isOpaqueHandle && settings. useOpaqueHandle ) {
93+ // Use helper to import opaque handle with fallback
94+ auto importResult = importOpaqueHandleWithFallback ( handle, processId, cacheID, reservedHandleData, neoDevice );
95+ if (!importResult. success ) {
96+ return { nullptr , nullptr };
16497 }
98+ importHandle = importResult.importHandle ;
16599 }
166100
167101 NEO::GraphicsAllocation *alloc = nullptr ;
@@ -246,4 +180,86 @@ ze_result_t Context::systemBarrier(ze_device_handle_t hDevice) {
246180 NEO::SysCalls::munmap (ptr, MemoryConstants::pageSize);
247181 return ZE_RESULT_SUCCESS;
248182}
183+
184+ Context::OpaqueHandleImportResult Context::importOpaqueHandleWithFallback (uint64_t handle,
185+ unsigned int processId,
186+ uint64_t cacheID,
187+ void *reservedHandleData,
188+ NEO::Device *neoDevice) {
189+ uint64_t importHandle = handle;
190+ bool handleRetrieved = false ;
191+ bool socketFallbackSuccess = false ;
192+
193+ // Check cache first for opaque handles
194+ if (this ->driverHandle ->tryGetCachedImportHandle (cacheID, importHandle)) {
195+ handleRetrieved = true ; // Mark as successful to skip import logic
196+ }
197+
198+ if (!handleRetrieved && reservedHandleData) {
199+ int importHandleFromReserved = -1 ;
200+ importHandleFromReserved = this ->driverHandle ->getMemoryManager ()->getImportHandleFromReservedHandleData (reservedHandleData, neoDevice->getRootDeviceIndex ());
201+ if (importHandleFromReserved != -1 ) {
202+ importHandle = static_cast <uint64_t >(importHandleFromReserved);
203+ handleRetrieved = true ;
204+ }
205+ }
206+
207+ // Try pidfd approach first (unless forced to use socket fallback or already cached)
208+ if (!handleRetrieved && (settings.useOpaqueHandle & OpaqueHandlingType::pidfd) && !NEO::debugManager.flags .ForceIpcSocketFallback .get ()) {
209+ pid_t exporterPid = static_cast <pid_t >(processId);
210+ unsigned int pidfdFlags = 0u ;
211+ int pidfd = NEO::SysCalls::pidfdopen (exporterPid, pidfdFlags);
212+ if (pidfd == -1 ) {
213+ PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr, " pidfd_open Syscall failed: %s\n " , strerror (errno));
214+ } else {
215+ unsigned int getfdFlags = 0u ;
216+ int newfd = NEO::SysCalls::pidfdgetfd (pidfd, static_cast <int >(handle), getfdFlags);
217+ NEO::SysCalls::close (pidfd);
218+ if (newfd < 0 ) {
219+ PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr, " pidfd_getfd Syscall failed: %s\n " , strerror (errno));
220+ return {0 , false };
221+ } else {
222+ importHandle = static_cast <uint64_t >(newfd);
223+ handleRetrieved = true ;
224+ // Cache the imported handle for future use
225+ this ->driverHandle ->setCachedImportHandle (cacheID, importHandle);
226+ }
227+ }
228+ }
229+
230+ // Try socket fallback if pidfd failed and socket fallback is enabled
231+ if (!handleRetrieved && (settings.useOpaqueHandle & OpaqueHandlingType::sockets)) {
232+ pid_t exporterPid = static_cast <pid_t >(processId);
233+ std::string socketPath = " neo_ipc_" + std::to_string (exporterPid);
234+
235+ NEO::IpcSocketClient socketClient;
236+ if (socketClient.connectToServer (socketPath)) {
237+ int receivedFd = socketClient.requestHandle (handle);
238+ if (receivedFd != -1 ) {
239+ importHandle = static_cast <uint64_t >(receivedFd);
240+ socketFallbackSuccess = true ;
241+ // Cache the imported handle for future use
242+ this ->driverHandle ->setCachedImportHandle (cacheID, importHandle);
243+ PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr,
244+ " IPC socket fallback successful for handle %lu, cached as %lu\n " ,
245+ handle, importHandle);
246+ } else {
247+ PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr,
248+ " IPC socket fallback failed for handle %lu\n " , handle);
249+ }
250+ } else {
251+ PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr,
252+ " Failed to connect to IPC socket server at %s\n " , socketPath.c_str ());
253+ }
254+
255+ // If socket fallback was attempted but failed, return failure
256+ if (!socketFallbackSuccess) {
257+ PRINT_STRING (NEO::debugManager.flags .PrintDebugMessages .get (), stderr,
258+ " Socket fallback failed for handle %lu, returning failure\n " , handle);
259+ return {0 , false };
260+ }
261+ }
262+
263+ return {importHandle, true };
264+ }
249265} // namespace L0
0 commit comments