Skip to content

Commit a54312b

Browse files
wfaderhold21janjust
authored andcommitted
TL/NCCL: add user buffer registration via memmap
1 parent 081bb98 commit a54312b

8 files changed

Lines changed: 804 additions & 12 deletions

File tree

src/components/tl/nccl/tl_nccl.c

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = {
4545
{"SYNC", "auto",
4646
"Determines how UCC tests completion of NCCL collective",
4747
ucs_offsetof(ucc_tl_nccl_context_config_t, sync_type),
48-
UCS_CONFIG_TYPE_ENUM(ucc_tl_nccl_completion_sync_names)
49-
},
48+
UCS_CONFIG_TYPE_ENUM(ucc_tl_nccl_completion_sync_names)},
5049

5150
{"BLOCKING", "yes",
5251
"If set to no will use non-blocking mode communicator behavior, "
@@ -59,6 +58,12 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = {
5958
ucc_offsetof(ucc_tl_nccl_context_config_t, nccl_lazy_init),
6059
UCC_CONFIG_TYPE_BOOL},
6160

61+
{"ENABLE_UBR", "try",
62+
"Enable NCCL User Buffer Registration for zero-copy operations. "
63+
"Requires NCCL v2.19+.",
64+
ucc_offsetof(ucc_tl_nccl_context_config_t, enable_ubr),
65+
UCC_CONFIG_TYPE_TERNARY},
66+
6267
{NULL}};
6368

6469
UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_nccl_lib_t, ucc_base_lib_t,

src/components/tl/nccl/tl_nccl.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "components/tl/ucc_tl.h"
1313
#include "components/tl/ucc_tl_log.h"
1414
#include "utils/ucc_mpool.h"
15+
#include "utils/ucc_list.h"
16+
#include "utils/ucc_spinlock.h"
1517

1618
#include <cuda_runtime.h>
1719
#if CUDART_VERSION >= 11000
@@ -44,6 +46,8 @@
4446
#define UCC_TL_NCCL_PROFILE_REQUEST_FREE UCC_PROFILE_REQUEST_FREE
4547
#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
4648
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB
49+
#define NCCL_VERSION_UBR NCCL_VERSION(2,19,0)
50+
#define NCCL_HAS_UBR (NCCL_VERSION_CODE >= NCCL_VERSION_UBR)
4751

