1919#include <uct/ib/base/ib_verbs.h>
2020#include <uct/ib/mlx5/rc/rc_mlx5.h>
2121#include <uct/cuda/cuda_copy/cuda_copy_md.h>
22+ #include <uct/cuda/base/cuda_ctx.h>
2223#include <uct/cuda/base/cuda_util.h>
2324
2425#include "gpunetio/common/doca_gpunetio_verbs_def.h"
@@ -33,6 +34,7 @@ typedef struct {
3334 uct_rc_iface_common_config_t super ;
3435 uct_rc_mlx5_iface_common_config_t mlx5 ;
3536 unsigned num_channels ;
37+ int retain_primary_ctx ;
3638} uct_rc_gdaki_iface_config_t ;
3739
3840ucs_config_field_t uct_rc_gdaki_iface_config_table [] = {
@@ -50,6 +52,11 @@ ucs_config_field_t uct_rc_gdaki_iface_config_table[] = {
5052 ucs_offsetof (uct_rc_gdaki_iface_config_t , num_channels ),
5153 UCS_CONFIG_TYPE_UINT },
5254
55+ {"RETAIN_PRIMARY_CTX" , "n" ,
56+ "Retain and use an inactive CUDA primary context for memory allocation" ,
57+ ucs_offsetof (uct_rc_gdaki_iface_config_t , retain_primary_ctx ),
58+ UCS_CONFIG_TYPE_BOOL },
59+
5360 {NULL }
5461};
5562
@@ -763,17 +770,13 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_iface_t, uct_md_h tl_md,
763770 return status ;
764771 }
765772
766- status = UCT_CUDADRV_FUNC_LOG_ERR (
767- cuDevicePrimaryCtxRetain (& self -> cuda_ctx , self -> cuda_dev ));
773+ status = uct_cuda_ctx_primary_push (self -> cuda_dev ,
774+ config -> retain_primary_ctx ,
775+ UCS_LOG_LEVEL_ERROR );
768776 if (status != UCS_OK ) {
769777 return status ;
770778 }
771779
772- status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (self -> cuda_ctx ));
773- if (status != UCS_OK ) {
774- goto err_ctx_release ;
775- }
776-
777780 status = uct_rc_gdaki_alloc (sizeof (uint64_t ), sizeof (uint64_t ),
778781 (void * * )& self -> atomic_buff , & self -> atomic_raw );
779782 if (status != UCS_OK ) {
@@ -797,19 +800,17 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_iface_t, uct_md_h tl_md,
797800err_lock :
798801 ibv_dereg_mr (self -> atomic_mr );
799802err_atomic :
800- cuMemFree (self -> atomic_raw );
803+ ( void ) UCT_CUDADRV_FUNC_LOG_WARN ( cuMemFree (self -> atomic_raw ) );
801804err_ctx :
802- (void )UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
803- err_ctx_release :
804- (void )UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (self -> cuda_dev ));
805+ uct_cuda_ctx_primary_pop_and_release (self -> cuda_dev );
805806 return status ;
806807}
807808
808809static UCS_CLASS_CLEANUP_FUNC (uct_rc_gdaki_iface_t )
809810{
810811 pthread_mutex_destroy (& self -> ep_init_lock );
811812 ibv_dereg_mr (self -> atomic_mr );
812- cuMemFree (self -> atomic_raw );
813+ ( void ) UCT_CUDADRV_FUNC_LOG_WARN ( cuMemFree (self -> atomic_raw ) );
813814 (void )UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (self -> cuda_dev ));
814815}
815816
0 commit comments