Skip to content

Commit 57f6780

Browse files
committed
fix: Avoid invalid casting into loader objects when DDI extension is supported
Related-To: NEO-15615 Signed-off-by: Vishnu Khanth <vishnu.khanth.b@intel.com>
1 parent 1357333 commit 57f6780

1 file changed

Lines changed: 120 additions & 34 deletions

File tree

source/loader/ze_loader_api.cpp

Lines changed: 120 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -101,84 +101,170 @@ zelLoaderTranslateHandleInternal(
101101

102102
*handleOut = handleIn;
103103
switch(handleType){
104-
case ZEL_HANDLE_DRIVER:
105-
if (loader::context->ze_driver_factory.hasInstance(reinterpret_cast<loader::ze_driver_object_t*>(handleIn)->handle)) {
106-
*handleOut = reinterpret_cast<loader::ze_driver_object_t*>( handleIn )->handle;
104+
case ZEL_HANDLE_DRIVER:
105+
{
106+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
107+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->Driver->pfnGet ==
108+
loader::loaderDispatch->pCore->Driver->pfnGet);
109+
if (legacy_ldr_intercept_enabled && loader::context->ze_driver_factory.hasInstance(reinterpret_cast<loader::ze_driver_object_t *>(handleIn)->handle))
110+
{
111+
*handleOut = reinterpret_cast<loader::ze_driver_object_t *>(handleIn)->handle;
107112
}
108113
break;
109-
case ZEL_HANDLE_DEVICE:
110-
if (loader::context->ze_device_factory.hasInstance(reinterpret_cast<loader::ze_device_object_t*>(handleIn)->handle)){
111-
*handleOut = reinterpret_cast<loader::ze_device_object_t*>( handleIn )->handle;
114+
}
115+
case ZEL_HANDLE_DEVICE:
116+
{
117+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
118+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->Device->pfnGet ==
119+
loader::loaderDispatch->pCore->Device->pfnGet);
120+
if (legacy_ldr_intercept_enabled && loader::context->ze_device_factory.hasInstance(reinterpret_cast<loader::ze_device_object_t *>(handleIn)->handle))
121+
{
122+
*handleOut = reinterpret_cast<loader::ze_device_object_t *>(handleIn)->handle;
112123
}
113124
break;
114-
case ZEL_HANDLE_CONTEXT:
115-
if (loader::context->ze_context_factory.hasInstance(reinterpret_cast<loader::ze_context_object_t*>(handleIn)->handle)) {
116-
*handleOut = reinterpret_cast<loader::ze_context_object_t*>( handleIn )->handle;
125+
}
126+
case ZEL_HANDLE_CONTEXT:
127+
{
128+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
129+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->Context->pfnCreate ==
130+
loader::loaderDispatch->pCore->Context->pfnCreate);
131+
if (legacy_ldr_intercept_enabled &&
132+
loader::context->ze_context_factory.hasInstance(reinterpret_cast<loader::ze_context_object_t *>(handleIn)->handle))
133+
{
134+
*handleOut = reinterpret_cast<loader::ze_context_object_t *>(handleIn)->handle;
117135
}
118136
break;
137+
}
119138
case ZEL_HANDLE_COMMAND_QUEUE:
120-
if (loader::context->ze_command_queue_factory.hasInstance(reinterpret_cast<loader::ze_command_queue_object_t*>(handleIn)->handle)) {
121-
*handleOut = reinterpret_cast<loader::ze_command_queue_object_t*>( handleIn )->handle;
139+
{
140+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
141+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->CommandQueue->pfnCreate ==
142+
loader::loaderDispatch->pCore->CommandQueue->pfnCreate);
143+
if (legacy_ldr_intercept_enabled && loader::context->ze_command_queue_factory.hasInstance(reinterpret_cast<loader::ze_command_queue_object_t *>(handleIn)->handle))
144+
{
145+
*handleOut = reinterpret_cast<loader::ze_command_queue_object_t *>(handleIn)->handle;
122146
}
123147
break;
148+
}
124149
case ZEL_HANDLE_COMMAND_LIST:
125-
if (loader::context->ze_command_list_factory.hasInstance(reinterpret_cast<loader::ze_command_list_object_t*>(handleIn)->handle)) {
126-
*handleOut = reinterpret_cast<loader::ze_command_list_object_t*>( handleIn )->handle;
150+
{
151+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
152+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->CommandList->pfnCreate ==
153+
loader::loaderDispatch->pCore->CommandList->pfnCreate);
154+
if (legacy_ldr_intercept_enabled && loader::context->ze_command_list_factory.hasInstance(reinterpret_cast<loader::ze_command_list_object_t *>(handleIn)->handle))
155+
{
156+
*handleOut = reinterpret_cast<loader::ze_command_list_object_t *>(handleIn)->handle;
127157
}
128158
break;
159+
}
129160
case ZEL_HANDLE_FENCE:
130-
if (loader::context->ze_fence_factory.hasInstance(reinterpret_cast<loader::ze_fence_object_t*>(handleIn)->handle)) {
131-
*handleOut = reinterpret_cast<loader::ze_fence_object_t*>( handleIn )->handle;
161+
{
162+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
163+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->Fence->pfnCreate ==
164+
loader::loaderDispatch->pCore->Fence->pfnCreate);
165+
if (legacy_ldr_intercept_enabled && loader::context->ze_fence_factory.hasInstance(reinterpret_cast<loader::ze_fence_object_t *>(handleIn)->handle))
166+
{
167+
*handleOut = reinterpret_cast<loader::ze_fence_object_t *>(handleIn)->handle;
132168
}
133169
break;
170+
}
134171
case ZEL_HANDLE_EVENT_POOL:
135-
if (loader::context->ze_event_pool_factory.hasInstance(reinterpret_cast<loader::ze_event_pool_object_t*>(handleIn)->handle)) {
136-
*handleOut = reinterpret_cast<loader::ze_event_pool_object_t*>( handleIn )->handle;
172+
{
173+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
174+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->EventPool->pfnCreate ==
175+
loader::loaderDispatch->pCore->EventPool->pfnCreate);
176+
if (legacy_ldr_intercept_enabled && loader::context->ze_event_pool_factory.hasInstance(reinterpret_cast<loader::ze_event_pool_object_t *>(handleIn)->handle))
177+
{
178+
*handleOut = reinterpret_cast<loader::ze_event_pool_object_t *>(handleIn)->handle;
137179
}
138180
break;
181+
}
139182
case ZEL_HANDLE_EVENT:
140-
if (loader::context->ze_event_factory.hasInstance(reinterpret_cast<loader::ze_event_object_t*>(handleIn)->handle)) {
141-
*handleOut = reinterpret_cast<loader::ze_event_object_t*>( handleIn )->handle;
183+
{
184+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
185+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->Event->pfnCreate ==
186+
loader::loaderDispatch->pCore->Event->pfnCreate);
187+
if (legacy_ldr_intercept_enabled && loader::context->ze_event_factory.hasInstance(reinterpret_cast<loader::ze_event_object_t *>(handleIn)->handle))
188+
{
189+
*handleOut = reinterpret_cast<loader::ze_event_object_t *>(handleIn)->handle;
142190
}
143191
break;
192+
}
144193
case ZEL_HANDLE_IMAGE:
145-
if (loader::context->ze_image_factory.hasInstance(reinterpret_cast<loader::ze_image_object_t*>(handleIn)->handle)) {
146-
*handleOut = reinterpret_cast<loader::ze_image_object_t*>( handleIn )->handle;
194+
{
195+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
196+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->Image->pfnCreate ==
197+
loader::loaderDispatch->pCore->Image->pfnCreate);
198+
if (legacy_ldr_intercept_enabled && loader::context->ze_image_factory.hasInstance(reinterpret_cast<loader::ze_image_object_t *>(handleIn)->handle))
199+
{
200+
*handleOut = reinterpret_cast<loader::ze_image_object_t *>(handleIn)->handle;
147201
}
148202
break;
203+
}
149204
case ZEL_HANDLE_MODULE:
150-
if (loader::context->ze_module_factory.hasInstance(reinterpret_cast<loader::ze_module_object_t*>(handleIn)->handle)) {
151-
*handleOut = reinterpret_cast<loader::ze_module_object_t*>( handleIn )->handle;
205+
{
206+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
207+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->Module->pfnCreate ==
208+
loader::loaderDispatch->pCore->Module->pfnCreate);
209+
if (legacy_ldr_intercept_enabled && loader::context->ze_module_factory.hasInstance(reinterpret_cast<loader::ze_module_object_t *>(handleIn)->handle))
210+
{
211+
*handleOut = reinterpret_cast<loader::ze_module_object_t *>(handleIn)->handle;
152212
}
153213
break;
214+
}
154215
case ZEL_HANDLE_MODULE_BUILD_LOG:
155-
if (loader::context->ze_module_build_log_factory.hasInstance(reinterpret_cast<loader::ze_module_build_log_object_t*>(handleIn)->handle)) {
156-
*handleOut = reinterpret_cast<loader::ze_module_build_log_object_t*>( handleIn )->handle;
216+
{
217+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
218+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->ModuleBuildLog->pfnGetString ==
219+
loader::loaderDispatch->pCore->ModuleBuildLog->pfnGetString);
220+
if (legacy_ldr_intercept_enabled && loader::context->ze_module_build_log_factory.hasInstance(reinterpret_cast<loader::ze_module_build_log_object_t *>(handleIn)->handle))
221+
{
222+
*handleOut = reinterpret_cast<loader::ze_module_build_log_object_t *>(handleIn)->handle;
157223
}
158224
break;
225+
}
159226
case ZEL_HANDLE_KERNEL:
160-
if (loader::context->ze_kernel_factory.hasInstance(reinterpret_cast<loader::ze_kernel_object_t*>(handleIn)->handle)) {
161-
*handleOut = reinterpret_cast<loader::ze_kernel_object_t*>( handleIn )->handle;
227+
{
228+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
229+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->Kernel->pfnCreate ==
230+
loader::loaderDispatch->pCore->Kernel->pfnCreate);
231+
if (legacy_ldr_intercept_enabled && loader::context->ze_kernel_factory.hasInstance(reinterpret_cast<loader::ze_kernel_object_t *>(handleIn)->handle))
232+
{
233+
*handleOut = reinterpret_cast<loader::ze_kernel_object_t *>(handleIn)->handle;
162234
}
163235
break;
236+
}
164237
case ZEL_HANDLE_SAMPLER:
165-
if (loader::context->ze_sampler_factory.hasInstance(reinterpret_cast<loader::ze_sampler_object_t*>(handleIn)->handle)) {
166-
*handleOut = reinterpret_cast<loader::ze_sampler_object_t*>( handleIn )->handle;
238+
{
239+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
240+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->Sampler->pfnCreate ==
241+
loader::loaderDispatch->pCore->Sampler->pfnCreate);
242+
if (legacy_ldr_intercept_enabled && loader::context->ze_sampler_factory.hasInstance(reinterpret_cast<loader::ze_sampler_object_t *>(handleIn)->handle))
243+
{
244+
*handleOut = reinterpret_cast<loader::ze_sampler_object_t *>(handleIn)->handle;
167245
}
168246
break;
247+
}
169248
case ZEL_HANDLE_PHYSICAL_MEM:
170-
if (loader::context->ze_physical_mem_factory.hasInstance(reinterpret_cast<loader::ze_physical_mem_object_t*>(handleIn)->handle)) {
171-
*handleOut = reinterpret_cast<loader::ze_physical_mem_object_t*>( handleIn )->handle;
249+
{
250+
bool legacy_ldr_intercept_enabled = (!loader::context->driverDDIPathDefault) ||
251+
(reinterpret_cast<ze_handle_t *>(handleIn)->pCore->PhysicalMem->pfnCreate ==
252+
loader::loaderDispatch->pCore->PhysicalMem->pfnCreate);
253+
if (legacy_ldr_intercept_enabled && loader::context->ze_physical_mem_factory.hasInstance(reinterpret_cast<loader::ze_physical_mem_object_t *>(handleIn)->handle))
254+
{
255+
*handleOut = reinterpret_cast<loader::ze_physical_mem_object_t *>(handleIn)->handle;
172256
}
173257
break;
258+
}
174259
default:
175260
return ZE_RESULT_ERROR_INVALID_ENUMERATION;
176-
}
177-
261+
}
262+
263+
178264
return ZE_RESULT_SUCCESS;
179265
}
180266

181267

182268
#if defined(__cplusplus)
183269
}
184-
#endif
270+
#endif

0 commit comments

Comments
 (0)