Skip to content

Commit 2505ea9

Browse files
committed
Add more hardware attribute to support CINN.
1 parent 255718a commit 2505ea9

1 file changed

Lines changed: 29 additions & 0 deletions

File tree

backends/metax_gpu/runtime/runtime.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,16 @@ C_Status GetMaxThreadsPerBlock(const C_Device device,
419419
*threads_per_block = count;
420420
return C_SUCCESS;
421421
}
422+
423+
C_Status GetMaxBlocksPerMultiProcessor(const C_Device device,
424+
size_t *blocks_per_mp) {
425+
int id = device->id;
426+
int count = 0;
427+
cudaError_t status =
428+
cudaDeviceGetAttribute(&count, cudaDevAttrMaxBlocksPerMultiprocessor, id);
429+
*blocks_per_mp = count;
430+
return C_SUCCESS;
431+
}
422432

423433
C_Status GetMaxGridDimSize(const C_Device device,
424434
std::array<unsigned int, 3> *grid_dim_size) {
@@ -436,6 +446,22 @@ C_Status GetMaxGridDimSize(const C_Device device,
436446
return C_SUCCESS;
437447
}
438448

449+
C_Status GetMaxBlockDimSize(const C_Device device,
450+
std::array<unsigned int, 3> *block_dim_size) {
451+
int id = device->id;
452+
std::array<unsigned int, 3> ret = {};
453+
int size;
454+
auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimX, id);
455+
ret[0] = size;
456+
auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimY, id);
457+
ret[1] = size;
458+
auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimZ, id);
459+
ret[2] = size;
460+
461+
*block_dim_size = ret;
462+
return C_SUCCESS;
463+
}
464+
439465
C_Status InitDevice(const C_Device device) {
440466
if (!device || device->id < 0) {
441467
return C_ERROR;
@@ -1469,7 +1495,10 @@ void InitPlugin(CustomRuntimeParams *params) {
14691495
params->interface->get_multi_process = GetMultiProcessors;
14701496
params->interface->get_max_threads_per_mp = GetMaxThreadsPerMultiProcessor;
14711497
params->interface->get_max_threads_per_block = GetMaxThreadsPerBlock;
1498+
params->interface->get_max_shared_mem_per_block = GetMaxSharedMemPerBlock;
1499+
params->interface->get_max_blocks_per_mp = GetMaxBlocksPerMultiProcessor;
14721500
params->interface->get_max_grid_dim_size = GetMaxGridDimSize;
1501+
params->interface->get_max_block_dim_size = GetMaxBlockDimSize;
14731502

14741503
params->interface->init_device = InitDevice;
14751504
params->interface->set_device = SetDevice;

0 commit comments

Comments
 (0)