4852
enum {
4953
TL_NCCL_COMM_STATE_ERROR,
@@ -76,6 +80,7 @@ typedef struct ucc_tl_nccl_context_config {
7680
ucc_tl_nccl_completion_sync_type_t sync_type;
7781
int nccl_cfg_blocking;
7882
int nccl_lazy_init;
83+
int enable_ubr;
7984
} ucc_tl_nccl_context_config_t;
8085

8186
typedef struct ucc_tl_nccl_lib {
@@ -84,11 +89,33 @@ typedef struct ucc_tl_nccl_lib {
8489
UCC_CLASS_DECLARE(ucc_tl_nccl_lib_t, const ucc_base_lib_params_t *,
8590
const ucc_base_config_t *);
8691

92+
/* Data structure for NCCL memory handle with lazy registration support */
93+
typedef struct ucc_tl_nccl_memh_data {
94+
void *address;
95+
size_t length;
96+
/* Array of NCCL comms this buffer is registered with (NULL = invalidated) */
97+
ncclComm_t *registered_comms;
98+
/* Array of NCCL handles from ncclCommRegister */
99+
void **nccl_handles;
100+
/* Number of comms in the array */
101+
int num_comms;
102+
/* Allocated size of the array */
103+
int max_comms;
104+
/* Protects registered_comms/nccl_handles/num_comms/max_comms */
105+
ucc_spinlock_t lock;
106+
/* Linked into ucc_tl_nccl_context_t::memh_list */
107+
ucc_list_link_t list_elem;
108+
} ucc_tl_nccl_memh_data_t;
109+
87110
typedef struct ucc_tl_nccl_context {
88111
ucc_tl_context_t super;
89112
ucc_tl_nccl_context_config_t cfg;
90113
ucc_mpool_t req_mp;
91114
void *scratch_buf;
115+
int ubr_available;
116+
/* List of all live ucc_tl_nccl_memh_data_t objects; protected by memh_lock */
117+
ucc_list_link_t memh_list;
118+
ucc_spinlock_t memh_lock;
92119
} ucc_tl_nccl_context_t;
93120
UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *,
94121
const ucc_base_config_t *);

src/components/tl/nccl/tl_nccl_coll.c

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "components/mc/ucc_mc.h"
1010
#include "components/ec/ucc_ec.h"
1111
#include "core/ucc_ee.h"
12+
#include "core/ucc_context.h"
1213
#include "utils/ucc_compiler_def.h"
1314
#include "utils/ucc_math.h"
1415
#include "utils/ucc_coll_utils.h"
@@ -127,6 +128,164 @@ static inline ucc_status_t ucc_tl_nccl_check_and_convert_buffer_reduction(
127128
return UCC_OK;
128129
}
129130

131+
#if NCCL_HAS_UBR
132+
/* Helper function to lazily register a memory region with NCCL communicator */
133+
static inline ucc_status_t ucc_tl_nccl_lazy_register_memh(
134+
void *buffer, size_t length, ucc_tl_nccl_team_t *team,
135+
ucc_mem_map_mem_h memh)
136+
{
137+
ucc_tl_nccl_context_t *ctx = UCC_TL_NCCL_TEAM_CTX(team);
138+
ucc_tl_nccl_memh_data_t *m_data;
139+
ucc_mem_map_memh_t *mem_handle;
140+
ncclResult_t nccl_status;
141+
ncclComm_t *new_comms;
142+
void **new_handles;
143+
void *nccl_handle;
144+
int i, new_max;
145+
uintptr_t buf_start, buf_end, region_start, region_end;
146+
147+
/* Skip if UBR is not available or memh not provided */
148+
if (!ctx->ubr_available || !memh) {
149+
return UCC_OK;
150+
}
151+
152+
mem_handle = (ucc_mem_map_memh_t *)memh;
153+
m_data = NULL;
154+
for (i = 0; i < mem_handle->num_tls; i++) {
155+
if (strcmp(mem_handle->tl_h[i].tl_name, "nccl") == 0) {
156+
m_data = (ucc_tl_nccl_memh_data_t *)mem_handle->tl_h[i].tl_data;
157+
break;
158+
}
159+
}
160+
161+
if (!m_data) {
162+
/* No NCCL memh data - buffer not registered with TL/NCCL */
163+
return UCC_OK;
164+
}
165+
166+
if (length > (UINTPTR_MAX - (uintptr_t)buffer)) {
167+
tl_error(UCC_TL_TEAM_LIB(team), "NCCL UBR: buffer size causes overflow");
168+
return UCC_ERR_INVALID_PARAM;
169+
}
170+
171+
/* Verify that the entire buffer is within the registered memory region */
172+
buf_start = (uintptr_t)buffer;
173+
buf_end = buf_start + length;
174+
region_start = (uintptr_t)m_data->address;
175+
region_end = region_start + m_data->length;
176+
177+
if (buf_start < region_start || buf_end > region_end) {
178+
tl_error(
179+
UCC_TL_TEAM_LIB(team),
180+
"NCCL UBR: buffer [%p, %p) is outside registered region [%p, %p)",
181+
buffer,
182+
(void *)buf_end,
183+
m_data->address,
184+
(void *)region_end);
185+
return UCC_ERR_INVALID_PARAM;
186+
}
187+
188+
/* Verify team communicator is initialized */
189+
if (!team->nccl_comm) {
190+
tl_debug(UCC_TL_TEAM_LIB(team),
191+
"NCCL UBR: communicator not initialized, skipping registration");
192+
return UCC_OK;
193+
}
194+
195+
ucc_spin_lock(&m_data->lock);
196+
197+
/* Check if already registered with this communicator */
198+
for (i = 0; i < m_data->num_comms; i++) {
199+
if (m_data->registered_comms[i] == team->nccl_comm) {
200+
/* Already registered */
201+
ucc_spin_unlock(&m_data->lock);
202+
return UCC_OK;
203+
}
204+
}
205+
206+
/* Need to register the memory region with this communicator.
207+
* Release the lock while calling into NCCL (potentially slow), then
208+
* re-acquire to update the bookkeeping arrays. */
209+
ucc_spin_unlock(&m_data->lock);
210+
211+
nccl_status = ncclCommRegister(
212+
team->nccl_comm, m_data->address, m_data->length, &nccl_handle);
213+
if (nccl_status != ncclSuccess) {
214+
tl_warn(
215+
UCC_TL_TEAM_LIB(team),
216+
"NCCL UBR: failed to register region %p, size %zu: %s",
217+
m_data->address,
218+
m_data->length,
219+
ncclGetErrorString(nccl_status));
220+
/* Don't fail - UBR is an optimization */
221+
return UCC_OK;
222+
}
223+
224+
ucc_spin_lock(&m_data->lock);
225+
226+
/* Re-check after re-acquiring the lock: another thread may have registered
227+
* while we were in ncclCommRegister. */
228+
for (i = 0; i < m_data->num_comms; i++) {
229+
if (m_data->registered_comms[i] == team->nccl_comm) {
230+
ucc_spin_unlock(&m_data->lock);
231+
/* Undo the redundant registration */
232+
ncclCommDeregister(team->nccl_comm, nccl_handle);
233+
return UCC_OK;
234+
}
235+
}
236+
237+
/* Add this comm and handle to the registered lists */
238+
if (m_data->num_comms >= m_data->max_comms) {
239+
/* Need to grow the arrays. Allocate both new buffers before touching
240+
* m_data so that a failure on either leaves the struct untouched. */
241+
new_max = (m_data->max_comms == 0) ? 4 : (m_data->max_comms * 2);
242+
new_comms = (ncclComm_t *)ucc_malloc(new_max * sizeof(ncclComm_t),
243+
"nccl_registered_comms");
244+
new_handles = (void **)ucc_malloc(new_max * sizeof(void *),
245+
"nccl_handles");
246+
if (!new_comms || !new_handles) {
247+
ucc_free(new_comms);
248+
ucc_free(new_handles);
249+
ucc_spin_unlock(&m_data->lock);
250+
tl_error(UCC_TL_TEAM_LIB(team),
251+
"failed to grow NCCL UBR bookkeeping arrays");
252+
ncclCommDeregister(team->nccl_comm, nccl_handle);
253+
return UCC_ERR_NO_MEMORY;
254+
}
255+
/* Copy existing entries into the new buffers, then swap. */
256+
if (m_data->num_comms > 0) {
257+
memcpy(new_comms, m_data->registered_comms,
258+
m_data->num_comms * sizeof(ncclComm_t));
259+
memcpy(new_handles, m_data->nccl_handles,
260+
m_data->num_comms * sizeof(void *));
261+
}
262+
ucc_free(m_data->registered_comms);
263+
ucc_free(m_data->nccl_handles);
264+
m_data->registered_comms = new_comms;
265+
m_data->nccl_handles = new_handles;
266+
m_data->max_comms = new_max;
267+
}
268+
269+
m_data->registered_comms[m_data->num_comms] = team->nccl_comm;
270+
m_data->nccl_handles[m_data->num_comms] = nccl_handle;
271+
m_data->num_comms++;
272+
273+
ucc_spin_unlock(&m_data->lock);
274+
275+
tl_debug(
276+
UCC_TL_TEAM_LIB(team),
277+
"NCCL UBR: lazily registered region %p, size %zu with comm %p "
278+
"(for buffer [%p, %p))",
279+
m_data->address,
280+
m_data->length,
281+
team->nccl_comm,
282+
buffer,
283+
(void *)buf_end);
284+
285+
return UCC_OK;
286+
}
287+
#endif
288+
130289
ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
131290
ucc_base_team_t *team,
132291
ucc_tl_nccl_task_t **coll_task)
@@ -176,6 +335,102 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
176335
}
177336
}
178337

