@@ -56,263 +56,6 @@ ur_result_t urKernelGetSuggestedLocalWorkSize(
5656 return UR_RESULT_SUCCESS;
5757}
5858
59- inline ur_result_t EnqueueCooperativeKernelLaunchHelper (
60- // / [in] handle of the queue object
61- ur_queue_handle_t Queue,
62- // / [in] handle of the kernel object
63- ur_kernel_handle_t Kernel,
64- // / [in] number of dimensions, from 1 to 3, to specify the global and
65- // / work-group work-items
66- uint32_t WorkDim,
67- // / [in][optional] pointer to an array of workDim unsigned values that
68- // / specify the offset used to calculate the global ID of a work-item
69- const size_t *GlobalWorkOffset,
70- // / [in] pointer to an array of workDim unsigned values that specify the
71- // / number of global work-items in workDim that will execute the kernel
72- // / function
73- const size_t *GlobalWorkSize,
74- // / [in][optional] pointer to an array of workDim unsigned values that
75- // / specify the number of local work-items forming a work-group that
76- // / will execute the kernel function. If nullptr, the runtime
77- // / implementation will choose the work-group size.
78- const size_t *LocalWorkSize,
79- // / [in] size of the event wait list
80- uint32_t NumEventsInWaitList,
81- // / [in][optional][range(0, numEventsInWaitList)] pointer to a list of
82- // / events that must be complete before the kernel execution. If
83- // / nullptr, the numEventsInWaitList must be 0, indicating that no wait
84- // / event.
85- const ur_event_handle_t *EventWaitList,
86- // / [in,out][optional] return an event object that identifies this
87- // / particular kernel execution instance.
88- ur_event_handle_t *OutEvent) {
89- UR_ASSERT (WorkDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
90- UR_ASSERT (WorkDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
91-
92- auto ZeDevice = Queue->Device ->ZeDevice ;
93-
94- ze_kernel_handle_t ZeKernel{};
95- if (Kernel->ZeKernelMap .empty ()) {
96- ZeKernel = Kernel->ZeKernel ;
97- } else {
98- auto It = Kernel->ZeKernelMap .find (ZeDevice);
99- if (It == Kernel->ZeKernelMap .end ()) {
100- /* kernel and queue don't match */
101- return UR_RESULT_ERROR_INVALID_QUEUE;
102- }
103- ZeKernel = It->second ;
104- }
105- // Lock automatically releases when this goes out of scope.
106- std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
107- Queue->Mutex , Kernel->Mutex , Kernel->Program ->Mutex );
108- if (GlobalWorkOffset != NULL ) {
109- UR_CALL (setKernelGlobalOffset (Queue->Context , ZeKernel, WorkDim,
110- GlobalWorkOffset));
111- }
112-
113- // If there are any pending arguments set them now.
114- for (auto &Arg : Kernel->PendingArguments ) {
115- // The ArgValue may be a NULL pointer in which case a NULL value is used for
116- // the kernel argument declared as a pointer to global or constant memory.
117- char **ZeHandlePtr = nullptr ;
118- if (Arg.Value ) {
119- UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode ,
120- Queue->Device , EventWaitList,
121- NumEventsInWaitList));
122- }
123- ZE2UR_CALL (zeKernelSetArgumentValue,
124- (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
125- }
126- Kernel->PendingArguments .clear ();
127-
128- ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
129- uint32_t WG[3 ]{};
130-
131- // New variable needed because GlobalWorkSize parameter might not be of size 3
132- size_t GlobalWorkSize3D[3 ]{1 , 1 , 1 };
133- std::copy (GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);
134-
135- if (LocalWorkSize) {
136- // L0
137- for (uint32_t I = 0 ; I < WorkDim; I++) {
138- UR_ASSERT (LocalWorkSize[I] < (std::numeric_limits<uint32_t >::max)(),
139- UR_RESULT_ERROR_INVALID_VALUE);
140- WG[I] = static_cast <uint32_t >(LocalWorkSize[I]);
141- }
142- } else {
143- // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize
144- // values do not fit to 32-bit that the API only supports currently.
145- bool SuggestGroupSize = true ;
146- for (int I : {0 , 1 , 2 }) {
147- if (GlobalWorkSize3D[I] > UINT32_MAX) {
148- SuggestGroupSize = false ;
149- }
150- }
151- if (SuggestGroupSize) {
152- ZE2UR_CALL (zeKernelSuggestGroupSize,
153- (ZeKernel, GlobalWorkSize3D[0 ], GlobalWorkSize3D[1 ],
154- GlobalWorkSize3D[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
155- } else {
156- for (int I : {0 , 1 , 2 }) {
157- // Try to find a I-dimension WG size that the GlobalWorkSize[I] is
158- // fully divisable with. Start with the max possible size in
159- // each dimension.
160- uint32_t GroupSize[] = {
161- Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeX ,
162- Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeY ,
163- Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeZ };
164- GroupSize[I] = (std::min)(size_t (GroupSize[I]), GlobalWorkSize3D[I]);
165- while (GlobalWorkSize3D[I] % GroupSize[I]) {
166- --GroupSize[I];
167- }
168-
169- if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) {
170- UR_LOG (ERR,
171- " urEnqueueCooperativeKernelLaunchExp: can't find a WG size "
172- " suitable for global work size > UINT32_MAX" );
173- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
174- }
175- WG[I] = GroupSize[I];
176- }
177- UR_LOG (DEBUG,
178- " urEnqueueCooperativeKernelLaunchExp: using computed WG "
179- " size = {{{}, {}, {}}}" ,
180- WG[0 ], WG[1 ], WG[2 ]);
181- }
182- }
183-
184- // TODO: assert if sizes do not fit into 32-bit?
185-
186- switch (WorkDim) {
187- case 3 :
188- ZeThreadGroupDimensions.groupCountX =
189- static_cast <uint32_t >(GlobalWorkSize3D[0 ] / WG[0 ]);
190- ZeThreadGroupDimensions.groupCountY =
191- static_cast <uint32_t >(GlobalWorkSize3D[1 ] / WG[1 ]);
192- ZeThreadGroupDimensions.groupCountZ =
193- static_cast <uint32_t >(GlobalWorkSize3D[2 ] / WG[2 ]);
194- break ;
195- case 2 :
196- ZeThreadGroupDimensions.groupCountX =
197- static_cast <uint32_t >(GlobalWorkSize3D[0 ] / WG[0 ]);
198- ZeThreadGroupDimensions.groupCountY =
199- static_cast <uint32_t >(GlobalWorkSize3D[1 ] / WG[1 ]);
200- WG[2 ] = 1 ;
201- break ;
202- case 1 :
203- ZeThreadGroupDimensions.groupCountX =
204- static_cast <uint32_t >(GlobalWorkSize3D[0 ] / WG[0 ]);
205- WG[1 ] = WG[2 ] = 1 ;
206- break ;
207-
208- default :
209- UR_LOG (ERR, " urEnqueueCooperativeKernelLaunchExp: unsupported work_dim" );
210- return UR_RESULT_ERROR_INVALID_VALUE;
211- }
212-
213- // Error handling for non-uniform group size case
214- if (GlobalWorkSize3D[0 ] !=
215- size_t (ZeThreadGroupDimensions.groupCountX ) * WG[0 ]) {
216- UR_LOG (ERR,
217- " urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
218- " range is not a multiple of the group size in the 1st dimension" );
219- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
220- }
221- if (GlobalWorkSize3D[1 ] !=
222- size_t (ZeThreadGroupDimensions.groupCountY ) * WG[1 ]) {
223- UR_LOG (ERR,
224- " urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
225- " range is not a multiple of the group size in the 2nd dimension" );
226- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
227- }
228- if (GlobalWorkSize3D[2 ] !=
229- size_t (ZeThreadGroupDimensions.groupCountZ ) * WG[2 ]) {
230- UR_LOG (DEBUG,
231- " urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
232- " range is not a multiple of the group size in the 3rd dimension" );
233- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
234- }
235-
236- ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
237-
238- bool UseCopyEngine = false ;
239- ur_ze_event_list_t TmpWaitList;
240- UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
241- NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
242-
243- // Get a new command list to be used on this call
244- ur_command_list_ptr_t CommandList{};
245- UR_CALL (Queue->Context ->getAvailableCommandList (
246- Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
247- true /* AllowBatching */ , nullptr /* ForcedCmdQueue*/ ));
248-
249- ze_event_handle_t ZeEvent = nullptr ;
250- ur_event_handle_t InternalEvent{};
251- bool IsInternal = OutEvent == nullptr ;
252- ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
253-
254- UR_CALL (createEventAndAssociateQueue (Queue, Event, UR_COMMAND_KERNEL_LAUNCH,
255- CommandList, IsInternal, false ));
256- UR_CALL (setSignalEvent (Queue, UseCopyEngine, &ZeEvent, Event,
257- NumEventsInWaitList, EventWaitList,
258- CommandList->second .ZeQueue ));
259- (*Event)->WaitList = TmpWaitList;
260-
261- // Save the kernel in the event, so that when the event is signalled
262- // the code can do a urKernelRelease on this kernel.
263- (*Event)->CommandData = (void *)Kernel;
264-
265- // Increment the reference count of the Kernel and indicate that the Kernel
266- // is in use. Once the event has been signalled, the code in
267- // CleanupCompletedEvent(Event) will do a urKernelRelease to update the
268- // reference count on the kernel, using the kernel saved in CommandData.
269- UR_CALL (ur::level_zero::urKernelRetain (Kernel));
270-
271- // Add to list of kernels to be submitted
272- if (IndirectAccessTrackingEnabled)
273- Queue->KernelsToBeSubmitted .push_back (Kernel);
274-
275- if (Queue->UsingImmCmdLists && IndirectAccessTrackingEnabled) {
276- // If using immediate commandlists then gathering of indirect
277- // references and appending to the queue (which means submission)
278- // must be done together.
279- std::unique_lock<ur_shared_mutex> ContextsLock (
280- Queue->Device ->Platform ->ContextsMutex , std::defer_lock);
281- // We are going to submit kernels for execution. If indirect access flag is
282- // set for a kernel then we need to make a snapshot of existing memory
283- // allocations in all contexts in the platform. We need to lock the mutex
284- // guarding the list of contexts in the platform to prevent creation of new
285- // memory alocations in any context before we submit the kernel for
286- // execution.
287- ContextsLock.lock ();
288- Queue->CaptureIndirectAccesses ();
289- // Add the command to the command list, which implies submission.
290- ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
291- (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
292- (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
293- } else {
294- // Add the command to the command list for later submission.
295- // No lock is needed here, unlike the immediate commandlist case above,
296- // because the kernels are not actually submitted yet. Kernels will be
297- // submitted only when the comamndlist is closed. Then, a lock is held.
298- ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
299- (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
300- (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
301- }
302-
303- UR_LOG (DEBUG,
304- " calling zeCommandListAppendLaunchCooperativeKernel() with ZeEvent {}" ,
305- ur_cast<std::uintptr_t >(ZeEvent));
306- printZeEventList ((*Event)->WaitList );
307-
308- // Execute command list asynchronously, as the event will be used
309- // to track down its completion.
310- UR_CALL (Queue->executeCommandList (CommandList, false /* IsBlocking*/ ,
311- true /* OKToBatchCommand*/ ));
312-
313- return UR_RESULT_SUCCESS;
314- }
315-
31659ur_result_t urEnqueueKernelLaunch (
31760 // / [in] handle of the queue object
31861 ur_queue_handle_t Queue,
@@ -348,14 +91,16 @@ ur_result_t urEnqueueKernelLaunch(
34891 // / [in,out][optional] return an event object that identifies this
34992 // / particular kernel execution instance.
35093 ur_event_handle_t *OutEvent) {
94+ using ZeKernelLaunchFuncT = ze_result_t (*)(
95+ ze_command_list_handle_t , ze_kernel_handle_t , const ze_group_count_t *,
96+ ze_event_handle_t , uint32_t , ze_event_handle_t *);
97+ ZeKernelLaunchFuncT ZeKernelLaunchFunc = &zeCommandListAppendLaunchKernel;
35198 for (uint32_t PropIndex = 0 ; PropIndex < NumPropsInLaunchPropList;
35299 PropIndex++) {
353100 if (LaunchPropList[PropIndex].id ==
354101 UR_KERNEL_LAUNCH_PROPERTY_ID_COOPERATIVE &&
355102 LaunchPropList[PropIndex].value .cooperative ) {
356- return EnqueueCooperativeKernelLaunchHelper (
357- Queue, Kernel, WorkDim, GlobalWorkOffset, GlobalWorkSize,
358- LocalWorkSize, NumEventsInWaitList, EventWaitList, OutEvent);
103+ ZeKernelLaunchFunc = &zeCommandListAppendLaunchCooperativeKernel;
359104 }
360105 if (LaunchPropList[PropIndex].id != UR_KERNEL_LAUNCH_PROPERTY_ID_IGNORE &&
361106 LaunchPropList[PropIndex].id !=
@@ -454,15 +199,15 @@ ur_result_t urEnqueueKernelLaunch(
454199 ContextsLock.lock ();
455200 Queue->CaptureIndirectAccesses ();
456201 // Add the command to the command list, which implies submission.
457- ZE2UR_CALL (zeCommandListAppendLaunchKernel ,
202+ ZE2UR_CALL (ZeKernelLaunchFunc ,
458203 (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
459204 (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
460205 } else {
461206 // Add the command to the command list for later submission.
462207 // No lock is needed here, unlike the immediate commandlist case above,
463208 // because the kernels are not actually submitted yet. Kernels will be
464209 // submitted only when the comamndlist is closed. Then, a lock is held.
465- ZE2UR_CALL (zeCommandListAppendLaunchKernel ,
210+ ZE2UR_CALL (ZeKernelLaunchFunc ,
466211 (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
467212 (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
468213 }
0 commit comments