Skip to content

Commit 0c882a0

Browse files
uditagarwal97Copiloticlsrc
authored andcommitted
Move max_global_work_groups query to UR (#21840)
This PR makes the following changes: 1. Moves implementation of `max_global_work_groups` to UR. This query has been implemented in SYCL RT because there's no backend support for `max_global_work_groups` query. However, it was recently decided that OpenCL will add a corresponding query. See CMPLRLLVM-73572 for more info. 2. Changes `max_global_work_groups` from `INT_MAX` to `SIZE_MAX` for all backends. For CUDA, HIP, OFFLOAD, and L0 adapter, we calculate the value of `max_global_work_groups` by taking minimum of `SIZE_MAX` and multiplication of per-dimension max group size. 3. Changed `max_work_groups<3>` so that `max_global_work_groups` no longer limits per-dimension max work group size. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: iclsrc <iclsrc@intel.com>
1 parent e6be14f commit 0c882a0

12 files changed

Lines changed: 149 additions & 7 deletions

File tree

include/unified-runtime/ur_api.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,24 +2380,26 @@ typedef enum ur_device_info_t {
23802380
UR_DEVICE_INFO_PREFERRED_VECTOR_WIDTH_LONG_LONG = 131,
23812381
/// [uint32_t] native vector width for long long
23822382
UR_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG_LONG = 132,
2383+
/// [size_t] return max total number of work groups
2384+
UR_DEVICE_INFO_MAX_WORK_GROUPS = 133,
23832385
/// [uint32_t][optional-query] return Intel GPU number of
23842386
/// stacks/chiplets/tiles
2385-
UR_DEVICE_INFO_XE_STACK_COUNT = 133,
2387+
UR_DEVICE_INFO_XE_STACK_COUNT = 134,
23862388
/// [uint32_t][optional-query] return Intel GPU number of regions sharing
23872389
/// local L2/L3 (XE_CU) per stack
2388-
UR_DEVICE_INFO_XE_REGIONS_PER_STACK = 134,
2390+
UR_DEVICE_INFO_XE_REGIONS_PER_STACK = 135,
23892391
/// [uint32_t][optional-query] return Intel GPU number of clusters
23902392
/// (slices) per region
2391-
UR_DEVICE_INFO_XE_CLUSTERS_PER_REGION = 135,
2393+
UR_DEVICE_INFO_XE_CLUSTERS_PER_REGION = 136,
23922394
/// [uint32_t][optional-query] return Intel GPU number of XE cores per
23932395
/// cluster
2394-
UR_DEVICE_INFO_XE_CORES_PER_CLUSTER = 136,
2396+
UR_DEVICE_INFO_XE_CORES_PER_CLUSTER = 137,
23952397
/// [uint32_t][optional-query] return Intel GPU number of execution
23962398
/// engines (EUs) per XE Core
2397-
UR_DEVICE_INFO_EUS_PER_XE_CORE = 137,
2399+
UR_DEVICE_INFO_EUS_PER_XE_CORE = 138,
23982400
/// [uint32_t][optional-query] return Intel GPU maximal number of lanes
23992401
/// (virtual SIMD size) per hardware thread
2400-
UR_DEVICE_INFO_MAX_LANES_PER_HW_THREAD = 138,
2402+
UR_DEVICE_INFO_MAX_LANES_PER_HW_THREAD = 139,
24012403
/// [::ur_bool_t] Returns true if the device supports the use of
24022404
/// command-buffers.
24032405
UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP = 0x1000,

include/unified-runtime/ur_print.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3180,6 +3180,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_device_info_t value) {
31803180
case UR_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG_LONG:
31813181
os << "UR_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG_LONG";
31823182
break;
3183+
case UR_DEVICE_INFO_MAX_WORK_GROUPS:
3184+
os << "UR_DEVICE_INFO_MAX_WORK_GROUPS";
3185+
break;
31833186
case UR_DEVICE_INFO_XE_STACK_COUNT:
31843187
os << "UR_DEVICE_INFO_XE_STACK_COUNT";
31853188
break;
@@ -5038,6 +5041,19 @@ inline ur_result_t printTagged(std::ostream &os, const void *ptr,
50385041

50395042
os << ")";
50405043
} break;
5044+
case UR_DEVICE_INFO_MAX_WORK_GROUPS: {
5045+
const size_t *tptr = (const size_t *)ptr;
5046+
if (sizeof(size_t) > size) {
5047+
os << "invalid size (is: " << size << ", expected: >=" << sizeof(size_t)
5048+
<< ")";
5049+
return UR_RESULT_ERROR_INVALID_SIZE;
5050+
}
5051+
os << (const void *)(tptr) << " (";
5052+
5053+
os << *tptr;
5054+
5055+
os << ")";
5056+
} break;
50415057
case UR_DEVICE_INFO_XE_STACK_COUNT: {
50425058
const uint32_t *tptr = (const uint32_t *)ptr;
50435059
if (sizeof(uint32_t) > size) {

scripts/core/device.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,8 @@ etors:
473473
desc: "[uint32_t] preferred vector width for long long"
474474
- name: NATIVE_VECTOR_WIDTH_LONG_LONG
475475
desc: "[uint32_t] native vector width for long long"
476+
- name: MAX_WORK_GROUPS
477+
desc: "[size_t] return max total number of work groups"
476478
- name: XE_STACK_COUNT
477479
desc: "[uint32_t][optional-query] return Intel GPU number of stacks/chiplets/tiles"
478480
- name: XE_REGIONS_PER_STACK

source/adapters/cuda/device.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
110110
return ReturnValue(ReturnSizes);
111111
}
112112

113+
case UR_DEVICE_INFO_MAX_WORK_GROUPS: {
114+
int MaxX = 0, MaxY = 0, MaxZ = 0;
115+
UR_CHECK_ERROR(cuDeviceGetAttribute(
116+
&MaxX, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, hDevice->get()));
117+
assert(MaxX >= 0);
118+
119+
UR_CHECK_ERROR(cuDeviceGetAttribute(
120+
&MaxY, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, hDevice->get()));
121+
assert(MaxY >= 0);
122+
123+
UR_CHECK_ERROR(cuDeviceGetAttribute(
124+
&MaxZ, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, hDevice->get()));
125+
assert(MaxZ >= 0);
126+
127+
return ReturnValue(multiplyWithOverflowCheck(static_cast<size_t>(MaxX),
128+
static_cast<size_t>(MaxY),
129+
static_cast<size_t>(MaxZ)));
130+
}
131+
113132
case UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
114133
int MaxWorkGroupSize = 0;
115134
UR_CHECK_ERROR(cuDeviceGetAttribute(

source/adapters/hip/device.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
117117
return ReturnValue(return_sizes);
118118
}
119119

120+
case UR_DEVICE_INFO_MAX_WORK_GROUPS: {
121+
int MaxX = 0, MaxY = 0, MaxZ = 0;
122+
UR_CHECK_ERROR(hipDeviceGetAttribute(&MaxX, hipDeviceAttributeMaxGridDimX,
123+
hDevice->get()));
124+
assert(MaxX >= 0);
125+
126+
UR_CHECK_ERROR(hipDeviceGetAttribute(&MaxY, hipDeviceAttributeMaxGridDimY,
127+
hDevice->get()));
128+
assert(MaxY >= 0);
129+
130+
UR_CHECK_ERROR(hipDeviceGetAttribute(&MaxZ, hipDeviceAttributeMaxGridDimZ,
131+
hDevice->get()));
132+
assert(MaxZ >= 0);
133+
134+
return ReturnValue(multiplyWithOverflowCheck(static_cast<size_t>(MaxX),
135+
static_cast<size_t>(MaxY),
136+
static_cast<size_t>(MaxZ)));
137+
}
138+
120139
case UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
121140
int MaxWorkGroupSize = 0;
122141
UR_CHECK_ERROR(hipDeviceGetAttribute(&MaxWorkGroupSize,

source/adapters/level_zero/device.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,14 @@ ur_result_t urDeviceGetInfo(
471471
Device->ZeDeviceComputeProperties->maxGroupCountZ}};
472472
return ReturnValue(MaxGroupCounts);
473473
}
474+
case UR_DEVICE_INFO_MAX_WORK_GROUPS: {
475+
// Multiply the max group counts in each dimension to get the total max
476+
// number of work groups. Prevent overflow.
477+
return ReturnValue(multiplyWithOverflowCheck(
478+
Device->ZeDeviceComputeProperties->maxGroupCountX,
479+
Device->ZeDeviceComputeProperties->maxGroupCountY,
480+
Device->ZeDeviceComputeProperties->maxGroupCountZ));
481+
}
474482
case UR_DEVICE_INFO_MAX_CLOCK_FREQUENCY:
475483
return ReturnValue(uint32_t{Device->ZeDeviceProperties->coreClockRate});
476484
case UR_DEVICE_INFO_ADDRESS_BITS: {

source/adapters/native_cpu/device.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
363363
case UR_DEVICE_INFO_SUB_GROUP_INDEPENDENT_FORWARD_PROGRESS:
364364
case UR_DEVICE_INFO_IL_VERSION:
365365
case UR_DEVICE_INFO_MAX_WORK_GROUPS_3D:
366+
case UR_DEVICE_INFO_MAX_WORK_GROUPS:
366367
case UR_DEVICE_INFO_MEMORY_CLOCK_RATE:
367368
case UR_DEVICE_INFO_MEMORY_BUS_WIDTH:
368369
case UR_DEVICE_INFO_GLOBAL_MEM_FREE:

source/adapters/offload/device.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include <OffloadAPI.h>
11+
#include <limits>
1112
#include <unified-runtime/ur_api.h>
1213
#include <ur/ur.hpp>
1314

@@ -211,6 +212,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
211212

212213
return UR_RESULT_SUCCESS;
213214
}
215+
case UR_DEVICE_INFO_MAX_WORK_GROUPS: {
216+
// OL dimensions are uint32_t while UR is size_t, so they need to be mapped.
217+
if (pPropSizeRet) {
218+
*pPropSizeRet = sizeof(size_t);
219+
}
220+
221+
if (pPropValue) {
222+
ol_dimensions_t olVec;
223+
OL_RETURN_ON_ERR(olGetDeviceInfo(
224+
hDevice->OffloadDevice, OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION,
225+
sizeof(olVec), &olVec));
226+
227+
// Multiply the max group counts in each dimension to get the total max
228+
// number of work groups. Prevent overflow.
229+
*reinterpret_cast<size_t *>(pPropValue) = multiplyWithOverflowCheck(
230+
static_cast<size_t>(olVec.x), static_cast<size_t>(olVec.y),
231+
static_cast<size_t>(olVec.z));
232+
}
233+
234+
return UR_RESULT_SUCCESS;
235+
}
214236

215237
// Unimplemented features
216238
case UR_DEVICE_INFO_PROGRAM_SET_SPECIALIZATION_CONSTANTS:

source/adapters/opencl/device.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
269269
ReturnSizes.sizes[2] = Max;
270270
return ReturnValue(ReturnSizes);
271271
}
272+
case UR_DEVICE_INFO_MAX_WORK_GROUPS: {
273+
return ReturnValue(std::numeric_limits<size_t>::max());
274+
}
272275
case UR_DEVICE_INFO_MAX_COMPUTE_QUEUE_INDICES: {
273276
return ReturnValue(static_cast<uint32_t>(1u));
274277
}

source/common/ur_util.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,37 @@ inline ur_result_t exceptionToResult(std::exception_ptr eptr) {
322322
}
323323
}
324324

325+
// Multiply a, b and c, and check for overflow. If overflow occurs, return
326+
// MAX_SIZE_T.
327+
inline size_t multiplyWithOverflowCheck(size_t a, size_t b, size_t c) {
328+
329+
size_t Product = 0;
330+
size_t MaxSizeTVal = std::numeric_limits<size_t>::max();
331+
332+
if (a == 0 || b == 0 || c == 0) {
333+
return 0;
334+
}
335+
336+
#ifndef _MSC_VER
337+
if (__builtin_mul_overflow(a, b, &Product) ||
338+
__builtin_mul_overflow(Product, c, &Product)) {
339+
return MaxSizeTVal; // Overflow occurred, return max possible value.
340+
}
341+
#else
342+
if (b > MaxSizeTVal / a) {
343+
return MaxSizeTVal; // Overflow occurred, return max possible value.
344+
}
345+
Product = a * b;
346+
347+
if (c > MaxSizeTVal / Product) {
348+
return MaxSizeTVal; // Overflow occurred, return max possible value.
349+
}
350+
Product *= c;
351+
#endif
352+
353+
return Product;
354+
}
355+
325356
template <class> inline constexpr bool ur_always_false_t = false;
326357

327358
namespace {

0 commit comments

Comments
 (0)