@@ -52,17 +52,6 @@ struct NDRDescT {
5252};
5353} // namespace native_cpu
5454
55- #ifdef NATIVECPU_USE_OCK
56- static native_cpu::state getResizedState (const native_cpu::NDRDescT &ndr,
57- size_t itemsPerThread) {
58- native_cpu::state resized_state (
59- ndr.GlobalSize [0 ], ndr.GlobalSize [1 ], ndr.GlobalSize [2 ], itemsPerThread,
60- ndr.LocalSize [1 ], ndr.LocalSize [2 ], ndr.GlobalOffset [0 ],
61- ndr.GlobalOffset [1 ], ndr.GlobalOffset [2 ]);
62- return resized_state;
63- }
64- #endif
65-
6655UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
6756 ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
6857 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -112,6 +101,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
112101 // TODO: add proper error checking
113102 native_cpu::NDRDescT ndr (workDim, pGlobalWorkOffset, pGlobalWorkSize,
114103 pLocalWorkSize);
104+ unsigned long long numWI;
105+ auto umulll_overflow = [](unsigned long long a, unsigned long long b,
106+ unsigned long long *c) -> bool {
107+ #ifdef __GNUC__
108+ return __builtin_umulll_overflow (a, b, c);
109+ #else
110+ *c = a * b;
111+ return a != 0 && b != *c / a;
112+ #endif
113+ };
114+ if (umulll_overflow (ndr.GlobalSize [0 ], ndr.GlobalSize [1 ], &numWI) ||
115+ umulll_overflow (numWI, ndr.GlobalSize [2 ], &numWI) || numWI > SIZE_MAX) {
116+ return UR_RESULT_ERROR_OUT_OF_RESOURCES;
117+ }
118+
115119 auto &tp = hQueue->getDevice ()->tp ;
116120 const size_t numParallelThreads = tp.num_threads ();
117121 std::vector<std::future<void >> futures;
@@ -130,131 +134,56 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
130134 auto kernel = std::make_unique<ur_kernel_handle_t_>(*hKernel);
131135 kernel->updateMemPool (numParallelThreads);
132136
137+ const size_t numWG = numWG0 * numWG1 * numWG2;
138+ const size_t numWGPerThread = numWG / numParallelThreads;
139+ const size_t remainderWG = numWG - numWGPerThread * numParallelThreads;
140+ // The fourth value is the linearized value.
141+ std::array<size_t , 4 > rangeStart = {0 , 0 , 0 , 0 };
142+ for (unsigned t = 0 ; t < numParallelThreads; ++t) {
143+ auto rangeEnd = rangeStart;
144+ rangeEnd[3 ] += numWGPerThread + (t < remainderWG);
145+ if (rangeEnd[3 ] == rangeStart[3 ])
146+ break ;
147+ rangeEnd[0 ] = rangeEnd[3 ] % numWG0;
148+ rangeEnd[1 ] = (rangeEnd[3 ] / numWG0) % numWG1;
149+ rangeEnd[2 ] = rangeEnd[3 ] / (numWG0 * numWG1);
150+ futures.emplace_back (
151+ tp.schedule_task ([state, &kernel = *kernel, rangeStart,
152+ rangeEnd = rangeEnd[3 ], numWG0, numWG1,
133153#ifndef NATIVECPU_USE_OCK
134- for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
135- for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
136- for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
137- for (unsigned local2 = 0 ; local2 < ndr.LocalSize [2 ]; local2++) {
138- for (unsigned local1 = 0 ; local1 < ndr.LocalSize [1 ]; local1++) {
139- for (unsigned local0 = 0 ; local0 < ndr.LocalSize [0 ]; local0++) {
140- state.update (g0, g1, g2, local0, local1, local2);
141- kernel->_subhandler (kernel->getArgs (1 , 0 ).data (), &state);
142- }
143- }
144- }
145- }
146- }
147- }
154+ localSize = ndr.LocalSize ,
155+ #endif
156+ numParallelThreads](size_t threadId) mutable {
157+ for (size_t g0 = rangeStart[0 ], g1 = rangeStart[1 ],
158+ g2 = rangeStart[2 ], g3 = rangeStart[3 ];
159+ g3 < rangeEnd; ++g3) {
160+ #ifdef NATIVECPU_USE_OCK
161+ state.update (g0, g1, g2);
162+ kernel._subhandler (
163+ kernel.getArgs (numParallelThreads, threadId).data (), &state);
148164#else
149- bool isLocalSizeOne =
150- ndr.LocalSize [0 ] == 1 && ndr.LocalSize [1 ] == 1 && ndr.LocalSize [2 ] == 1 ;
151- if (isLocalSizeOne && ndr.GlobalSize [0 ] > numParallelThreads &&
152- !kernel->hasLocalArgs ()) {
153- // If the local size is one, we make the assumption that we are running a
154- // parallel_for over a sycl::range.
155- // Todo: we could add more compiler checks and
156- // kernel properties for this (e.g. check that no barriers are called).
157-
158- // Todo: this assumes that dim 0 is the best dimension over which we want to
159- // parallelize
160-
161- // Since we also vectorize the kernel, and vectorization happens within the
162- // work group loop, it's better to have a large-ish local size. We can
163- // divide the global range by the number of threads, set that as the local
164- // size and peel everything else.
165-
166- size_t new_num_work_groups_0 = numParallelThreads;
167- size_t itemsPerThread = ndr.GlobalSize [0 ] / numParallelThreads;
168-
169- for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
170- for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
171- for (unsigned g0 = 0 ; g0 < new_num_work_groups_0; g0 += 1 ) {
172- futures.emplace_back (tp.schedule_task (
173- [ndr, itemsPerThread, &kernel = *kernel, g0, g1, g2](size_t ) {
174- native_cpu::state resized_state =
175- getResizedState (ndr, itemsPerThread);
176- resized_state.update (g0, g1, g2);
177- kernel._subhandler (kernel.getArgs ().data (), &resized_state);
178- }));
179- }
180- // Peel the remaining work items. Since the local size is 1, we iterate
181- // over the work groups.
182- for (unsigned g0 = new_num_work_groups_0 * itemsPerThread; g0 < numWG0;
183- g0++) {
184- state.update (g0, g1, g2);
185- kernel->_subhandler (kernel->getArgs ().data (), &state);
186- }
187- }
188- }
189-
190- } else {
191- // We are running a parallel_for over an nd_range
192-
193- if (numWG1 * numWG2 >= numParallelThreads) {
194- // Dimensions 1 and 2 have enough work, split them across the threadpool
195- for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
196- for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
197- futures.emplace_back (
198- tp.schedule_task ([state, &kernel = *kernel, numWG0, g1, g2,
199- numParallelThreads](size_t threadId) mutable {
200- for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
201- state.update (g0, g1, g2);
165+ for (size_t local2 = 0 ; local2 < localSize[2 ]; ++local2) {
166+ for (size_t local1 = 0 ; local1 < localSize[1 ]; ++local1) {
167+ for (size_t local0 = 0 ; local0 < localSize[0 ]; ++local0) {
168+ state.update (g0, g1, g2, local0, local1, local2);
202169 kernel._subhandler (
203170 kernel.getArgs (numParallelThreads, threadId).data (),
204171 &state);
205172 }
206- }));
207- }
208- }
209- } else {
210- // Split dimension 0 across the threadpool
211- // Here we try to create groups of workgroups in order to reduce
212- // synchronization overhead
213- for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
214- for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
215- for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
216- groups.push_back ([state, g0, g1, g2, numParallelThreads](
217- size_t threadId,
218- ur_kernel_handle_t_ &kernel) mutable {
219- state.update (g0, g1, g2);
220- kernel._subhandler (
221- kernel.getArgs (numParallelThreads, threadId).data (), &state);
222- });
223- }
224- }
225- }
226- auto numGroups = groups.size ();
227- auto groupsPerThread = numGroups / numParallelThreads;
228- if (groupsPerThread) {
229- for (unsigned thread = 0 ; thread < numParallelThreads; thread++) {
230- futures.emplace_back (
231- tp.schedule_task ([groups, thread, groupsPerThread,
232- &kernel = *kernel](size_t threadId) {
233- for (unsigned i = 0 ; i < groupsPerThread; i++) {
234- auto index = thread * groupsPerThread + i;
235- groups[index](threadId, kernel);
236- }
237- }));
238- }
239- }
240-
241- // schedule the remaining tasks
242- auto remainder = numGroups % numParallelThreads;
243- if (remainder) {
244- futures.emplace_back (
245- tp.schedule_task ([groups, remainder,
246- scheduled = numParallelThreads * groupsPerThread,
247- &kernel = *kernel](size_t threadId) {
248- for (unsigned i = 0 ; i < remainder; i++) {
249- auto index = scheduled + i;
250- groups[index](threadId, kernel);
251173 }
252- }));
253- }
254- }
174+ }
175+ #endif
176+ if (++g0 == numWG0) {
177+ g0 = 0 ;
178+ if (++g1 == numWG1) {
179+ g1 = 0 ;
180+ ++g2;
181+ }
182+ }
183+ }
184+ }));
185+ rangeStart = rangeEnd;
255186 }
256-
257- #endif // NATIVECPU_USE_OCK
258187 event->set_futures (futures);
259188
260189 if (phEvent) {
0 commit comments