338+
#if NCCL_HAS_UBR
339+
/* Lazily register memory regions if they were pre-mapped and UBR is enabled */
340+
if (nccl_ctx->ubr_available) {
341+
ucc_mem_map_mem_h src_memh = NULL;
342+
ucc_mem_map_mem_h dst_memh = NULL;
343+
ucc_rank_t grank = UCC_TL_TEAM_RANK(nccl_team);
344+
ucc_count_t total_count;
345+
ucc_datatype_t dt;
346+
int is_src_v_type;
347+
int is_dst_v_type;
348+
349+
/* SCATTERV: src is info_v (root distributes variable chunks),
350+
* dst is info (each rank receives one contiguous chunk).
351+
* GATHERV: src is info (each rank contributes one contiguous chunk),
352+
* dst is info_v (root collects variable chunks).
353+
* ALLGATHERV: src is info, dst is info_v.
354+
* ALLTOALLV: src is info_v, dst is info_v. */
355+
is_src_v_type = (coll_args->args.coll_type == UCC_COLL_TYPE_ALLTOALLV ||
356+
coll_args->args.coll_type == UCC_COLL_TYPE_SCATTERV);
357+
is_dst_v_type = (coll_args->args.coll_type == UCC_COLL_TYPE_ALLTOALLV ||
358+
coll_args->args.coll_type == UCC_COLL_TYPE_ALLGATHERV ||
359+
coll_args->args.coll_type == UCC_COLL_TYPE_GATHERV);
360+
361+
/* Register source buffer's memory region if memh provided */
362+
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH) {
363+
/* Check if global or local memh */
364+
if ((coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) &&
365+
(coll_args->args.flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL)) {
366+
src_memh = coll_args->args.src_memh.global_memh[grank];
367+
} else {
368+
src_memh = coll_args->args.src_memh.local_memh;
369+
}
370+
371+
if (is_src_v_type) {
372+
total_count = ucc_coll_args_get_v_buffer_size(
373+
&coll_args->args,
374+
coll_args->args.src.info_v.counts,
375+
coll_args->args.src.info_v.displacements,
376+
UCC_TL_TEAM_SIZE(nccl_team));
377+
dt = coll_args->args.src.info_v.datatype;
378+
} else {
379+
total_count = coll_args->args.src.info.count;
380+
dt = coll_args->args.src.info.datatype;
381+
}
382+
status = ucc_tl_nccl_lazy_register_memh(
383+
is_src_v_type ? coll_args->args.src.info_v.buffer
384+
: coll_args->args.src.info.buffer,
385+
total_count * ucc_dt_size(dt),
386+
nccl_team,
387+
src_memh);
388+
if (ucc_unlikely(status != UCC_OK)) {
389+
tl_error(
390+
team->context->lib,
391+
"NCCL UBR: lazy_register failed with status %d",
392+
status);
393+
ucc_tl_nccl_free_task(task);
394+
return status;
395+
}
396+
}
397+
398+
/* Register destination buffer's memory region if memh provided */
399+
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH) {
400+
/* Check if global or local memh */
401+
if ((coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) &&
402+
(coll_args->args.flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)) {
403+
dst_memh = coll_args->args.dst_memh.global_memh[grank];
404+
} else {
405+
dst_memh = coll_args->args.dst_memh.local_memh;
406+
}
407+
408+
if (is_dst_v_type) {
409+
total_count = ucc_coll_args_get_v_buffer_size(
410+
&coll_args->args,
411+
coll_args->args.dst.info_v.counts,
412+
coll_args->args.dst.info_v.displacements,
413+
UCC_TL_TEAM_SIZE(nccl_team));
414+
dt = coll_args->args.dst.info_v.datatype;
415+
} else {
416+
total_count = coll_args->args.dst.info.count;
417+
dt = coll_args->args.dst.info.datatype;
418+
}
419+
420+
status = ucc_tl_nccl_lazy_register_memh(
421+
is_dst_v_type ? coll_args->args.dst.info_v.buffer
422+
: coll_args->args.dst.info.buffer,
423+
total_count * ucc_dt_size(dt),
424+
nccl_team,
425+
dst_memh);
426+
if (ucc_unlikely(status != UCC_OK)) {
427+
ucc_tl_nccl_free_task(task);
428+
return status;
429+
}
430+
}
431+
}
432+
#endif
433+
179434
*coll_task = task;
180435
return UCC_OK;
181436
}

0 commit comments

Comments
 (0)