diff --git a/src/uct/ib/mlx5/gdaki/gdaki.c b/src/uct/ib/mlx5/gdaki/gdaki.c index 414b5c9c856..bad16e53b8b 100644 --- a/src/uct/ib/mlx5/gdaki/gdaki.c +++ b/src/uct/ib/mlx5/gdaki/gdaki.c @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "gpunetio/common/doca_gpunetio_verbs_def.h" @@ -33,6 +34,7 @@ typedef struct { uct_rc_iface_common_config_t super; uct_rc_mlx5_iface_common_config_t mlx5; unsigned num_channels; + int retain_primary_ctx; } uct_rc_gdaki_iface_config_t; ucs_config_field_t uct_rc_gdaki_iface_config_table[] = { @@ -50,6 +52,11 @@ ucs_config_field_t uct_rc_gdaki_iface_config_table[] = { ucs_offsetof(uct_rc_gdaki_iface_config_t, num_channels), UCS_CONFIG_TYPE_UINT}, + {"RETAIN_PRIMARY_CTX", "n", + "Retain and use an inactive CUDA primary context for memory allocation", + ucs_offsetof(uct_rc_gdaki_iface_config_t, retain_primary_ctx), + UCS_CONFIG_TYPE_BOOL}, + {NULL} }; @@ -763,17 +770,13 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_iface_t, uct_md_h tl_md, return status; } - status = UCT_CUDADRV_FUNC_LOG_ERR( - cuDevicePrimaryCtxRetain(&self->cuda_ctx, self->cuda_dev)); + status = uct_cuda_ctx_primary_push(self->cuda_dev, + config->retain_primary_ctx, + UCS_LOG_LEVEL_ERROR); if (status != UCS_OK) { return status; } - status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxPushCurrent(self->cuda_ctx)); - if (status != UCS_OK) { - goto err_ctx_release; - } - status = uct_rc_gdaki_alloc(sizeof(uint64_t), sizeof(uint64_t), (void**)&self->atomic_buff, &self->atomic_raw); if (status != UCS_OK) { @@ -797,11 +800,9 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_iface_t, uct_md_h tl_md, err_lock: ibv_dereg_mr(self->atomic_mr); err_atomic: - cuMemFree(self->atomic_raw); + (void)UCT_CUDADRV_FUNC_LOG_WARN(cuMemFree(self->atomic_raw)); err_ctx: - (void)UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL)); -err_ctx_release: - (void)UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(self->cuda_dev)); + uct_cuda_ctx_primary_pop_and_release(self->cuda_dev); return status; } @@ -809,7 +810,7 @@ static UCS_CLASS_CLEANUP_FUNC(uct_rc_gdaki_iface_t) { pthread_mutex_destroy(&self->ep_init_lock); ibv_dereg_mr(self->atomic_mr); - cuMemFree(self->atomic_raw); + (void)UCT_CUDADRV_FUNC_LOG_WARN(cuMemFree(self->atomic_raw)); (void)UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(self->cuda_dev)); }