|
9 | 9 | #include "components/mc/ucc_mc.h" |
10 | 10 | #include "components/ec/ucc_ec.h" |
11 | 11 | #include "core/ucc_ee.h" |
| 12 | +#include "core/ucc_context.h" |
12 | 13 | #include "utils/ucc_compiler_def.h" |
13 | 14 | #include "utils/ucc_math.h" |
14 | 15 | #include "utils/ucc_coll_utils.h" |
@@ -127,6 +128,164 @@ static inline ucc_status_t ucc_tl_nccl_check_and_convert_buffer_reduction( |
127 | 128 | return UCC_OK; |
128 | 129 | } |
129 | 130 |
|
| 131 | +#if NCCL_HAS_UBR |
| 132 | +/* Helper function to lazily register a memory region with NCCL communicator */ |
| 133 | +static inline ucc_status_t ucc_tl_nccl_lazy_register_memh( |
| 134 | + void *buffer, size_t length, ucc_tl_nccl_team_t *team, |
| 135 | + ucc_mem_map_mem_h memh) |
| 136 | +{ |
| 137 | + ucc_tl_nccl_context_t *ctx = UCC_TL_NCCL_TEAM_CTX(team); |
| 138 | + ucc_tl_nccl_memh_data_t *m_data; |
| 139 | + ucc_mem_map_memh_t *mem_handle; |
| 140 | + ncclResult_t nccl_status; |
| 141 | + ncclComm_t *new_comms; |
| 142 | + void **new_handles; |
| 143 | + void *nccl_handle; |
| 144 | + int i, new_max; |
| 145 | + uintptr_t buf_start, buf_end, region_start, region_end; |
| 146 | + |
| 147 | + /* Skip if UBR is not available or memh not provided */ |
| 148 | + if (!ctx->ubr_available || !memh) { |
| 149 | + return UCC_OK; |
| 150 | + } |
| 151 | + |
| 152 | + mem_handle = (ucc_mem_map_memh_t *)memh; |
| 153 | + m_data = NULL; |
| 154 | + for (i = 0; i < mem_handle->num_tls; i++) { |
| 155 | + if (strcmp(mem_handle->tl_h[i].tl_name, "nccl") == 0) { |
| 156 | + m_data = (ucc_tl_nccl_memh_data_t *)mem_handle->tl_h[i].tl_data; |
| 157 | + break; |
| 158 | + } |
| 159 | + } |
| 160 | + |
| 161 | + if (!m_data) { |
| 162 | + /* No NCCL memh data - buffer not registered with TL/NCCL */ |
| 163 | + return UCC_OK; |
| 164 | + } |
| 165 | + |
| 166 | + if (length > (UINTPTR_MAX - (uintptr_t)buffer)) { |
| 167 | + tl_error(UCC_TL_TEAM_LIB(team), "NCCL UBR: buffer size causes overflow"); |
| 168 | + return UCC_ERR_INVALID_PARAM; |
| 169 | + } |
| 170 | + |
| 171 | + /* Verify that the entire buffer is within the registered memory region */ |
| 172 | + buf_start = (uintptr_t)buffer; |
| 173 | + buf_end = buf_start + length; |
| 174 | + region_start = (uintptr_t)m_data->address; |
| 175 | + region_end = region_start + m_data->length; |
| 176 | + |
| 177 | + if (buf_start < region_start || buf_end > region_end) { |
| 178 | + tl_error( |
| 179 | + UCC_TL_TEAM_LIB(team), |
| 180 | + "NCCL UBR: buffer [%p, %p) is outside registered region [%p, %p)", |
| 181 | + buffer, |
| 182 | + (void *)buf_end, |
| 183 | + m_data->address, |
| 184 | + (void *)region_end); |
| 185 | + return UCC_ERR_INVALID_PARAM; |
| 186 | + } |
| 187 | + |
| 188 | + /* Verify team communicator is initialized */ |
| 189 | + if (!team->nccl_comm) { |
| 190 | + tl_debug(UCC_TL_TEAM_LIB(team), |
| 191 | + "NCCL UBR: communicator not initialized, skipping registration"); |
| 192 | + return UCC_OK; |
| 193 | + } |
| 194 | + |
| 195 | + ucc_spin_lock(&m_data->lock); |
| 196 | + |
| 197 | + /* Check if already registered with this communicator */ |
| 198 | + for (i = 0; i < m_data->num_comms; i++) { |
| 199 | + if (m_data->registered_comms[i] == team->nccl_comm) { |
| 200 | + /* Already registered */ |
| 201 | + ucc_spin_unlock(&m_data->lock); |
| 202 | + return UCC_OK; |
| 203 | + } |
| 204 | + } |
| 205 | + |
| 206 | + /* Need to register the memory region with this communicator. |
| 207 | + * Release the lock while calling into NCCL (potentially slow), then |
| 208 | + * re-acquire to update the bookkeeping arrays. */ |
| 209 | + ucc_spin_unlock(&m_data->lock); |
| 210 | + |
| 211 | + nccl_status = ncclCommRegister( |
| 212 | + team->nccl_comm, m_data->address, m_data->length, &nccl_handle); |
| 213 | + if (nccl_status != ncclSuccess) { |
| 214 | + tl_warn( |
| 215 | + UCC_TL_TEAM_LIB(team), |
| 216 | + "NCCL UBR: failed to register region %p, size %zu: %s", |
| 217 | + m_data->address, |
| 218 | + m_data->length, |
| 219 | + ncclGetErrorString(nccl_status)); |
| 220 | + /* Don't fail - UBR is an optimization */ |
| 221 | + return UCC_OK; |
| 222 | + } |
| 223 | + |
| 224 | + ucc_spin_lock(&m_data->lock); |
| 225 | + |
| 226 | + /* Re-check after re-acquiring the lock: another thread may have registered |
| 227 | + * while we were in ncclCommRegister. */ |
| 228 | + for (i = 0; i < m_data->num_comms; i++) { |
| 229 | + if (m_data->registered_comms[i] == team->nccl_comm) { |
| 230 | + ucc_spin_unlock(&m_data->lock); |
| 231 | + /* Undo the redundant registration */ |
| 232 | + ncclCommDeregister(team->nccl_comm, nccl_handle); |
| 233 | + return UCC_OK; |
| 234 | + } |
| 235 | + } |
| 236 | + |
| 237 | + /* Add this comm and handle to the registered lists */ |
| 238 | + if (m_data->num_comms >= m_data->max_comms) { |
| 239 | + /* Need to grow the arrays. Allocate both new buffers before touching |
| 240 | + * m_data so that a failure on either leaves the struct untouched. */ |
| 241 | + new_max = (m_data->max_comms == 0) ? 4 : (m_data->max_comms * 2); |
| 242 | + new_comms = (ncclComm_t *)ucc_malloc(new_max * sizeof(ncclComm_t), |
| 243 | + "nccl_registered_comms"); |
| 244 | + new_handles = (void **)ucc_malloc(new_max * sizeof(void *), |
| 245 | + "nccl_handles"); |
| 246 | + if (!new_comms || !new_handles) { |
| 247 | + ucc_free(new_comms); |
| 248 | + ucc_free(new_handles); |
| 249 | + ucc_spin_unlock(&m_data->lock); |
| 250 | + tl_error(UCC_TL_TEAM_LIB(team), |
| 251 | + "failed to grow NCCL UBR bookkeeping arrays"); |
| 252 | + ncclCommDeregister(team->nccl_comm, nccl_handle); |
| 253 | + return UCC_ERR_NO_MEMORY; |
| 254 | + } |
| 255 | + /* Copy existing entries into the new buffers, then swap. */ |
| 256 | + if (m_data->num_comms > 0) { |
| 257 | + memcpy(new_comms, m_data->registered_comms, |
| 258 | + m_data->num_comms * sizeof(ncclComm_t)); |
| 259 | + memcpy(new_handles, m_data->nccl_handles, |
| 260 | + m_data->num_comms * sizeof(void *)); |
| 261 | + } |
| 262 | + ucc_free(m_data->registered_comms); |
| 263 | + ucc_free(m_data->nccl_handles); |
| 264 | + m_data->registered_comms = new_comms; |
| 265 | + m_data->nccl_handles = new_handles; |
| 266 | + m_data->max_comms = new_max; |
| 267 | + } |
| 268 | + |
| 269 | + m_data->registered_comms[m_data->num_comms] = team->nccl_comm; |
| 270 | + m_data->nccl_handles[m_data->num_comms] = nccl_handle; |
| 271 | + m_data->num_comms++; |
| 272 | + |
| 273 | + ucc_spin_unlock(&m_data->lock); |
| 274 | + |
| 275 | + tl_debug( |
| 276 | + UCC_TL_TEAM_LIB(team), |
| 277 | + "NCCL UBR: lazily registered region %p, size %zu with comm %p " |
| 278 | + "(for buffer [%p, %p))", |
| 279 | + m_data->address, |
| 280 | + m_data->length, |
| 281 | + team->nccl_comm, |
| 282 | + buffer, |
| 283 | + (void *)buf_end); |
| 284 | + |
| 285 | + return UCC_OK; |
| 286 | +} |
| 287 | +#endif |
| 288 | + |
130 | 289 | ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args, |
131 | 290 | ucc_base_team_t *team, |
132 | 291 | ucc_tl_nccl_task_t **coll_task) |
@@ -176,6 +335,102 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args, |
176 | 335 | } |
177 | 336 | } |
178 | 337 |
|
| 338 | +#if NCCL_HAS_UBR |
| 339 | + /* Lazily register memory regions if they were pre-mapped and UBR is enabled */ |
| 340 | + if (nccl_ctx->ubr_available) { |
| 341 | + ucc_mem_map_mem_h src_memh = NULL; |
| 342 | + ucc_mem_map_mem_h dst_memh = NULL; |
| 343 | + ucc_rank_t grank = UCC_TL_TEAM_RANK(nccl_team); |
| 344 | + ucc_count_t total_count; |
| 345 | + ucc_datatype_t dt; |
| 346 | + int is_src_v_type; |
| 347 | + int is_dst_v_type; |
| 348 | + |
| 349 | + /* SCATTERV: src is info_v (root distributes variable chunks), |
| 350 | + * dst is info (each rank receives one contiguous chunk). |
| 351 | + * GATHERV: src is info (each rank contributes one contiguous chunk), |
| 352 | + * dst is info_v (root collects variable chunks). |
| 353 | + * ALLGATHERV: src is info, dst is info_v. |
| 354 | + * ALLTOALLV: src is info_v, dst is info_v. */ |
| 355 | + is_src_v_type = (coll_args->args.coll_type == UCC_COLL_TYPE_ALLTOALLV || |
| 356 | + coll_args->args.coll_type == UCC_COLL_TYPE_SCATTERV); |
| 357 | + is_dst_v_type = (coll_args->args.coll_type == UCC_COLL_TYPE_ALLTOALLV || |
| 358 | + coll_args->args.coll_type == UCC_COLL_TYPE_ALLGATHERV || |
| 359 | + coll_args->args.coll_type == UCC_COLL_TYPE_GATHERV); |
| 360 | + |
| 361 | + /* Register source buffer's memory region if memh provided */ |
| 362 | + if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH) { |
| 363 | + /* Check if global or local memh */ |
| 364 | + if ((coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) && |
| 365 | + (coll_args->args.flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL)) { |
| 366 | + src_memh = coll_args->args.src_memh.global_memh[grank]; |
| 367 | + } else { |
| 368 | + src_memh = coll_args->args.src_memh.local_memh; |
| 369 | + } |
| 370 | + |
| 371 | + if (is_src_v_type) { |
| 372 | + total_count = ucc_coll_args_get_v_buffer_size( |
| 373 | + &coll_args->args, |
| 374 | + coll_args->args.src.info_v.counts, |
| 375 | + coll_args->args.src.info_v.displacements, |
| 376 | + UCC_TL_TEAM_SIZE(nccl_team)); |
| 377 | + dt = coll_args->args.src.info_v.datatype; |
| 378 | + } else { |
| 379 | + total_count = coll_args->args.src.info.count; |
| 380 | + dt = coll_args->args.src.info.datatype; |
| 381 | + } |
| 382 | + status = ucc_tl_nccl_lazy_register_memh( |
| 383 | + is_src_v_type ? coll_args->args.src.info_v.buffer |
| 384 | + : coll_args->args.src.info.buffer, |
| 385 | + total_count * ucc_dt_size(dt), |
| 386 | + nccl_team, |
| 387 | + src_memh); |
| 388 | + if (ucc_unlikely(status != UCC_OK)) { |
| 389 | + tl_error( |
| 390 | + team->context->lib, |
| 391 | + "NCCL UBR: lazy_register failed with status %d", |
| 392 | + status); |
| 393 | + ucc_tl_nccl_free_task(task); |
| 394 | + return status; |
| 395 | + } |
| 396 | + } |
| 397 | + |
| 398 | + /* Register destination buffer's memory region if memh provided */ |
| 399 | + if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH) { |
| 400 | + /* Check if global or local memh */ |
| 401 | + if ((coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) && |
| 402 | + (coll_args->args.flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)) { |
| 403 | + dst_memh = coll_args->args.dst_memh.global_memh[grank]; |
| 404 | + } else { |
| 405 | + dst_memh = coll_args->args.dst_memh.local_memh; |
| 406 | + } |
| 407 | + |
| 408 | + if (is_dst_v_type) { |
| 409 | + total_count = ucc_coll_args_get_v_buffer_size( |
| 410 | + &coll_args->args, |
| 411 | + coll_args->args.dst.info_v.counts, |
| 412 | + coll_args->args.dst.info_v.displacements, |
| 413 | + UCC_TL_TEAM_SIZE(nccl_team)); |
| 414 | + dt = coll_args->args.dst.info_v.datatype; |
| 415 | + } else { |
| 416 | + total_count = coll_args->args.dst.info.count; |
| 417 | + dt = coll_args->args.dst.info.datatype; |
| 418 | + } |
| 419 | + |
| 420 | + status = ucc_tl_nccl_lazy_register_memh( |
| 421 | + is_dst_v_type ? coll_args->args.dst.info_v.buffer |
| 422 | + : coll_args->args.dst.info.buffer, |
| 423 | + total_count * ucc_dt_size(dt), |
| 424 | + nccl_team, |
| 425 | + dst_memh); |
| 426 | + if (ucc_unlikely(status != UCC_OK)) { |
| 427 | + ucc_tl_nccl_free_task(task); |
| 428 | + return status; |
| 429 | + } |
| 430 | + } |
| 431 | + } |
| 432 | +#endif |
| 433 | + |
179 | 434 | *coll_task = task; |
180 | 435 | return UCC_OK; |
181 | 436 | } |
|
0 commit comments