@@ -419,6 +419,16 @@ C_Status GetMaxThreadsPerBlock(const C_Device device,
419419 *threads_per_block = count;
420420 return C_SUCCESS;
421421}
422+
423+ C_Status GetMaxBlocksPerMultiProcessor (const C_Device device,
424+ size_t *blocks_per_mp) {
425+ int id = device->id ;
426+ int count = 0 ;
427+ cudaError_t status =
428+ cudaDeviceGetAttribute (&count, cudaDevAttrMaxBlocksPerMultiprocessor, id);
429+ *blocks_per_mp = count;
430+ return C_SUCCESS;
431+ }
422432
423433C_Status GetMaxGridDimSize (const C_Device device,
424434 std::array<unsigned int , 3 > *grid_dim_size) {
@@ -436,6 +446,22 @@ C_Status GetMaxGridDimSize(const C_Device device,
436446 return C_SUCCESS;
437447}
438448
449+ C_Status GetMaxBlockDimSize (const C_Device device,
450+ std::array<unsigned int , 3 > *block_dim_size) {
451+ int id = device->id ;
452+ std::array<unsigned int , 3 > ret = {};
453+ int size;
454+ auto error_code_x = cudaDeviceGetAttribute (&size, cudaDevAttrMaxBlockDimX, id);
455+ ret[0 ] = size;
456+ auto error_code_y = cudaDeviceGetAttribute (&size, cudaDevAttrMaxBlockDimY, id);
457+ ret[1 ] = size;
458+ auto error_code_z = cudaDeviceGetAttribute (&size, cudaDevAttrMaxBlockDimZ, id);
459+ ret[2 ] = size;
460+
461+ *block_dim_size = ret;
462+ return C_SUCCESS;
463+ }
464+
439465C_Status InitDevice (const C_Device device) {
440466 if (!device || device->id < 0 ) {
441467 return C_ERROR;
@@ -1469,7 +1495,10 @@ void InitPlugin(CustomRuntimeParams *params) {
14691495 params->interface ->get_multi_process = GetMultiProcessors;
14701496 params->interface ->get_max_threads_per_mp = GetMaxThreadsPerMultiProcessor;
14711497 params->interface ->get_max_threads_per_block = GetMaxThreadsPerBlock;
1498+ params->interface ->get_max_shared_mem_per_block = GetMaxSharedMemPerBlock;
1499+ params->interface ->get_max_blocks_per_mp = GetMaxBlocksPerMultiProcessor;
14721500 params->interface ->get_max_grid_dim_size = GetMaxGridDimSize;
1501+ params->interface ->get_max_block_dim_size = GetMaxBlockDimSize;
14731502
14741503 params->interface ->init_device = InitDevice;
14751504 params->interface ->set_device = SetDevice;
0 commit comments