From fd2d918b05f734168859bce3151d84fc274ebc58 Mon Sep 17 00:00:00 2001 From: Juee14Desai Date: Wed, 25 Mar 2026 00:57:57 +0200 Subject: [PATCH 1/2] TL/UCP: topo aware ring algo for allgather Replace the default ring allgather with a topo aware multi ring implementation that uses team->cuda_ring to route data along NVLink optimal paths (up to 8 parallel rings). Algorithm changes: - Ring rank, peer, and block indices are now derived from the cuda_ring topology pattern instead of flat team rank ordering. - Each ring transfers its own slice of each block, enabling concurrent data movement across multiple NVLink paths. - Algorithm auto selected for CUDA memory >4KB when cuda_ring is available; falls back to knomial otherwise. Also fixes CUDA primary context detection in ucc_sysinfo_cuda.c and decouples the service allgather from the topo aware ring. Signed-off-by: Juee Himalbhai Desai --- src/components/tl/ucp/allgather/allgather.c | 36 +++- .../tl/ucp/allgather/allgather_ring.c | 156 ++++++++++-------- src/components/tl/ucp/tl_ucp_service_coll.c | 80 ++++++++- src/components/topo/cuda/ucc_sysinfo_cuda.c | 32 +++- 4 files changed, 225 insertions(+), 79 deletions(-) diff --git a/src/components/tl/ucp/allgather/allgather.c b/src/components/tl/ucp/allgather/allgather.c index 43d6a2a0e49..ecad3c43340 100644 --- a/src/components/tl/ucp/allgather/allgather.c +++ b/src/components/tl/ucp/allgather/allgather.c @@ -54,9 +54,7 @@ ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task) char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team) { int max_size = ALLGATHER_MAX_PATTERN_SIZE; - int algo_num = UCC_TL_TEAM_SIZE(team) % 2 - ? UCC_TL_UCP_ALLGATHER_ALG_RING - : UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR; + int algo_num; char * str = ucc_malloc(max_size * sizeof(char)); ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); uint64_t cuda_types = @@ -67,6 +65,10 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team) char * non_cuda_str; char * cuda_str; + algo_num = UCC_TL_TEAM_SIZE(team) % 2 + ? UCC_TL_UCP_ALLGATHER_ALG_RING + : UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR; + if (team->cfg.use_reordering) { sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED); if (!ucc_ep_map_is_identity(&sbgp->map)) { @@ -74,6 +76,10 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team) } } + if (algo_num == UCC_TL_UCP_ALLGATHER_ALG_RING && !team->cuda_ring) { + algo_num = UCC_TL_UCP_ALLGATHER_ALG_KNOMIAL; + } + if (team->topo && ucc_topo_is_single_ppn(team->topo)) { if (cuda_types) { cuda_str = ucc_malloc(max_size * sizeof(char)); @@ -98,6 +104,30 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team) return str; } } + + if (team->cuda_ring && cuda_types) { + cuda_str = ucc_malloc(max_size * sizeof(char)); + ucc_mtype_map_to_str(cuda_types, ",", cuda_str, max_size); + if (non_cuda_types) { + non_cuda_str = ucc_malloc(max_size * sizeof(char)); + ucc_mtype_map_to_str(non_cuda_types, ",", non_cuda_str, max_size); + ucc_snprintf_safe(str, max_size, + "allgather:0-4k:@0#allgather:4k-inf:%s:@%d" + "#allgather:4k-inf:%s:@%d", + cuda_str, UCC_TL_UCP_ALLGATHER_ALG_RING, + non_cuda_str, algo_num); + ucc_free(cuda_str); + ucc_free(non_cuda_str); + return str; + } + ucc_snprintf_safe(str, max_size, + "allgather:0-4k:@0#allgather:4k-inf:%s:@%d" + "#allgather:4k-inf:@%d", + cuda_str, UCC_TL_UCP_ALLGATHER_ALG_RING, algo_num); + ucc_free(cuda_str); + return str; + } + ucc_snprintf_safe(str, max_size, UCC_TL_UCP_ALLGATHER_DEFAULT_ALG_SELECT_STR, algo_num); return str; diff --git a/src/components/tl/ucp/allgather/allgather_ring.c b/src/components/tl/ucp/allgather/allgather_ring.c index 07178aea25e..8621a84235f 100644 --- a/src/components/tl/ucp/allgather/allgather_ring.c +++ b/src/components/tl/ucp/allgather/allgather_ring.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -12,62 +12,78 @@ #include "utils/ucc_math.h" #include "utils/ucc_coll_utils.h" #include "components/mc/ucc_mc.h" +#include "coll_patterns/ring.h" -static ucc_rank_t ucc_tl_ucp_allgather_ring_get_send_block(ucc_subset_t *subset, - ucc_rank_t trank, - ucc_rank_t tsize, - int step) -{ - return ucc_ep_map_eval(subset->map, (trank - step + tsize) % tsize); -} - -static ucc_rank_t ucc_tl_ucp_allgather_ring_get_recv_block(ucc_subset_t *subset, - ucc_rank_t trank, - ucc_rank_t tsize, - int step) -{ - return ucc_ep_map_eval(subset->map, (trank - step - 1 + tsize) % tsize); -} +#define MAX_RINGS 8 void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t trank = task->subset.myrank; - ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; - void *rbuf = TASK_ARGS(task).dst.info.buffer; - ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; - size_t count = TASK_ARGS(task).dst.info.count; - ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; - size_t data_size = (count / tsize) * ucc_dt_size(dt); - ucc_rank_t sendto, recvfrom, sblock, rblock; - int step; - void *buf; + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + void *rbuf = args->dst.info.buffer; + ucc_memory_type_t rmem = args->dst.info.mem_type; + size_t count = args->dst.info.count; + ucc_datatype_t dt = args->dst.info.datatype; + size_t rdt_size = ucc_dt_size(dt); + ucc_ring_pattern_t *ring = team->cuda_ring; + ucc_rank_t ring_id; + ucc_rank_t nrings; + ucc_rank_t rrank; + ucc_rank_t tsize; + ucc_rank_t send_idx, recv_idx, sendto, recvfrom, step; + size_t block_count, ring_offset, ring_count; + size_t data_size, data_displ; + + nrings = ucc_min(MAX_RINGS, ring->num_rings); + tsize = ucc_ring_pattern_size(ring, 0); + block_count = count / tsize; if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } - sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize); - recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize); - - while (task->tagged.send_posted < tsize - 1) { - step = task->tagged.send_posted; - sblock = task->allgather_ring.get_send_block(&task->subset, trank, - tsize, step); - rblock = task->allgather_ring.get_recv_block(&task->subset, trank, - tsize, step); - buf = PTR_OFFSET(rbuf, sblock * data_size); - UCPCHECK_GOTO( - ucc_tl_ucp_send_nb(buf, data_size, rmem, sendto, team, task), - task, out); - buf = PTR_OFFSET(rbuf, rblock * data_size); - UCPCHECK_GOTO( - ucc_tl_ucp_recv_nb(buf, data_size, rmem, recvfrom, team, task), - task, out); + + while (task->tagged.send_posted < 1 + nrings * (tsize - 1)) { + ucc_assert(task->tagged.send_posted > 0); + ucc_assert(task->tagged.recv_posted > 0); + ucc_assert(task->tagged.send_posted == task->tagged.recv_posted); + step = (ucc_rank_t)((task->tagged.send_posted - 1) / nrings); + for (ring_id = 0; ring_id < nrings; ring_id++) { + rrank = ucc_ring_pattern_rank(ring, ring_id); + sendto = ucc_ring_pattern_get_send_peer(ring, ring_id, rrank); + recvfrom = ucc_ring_pattern_get_recv_peer(ring, ring_id, rrank); + + send_idx = ucc_ring_pattern_get_send_block(ring, ring_id, + rrank, step); + ring_offset = ucc_buffer_block_offset(block_count, nrings, + ring_id); + ring_count = ucc_buffer_block_count(block_count, nrings, + ring_id); + data_displ = (send_idx * block_count + ring_offset) * rdt_size; + data_size = ring_count * rdt_size; + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(PTR_OFFSET(rbuf, data_displ), + data_size, rmem, sendto, team, + task), + task, out); + + recv_idx = ucc_ring_pattern_get_recv_block(ring, ring_id, + rrank, step); + ring_offset = ucc_buffer_block_offset(block_count, nrings, + ring_id); + ring_count = ucc_buffer_block_count(block_count, nrings, + ring_id); + data_displ = (recv_idx * block_count + ring_offset) * rdt_size; + data_size = ring_count * rdt_size; + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, data_displ), + data_size, rmem, recvfrom, + team, task), + task, out); + } if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } } + ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); task->super.status = UCC_OK; out: @@ -76,59 +92,53 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - size_t count = TASK_ARGS(task).dst.info.count; - void *sbuf = TASK_ARGS(task).src.info.buffer; - void *rbuf = TASK_ARGS(task).dst.info.buffer; - ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type; - ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; - ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; - ucc_rank_t trank = task->subset.myrank; - ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; - size_t data_size = (count / tsize) * ucc_dt_size(dt); + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_ring_pattern_t *ring = team->cuda_ring; + size_t count = args->dst.info.count; + void *sbuf = args->src.info.buffer; + void *rbuf = args->dst.info.buffer; + ucc_memory_type_t rmem = args->dst.info.mem_type; + ucc_memory_type_t smem = args->src.info.mem_type; + ucc_datatype_t dt = args->dst.info.datatype; + ucc_rank_t tsize = ucc_ring_pattern_size(ring, 0); + ucc_rank_t block = UCC_TL_TEAM_RANK(team); + size_t data_size = (count / tsize) * ucc_dt_size(dt); ucc_status_t status; - ucc_rank_t block; UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_start", 0); ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); - if (!UCC_IS_INPLACE(TASK_ARGS(task))) { - block = task->allgather_ring.get_send_block(&task->subset, trank, tsize, - 0); + if (!UCC_IS_INPLACE(*args)) { status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block), - sbuf, data_size, rmem, smem); + sbuf, data_size, rmem, smem); if (ucc_unlikely(UCC_OK != status)) { return status; } } + task->tagged.send_posted = task->tagged.send_completed = 1; + task->tagged.recv_posted = task->tagged.recv_completed = 1; + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task) { ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_sbgp_t *sbgp; if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) { tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported"); return UCC_ERR_NOT_SUPPORTED; } - if (!(task->flags & UCC_TL_UCP_TASK_FLAG_SUBSET)) { - if (team->cfg.use_reordering) { - sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED); - task->subset.myrank = sbgp->group_rank; - task->subset.map = sbgp->map; - } + if (!team->cuda_ring) { + return UCC_ERR_NOT_SUPPORTED; } - task->allgather_ring.get_send_block = ucc_tl_ucp_allgather_ring_get_send_block; - task->allgather_ring.get_recv_block = ucc_tl_ucp_allgather_ring_get_recv_block; - task->super.post = ucc_tl_ucp_allgather_ring_start; - task->super.progress = ucc_tl_ucp_allgather_ring_progress; - + task->super.post = ucc_tl_ucp_allgather_ring_start; + task->super.progress = ucc_tl_ucp_allgather_ring_progress; return UCC_OK; } diff --git a/src/components/tl/ucp/tl_ucp_service_coll.c b/src/components/tl/ucp/tl_ucp_service_coll.c index 1d93a8cd82e..a167b5024b0 100644 --- a/src/components/tl/ucp/tl_ucp_service_coll.c +++ b/src/components/tl/ucp/tl_ucp_service_coll.c @@ -7,6 +7,7 @@ #include "tl_ucp.h" #include "tl_ucp_coll.h" #include "tl_ucp_tag.h" +#include "tl_ucp_sendrecv.h" #include "allreduce/allreduce.h" #include "allgather/allgather.h" #include "bcast/bcast.h" @@ -29,6 +30,81 @@ static ucc_rank_t ucc_tl_ucp_service_ring_get_recv_block(ucc_subset_t *subset, return (trank - step - 1 + tsize) % tsize; } +static void ucc_tl_ucp_service_allgather_ring_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = task->subset.myrank; + ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; + void *rbuf = TASK_ARGS(task).dst.info.buffer; + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + size_t count = TASK_ARGS(task).dst.info.count; + size_t data_size = (count / tsize) * ucc_dt_size(TASK_ARGS(task).dst.info.datatype); + ucc_rank_t sendto, recvfrom, sblock, rblock; + int step; + + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + return; + } + + sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize); + recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize); + + while (task->tagged.send_posted < tsize - 1) { + step = task->tagged.send_posted; + sblock = task->allgather_ring.get_send_block(&task->subset, trank, + tsize, step); + rblock = task->allgather_ring.get_recv_block(&task->subset, trank, + tsize, step); + UCPCHECK_GOTO( + ucc_tl_ucp_send_nb(PTR_OFFSET(rbuf, sblock * data_size), + data_size, rmem, sendto, team, task), + task, out); + UCPCHECK_GOTO( + ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, rblock * data_size), + data_size, rmem, recvfrom, team, task), + task, out); + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + return; + } + } + + ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); + task->super.status = UCC_OK; +out: + return; +} + +static ucc_status_t ucc_tl_ucp_service_allgather_ring_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = task->subset.myrank; + ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; + size_t count = TASK_ARGS(task).dst.info.count; + void *sbuf = TASK_ARGS(task).src.info.buffer; + void *rbuf = TASK_ARGS(task).dst.info.buffer; + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type; + size_t data_size = (count / tsize) * ucc_dt_size(TASK_ARGS(task).dst.info.datatype); + ucc_rank_t block; + ucc_status_t status; + + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + + if (!UCC_IS_INPLACE(TASK_ARGS(task))) { + block = task->allgather_ring.get_send_block(&task->subset, trank, + tsize, 0); + status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block), + sbuf, data_size, rmem, smem); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + } + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} + static ucc_status_t ucc_tl_ucp_service_coll_start_executor(ucc_coll_task_t *task) { ucc_ee_executor_params_t eparams; @@ -178,10 +254,10 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf, task->subset = subset; task->tagged.tag = UCC_TL_UCP_SERVICE_TAG; task->n_polls = npolls; - task->super.progress = ucc_tl_ucp_allgather_ring_progress; + task->super.progress = ucc_tl_ucp_service_allgather_ring_progress; task->super.finalize = ucc_tl_ucp_coll_finalize; - status = ucc_tl_ucp_allgather_ring_start(&task->super); + status = ucc_tl_ucp_service_allgather_ring_start(&task->super); if (status != UCC_OK) { goto finalize_coll; } diff --git a/src/components/topo/cuda/ucc_sysinfo_cuda.c b/src/components/topo/cuda/ucc_sysinfo_cuda.c index 1e0856e5a13..1c30cb5f633 100644 --- a/src/components/topo/cuda/ucc_sysinfo_cuda.c +++ b/src/components/topo/cuda/ucc_sysinfo_cuda.c @@ -70,8 +70,35 @@ static ucc_status_t ucc_sysinfo_cuda_set_visible_devices( cuGetErrorString(cu_st, &cu_err_str); ucc_debug("cuCtxGetCurrent failed: %d (%s)", cu_st, cu_err_str ? cu_err_str : "unknown"); + return UCC_OK; + } + { + int num_devs = 0; + int d; + unsigned int flags; + int active; + + cu_st = cuDeviceGetCount(&num_devs); + if (cu_st != CUDA_SUCCESS || num_devs == 0) { + ucc_debug("no CUDA context and cuDeviceGetCount=%d", num_devs); + return UCC_OK; + } + for (d = 0; d < num_devs; d++) { + CUdevice dev; + + cu_st = cuDeviceGet(&dev, d); + if (cu_st != CUDA_SUCCESS) { + continue; + } + cu_st = cuDevicePrimaryCtxGetState(dev, &flags, &active); + if (cu_st == CUDA_SUCCESS && active) { + cu_dev = dev; + goto have_device; + } + } + ucc_debug("no active CUDA primary context found on any device"); + return UCC_OK; } - return UCC_OK; } cu_st = cuCtxGetDevice(&cu_dev); @@ -84,6 +111,7 @@ static ucc_status_t ucc_sysinfo_cuda_set_visible_devices( return UCC_ERR_NO_MESSAGE; } +have_device: cu_st = cuDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), cu_dev); if (cu_st != CUDA_SUCCESS) { const char *cu_err_str = NULL; @@ -114,6 +142,8 @@ static ucc_status_t ucc_sysinfo_cuda_set_visible_devices( } } + ucc_debug("set_visible_devices: pci=%s dev=%d visible_gpus=0x%x", + pci_bus_id, (int)cu_dev, *visible_devices); return UCC_OK; } From 9342df058d0af9b3141e21023c924c0cda1e2f59 Mon Sep 17 00:00:00 2001 From: Juee14Desai Date: Wed, 25 Mar 2026 00:58:11 +0200 Subject: [PATCH 2/2] TL/UCP: topo aware ring algo for reduce_scatter Replace the default ring reduce_scatter with a topo aware multi ring implementation that uses team->cuda_ring to route data along NVLink optimal paths (up to 8 parallel rings). Algorithm changes: - Ring rank, peer, and block indices are now derived from the cuda_ring topology pattern instead of flat team rank ordering. - Each ring handles its own sub block slice, with per ring GPU reductions via the executor before forwarding to the next peer. - Scratch buffer management simplified to a single mc_alloc/free per task lifetime (removed fragmentation logic). Signed-off-by: Juee Himalbhai Desai --- .../tl/ucp/reduce_scatter/reduce_scatter.c | 50 ++ .../tl/ucp/reduce_scatter/reduce_scatter.h | 9 +- .../ucp/reduce_scatter/reduce_scatter_ring.c | 580 +++++++----------- src/components/tl/ucp/tl_ucp_coll.c | 4 +- src/components/tl/ucp/tl_ucp_task.h | 2 + 5 files changed, 294 insertions(+), 351 deletions(-) diff --git a/src/components/tl/ucp/reduce_scatter/reduce_scatter.c b/src/components/tl/ucp/reduce_scatter/reduce_scatter.c index 39b01f9638e..eb94632e811 100644 --- a/src/components/tl/ucp/reduce_scatter/reduce_scatter.c +++ b/src/components/tl/ucp/reduce_scatter/reduce_scatter.c @@ -7,6 +7,9 @@ #include "tl_ucp.h" #include "reduce_scatter.h" #include "utils/ucc_coll_utils.h" +#include "utils/ucc_string.h" + +#define REDUCE_SCATTER_MAX_PATTERN_SIZE 256 ucc_base_coll_alg_info_t ucc_tl_ucp_reduce_scatter_algs[UCC_TL_UCP_REDUCE_SCATTER_ALG_LAST + 1] = { @@ -20,3 +23,50 @@ ucc_base_coll_alg_info_t .desc = "recursive k-ing with arbitrary radix"}, [UCC_TL_UCP_REDUCE_SCATTER_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; + +char *ucc_tl_ucp_reduce_scatter_score_str_get(ucc_tl_ucp_team_t *team) +{ + int max_size = REDUCE_SCATTER_MAX_PATTERN_SIZE; + char *str = ucc_malloc(max_size * sizeof(char)); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); + uint64_t cuda_types = + ctx->ucp_memory_types & + (UCC_BIT(UCC_MEMORY_TYPE_CUDA) | + UCC_BIT(UCC_MEMORY_TYPE_CUDA_MANAGED)); + uint64_t non_cuda_types = ctx->ucp_memory_types & (~cuda_types); + char *non_cuda_str; + char *cuda_str; + + if (team->cuda_ring && cuda_types) { + cuda_str = ucc_malloc(max_size * sizeof(char)); + ucc_mtype_map_to_str(cuda_types, ",", cuda_str, max_size); + if (non_cuda_types) { + non_cuda_str = ucc_malloc(max_size * sizeof(char)); + ucc_mtype_map_to_str(non_cuda_types, ",", non_cuda_str, max_size); + ucc_snprintf_safe(str, max_size, + "reduce_scatter:0-4k:@%d" + "#reduce_scatter:4k-inf:%s:@%d" + "#reduce_scatter:4k-inf:%s:@%d", + UCC_TL_UCP_REDUCE_SCATTER_ALG_KNOMIAL, + cuda_str, UCC_TL_UCP_REDUCE_SCATTER_ALG_RING, + non_cuda_str, UCC_TL_UCP_REDUCE_SCATTER_ALG_KNOMIAL); + ucc_free(cuda_str); + ucc_free(non_cuda_str); + return str; + } + ucc_snprintf_safe(str, max_size, + "reduce_scatter:0-4k:@%d" + "#reduce_scatter:4k-inf:%s:@%d" + "#reduce_scatter:4k-inf:@%d", + UCC_TL_UCP_REDUCE_SCATTER_ALG_KNOMIAL, + cuda_str, UCC_TL_UCP_REDUCE_SCATTER_ALG_RING, + UCC_TL_UCP_REDUCE_SCATTER_ALG_KNOMIAL); + ucc_free(cuda_str); + return str; + } + + ucc_snprintf_safe(str, max_size, + UCC_TL_UCP_REDUCE_SCATTER_DEFAULT_ALG_SELECT_STR, + UCC_TL_UCP_REDUCE_SCATTER_ALG_KNOMIAL); + return str; +} diff --git a/src/components/tl/ucp/reduce_scatter/reduce_scatter.h b/src/components/tl/ucp/reduce_scatter/reduce_scatter.h index 39414172994..f132f6cbbc5 100644 --- a/src/components/tl/ucp/reduce_scatter/reduce_scatter.h +++ b/src/components/tl/ucp/reduce_scatter/reduce_scatter.h @@ -19,7 +19,9 @@ extern ucc_base_coll_alg_info_t ucc_tl_ucp_reduce_scatter_algs[UCC_TL_UCP_REDUCE_SCATTER_ALG_LAST + 1]; #define UCC_TL_UCP_REDUCE_SCATTER_DEFAULT_ALG_SELECT_STR \ - "reduce_scatter:@ring" + "reduce_scatter:@%d" + +char *ucc_tl_ucp_reduce_scatter_score_str_get(ucc_tl_ucp_team_t *team); static inline int ucc_tl_ucp_reduce_scatter_alg_from_str(const char *str) { @@ -48,4 +50,9 @@ ucc_status_t ucc_tl_ucp_reduce_scatter_ring_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t * team, ucc_coll_task_t ** task_h); + +ucc_status_t ucc_tl_ucp_reduce_scatter_ring_init_common( + ucc_tl_ucp_task_t *task); + +void ucc_tl_ucp_reduce_scatter_ring_progress(ucc_coll_task_t *coll_task); #endif diff --git a/src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c b/src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c index ea51d19bf69..faf96657ec1 100644 --- a/src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c +++ b/src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c @@ -1,9 +1,11 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ +#include "config.h" +#include "tl_ucp.h" #include "reduce_scatter.h" #include "tl_ucp_sendrecv.h" #include "core/ucc_progress_queue.h" @@ -11,414 +13,296 @@ #include "utils/ucc_math.h" #include "utils/ucc_coll_utils.h" #include "utils/ucc_dt_reduce.h" -#include "utils/ucc_atomic.h" +#include "coll_patterns/ring.h" -#define REVERSED_FRAG 1 +#define MAX_RINGS 8 -static inline void send_completion_common(void *request, ucs_status_t status, - void *user_data) +static inline size_t +rs_ring_total_count(ucc_coll_args_t *args) { - ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; - - if (ucc_unlikely(UCS_OK != status)) { - tl_error(UCC_TASK_LIB(task), "failure in rs ring completion %s", - ucs_status_string(status)); - task->super.status = ucs_status_to_ucc_status(status); - } - ucc_atomic_add32(&task->tagged.send_completed, 1); - if (request) { - ucp_request_free(request); - } -} - -static void send_completion_1(void *request, ucs_status_t status, - void *user_data) -{ - ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; - - task->reduce_scatter_ring.s_scratch_busy[0] = 0; - send_completion_common(request, status, user_data); + return UCC_IS_INPLACE(*args) ? args->dst.info.count + : args->src.info.count; } -static void send_completion_2(void *request, ucs_status_t status, - void *user_data) +void ucc_tl_ucp_reduce_scatter_ring_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; - - task->reduce_scatter_ring.s_scratch_busy[1] = 0; - send_completion_common(request, status, user_data); -} + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_ring_pattern_t *ring = team->cuda_ring; + ucc_rank_t nrings = ucc_min(MAX_RINGS, ring->num_rings); + ucc_rank_t tsize = ucc_ring_pattern_size(ring, 0); + size_t total_cnt = rs_ring_total_count(args); + size_t block_cnt = total_cnt / tsize; + void *sbuf = UCC_IS_INPLACE(*args) + ? args->dst.info.buffer + : args->src.info.buffer; + void *dst = args->dst.info.buffer; + void *scratch = task->reduce_scatter_ring.scratch; + size_t blk_bytes = task->reduce_scatter_ring.max_block_count * + ucc_dt_size(args->dst.info.datatype); + void *scratch2 = PTR_OFFSET(scratch, blk_bytes); + int persistent = UCC_IS_PERSISTENT(*args); + ucc_memory_type_t mem_type = args->dst.info.mem_type; + ucc_datatype_t dt = args->dst.info.datatype; + size_t dt_size = ucc_dt_size(dt); + ucc_rank_t rrank, adj_rrank, recv_block, send_block; + ucc_rank_t sendto, recvfrom, ring_id, step; + size_t ring_offset, ring_count, data_displ, data_size; + ucc_status_t status; + int is_avg; + void *reduce_target; + void *recv_buf, *send_src, *next_recv_dst; -static inline void ucc_ring_frag_count(ucc_tl_ucp_task_t *task, size_t count, - ucc_rank_t block, size_t *frag_count) -{ - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - size_t size = UCC_TL_TEAM_SIZE(team); - int n_frags, frag; - size_t block_count; - - n_frags = task->reduce_scatter_ring.n_frags; - frag = task->reduce_scatter_ring.frag; - - block_count = ucc_buffer_block_count(count, size, block); - *frag_count = ucc_buffer_block_count(block_count, n_frags, frag); -} - -static inline void ucc_ring_frag_block_offset(ucc_tl_ucp_task_t *task, - size_t count, ucc_rank_t block, - size_t *block_offset, - size_t *frag_offset) -{ - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - size_t size = UCC_TL_TEAM_SIZE(team); - int n_frags, frag; - size_t block_count; - - n_frags = task->reduce_scatter_ring.n_frags; - frag = task->reduce_scatter_ring.frag; - - block_count = ucc_buffer_block_count(count, size, block); - *frag_offset = ucc_buffer_block_offset(block_count, n_frags, frag); - *block_offset = ucc_buffer_block_offset(count, size, block); -} - -static void ucc_tl_ucp_reduce_scatter_ring_progress(ucc_coll_task_t *coll_task) -{ - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, - ucc_tl_ucp_task_t); - ucc_coll_args_t *args = &TASK_ARGS(task); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t size = task->subset.map.ep_num; - ucc_rank_t rank = task->subset.myrank; - void *sbuf = args->src.info.buffer; - ucc_memory_type_t mem_type = args->dst.info.mem_type; - size_t count = args->dst.info.count * size; - ucc_datatype_t dt = args->dst.info.datatype; - size_t dt_size = ucc_dt_size(dt); - ucc_rank_t sendto = (rank + 1) % size; - ucc_rank_t recvfrom = (rank - 1 + size) % size; - ucp_send_nbx_callback_t cb[2] = {send_completion_1, send_completion_2}; - ucc_rank_t prevblock, recv_data_from; - ucc_status_t status; - size_t max_block_size, block_offset, frag_count, frag_offset, final_offset; - int step, is_avg, id; - void *r_scratch, *s_scratch[2], *reduce_target; - volatile char *busy; - - final_offset = 0; - if (UCC_IS_INPLACE(*args)) { - sbuf = args->dst.info.buffer; - count /= size; - final_offset = - ucc_buffer_block_offset(count, size, UCC_TL_TEAM_RANK(team)); + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + return; } - sendto = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, sendto); - recvfrom = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, recvfrom); - if (team->cfg.use_reordering) { - sendto = ucc_ep_map_eval(task->subset.map, sendto); - recvfrom = ucc_ep_map_eval(task->subset.map, recvfrom); - } - max_block_size = task->reduce_scatter_ring.max_block_count * dt_size; - busy = task->reduce_scatter_ring.s_scratch_busy; - r_scratch = task->reduce_scatter_ring.scratch; - s_scratch[0] = PTR_OFFSET(r_scratch, max_block_size); - s_scratch[1] = PTR_OFFSET(s_scratch[0], max_block_size); + while (task->tagged.send_posted < 1 + nrings * (tsize - 1)) { + ucc_assert(task->tagged.send_posted > 0); + ucc_assert(task->tagged.recv_posted > 0); - if (UCC_INPROGRESS == ucc_tl_ucp_test_ring(task)) { - return; - } - while (task->tagged.recv_posted > 0) { - /* always have at least 1 send completion, ie 1 free slot */ - ucc_assert(!busy[0] || !busy[1]); - id = busy[0] ? 1 : 0; - reduce_target = s_scratch[id]; - step = task->tagged.send_posted; - prevblock = (rank - 1 - step + size) % size; - prevblock = - ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, prevblock); - if (team->cfg.use_reordering) { - prevblock = ucc_ep_map_eval(task->subset.map, prevblock); - } - /* reduction */ - ucc_assert(task->tagged.recv_posted == task->tagged.recv_completed); - ucc_assert(task->tagged.recv_posted < size); - - ucc_ring_frag_count(task, count, prevblock, &frag_count); - ucc_ring_frag_block_offset(task, count, prevblock, &block_offset, - &frag_offset); - if (task->tagged.recv_completed == size - 1) { - reduce_target = PTR_OFFSET(args->dst.info.buffer, - (frag_offset + final_offset) * dt_size); - } + step = (ucc_rank_t)((task->tagged.send_posted - 1) / nrings); is_avg = (args->op == UCC_OP_AVG) && - (task->tagged.recv_completed == (size - 1)); - if (UCC_OK != - (status = ucc_dt_reduce( - r_scratch, - PTR_OFFSET(sbuf, (block_offset + frag_offset) * dt_size), - reduce_target, frag_count, dt, args, - is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0, - AVG_ALPHA(task), task->reduce_scatter_ring.executor, - &task->reduce_scatter_ring.etask))) { - tl_error(UCC_TASK_LIB(task), "failed to perform dt reduction"); - task->super.status = status; - return; - } - EXEC_TASK_WAIT(task->reduce_scatter_ring.etask); - if (task->tagged.recv_completed == size - 1) { - task->tagged.recv_posted = task->tagged.recv_completed = 0; - break; + (step == (ucc_rank_t)(tsize - 2)); + + recv_buf = persistent ? ((step % 2 == 0) ? scratch : scratch2) + : scratch; + + for (ring_id = 0; ring_id < nrings; ring_id++) { + rrank = ucc_ring_pattern_rank(ring, ring_id); + adj_rrank = (rrank + tsize - 1) % tsize; + recv_block = ucc_ring_pattern_get_recv_block(ring, ring_id, + adj_rrank, step); + ring_offset = ucc_buffer_block_offset(block_cnt, nrings, ring_id); + ring_count = ucc_buffer_block_count(block_cnt, nrings, ring_id); + + if (step == (ucc_rank_t)(tsize - 2)) { + if (UCC_IS_INPLACE(*args)) { + reduce_target = PTR_OFFSET( + dst, + (recv_block * block_cnt + ring_offset) * dt_size); + } else { + reduce_target = PTR_OFFSET(dst, ring_offset * dt_size); + } + } else if (persistent) { + reduce_target = PTR_OFFSET(recv_buf, ring_offset * dt_size); + } else { + reduce_target = PTR_OFFSET( + sbuf, + (recv_block * block_cnt + ring_offset) * dt_size); + } + + status = ucc_dt_reduce( + PTR_OFFSET(recv_buf, ring_offset * dt_size), + PTR_OFFSET(sbuf, + (recv_block * block_cnt + ring_offset) * dt_size), + reduce_target, ring_count, dt, args, + is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0, + AVG_ALPHA(task), task->reduce_scatter_ring.executor, + &task->reduce_scatter_ring.etask); + if (UCC_OK != status) { + tl_error(UCC_TASK_LIB(task), + "failed to perform dt reduction"); + task->super.status = status; + return; + } + EXEC_TASK_WAIT(task->reduce_scatter_ring.etask); } - ucc_assert(task->tagged.send_posted - task->tagged.send_completed <= 1); - ucc_assert(task->tagged.send_posted < size); - - busy[id] = 1; - UCPCHECK_GOTO(ucc_tl_ucp_send_cb(reduce_target, frag_count * dt_size, - mem_type, sendto, team, task, cb[id], (void *)task), - task, out); - - recv_data_from = (rank - 2 - step + size) % size; - recv_data_from = - ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, recv_data_from); - if (team->cfg.use_reordering) { - recv_data_from = ucc_ep_map_eval(task->subset.map, recv_data_from); + + if (step + 1 >= (ucc_rank_t)(tsize - 1)) { + task->super.status = UCC_OK; + return; } - ucc_ring_frag_count(task, count, recv_data_from, &frag_count); - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(r_scratch, frag_count * dt_size, mem_type, - recvfrom, team, task), - task, out); + step++; + for (ring_id = 0; ring_id < nrings; ring_id++) { + rrank = ucc_ring_pattern_rank(ring, ring_id); + sendto = ucc_ring_pattern_get_send_peer(ring, ring_id, rrank); + recvfrom = ucc_ring_pattern_get_recv_peer(ring, ring_id, rrank); + ring_offset = ucc_buffer_block_offset(block_cnt, nrings, ring_id); + ring_count = ucc_buffer_block_count(block_cnt, nrings, ring_id); + + data_size = ring_count * dt_size; + + if (persistent) { + send_src = PTR_OFFSET(recv_buf, ring_offset * dt_size); + next_recv_dst = PTR_OFFSET( + (step % 2 == 0) ? scratch : scratch2, + ring_offset * dt_size); + } else { + adj_rrank = (rrank + tsize - 1) % tsize; + send_block = ucc_ring_pattern_get_recv_block(ring, ring_id, + adj_rrank, + step - 1); + data_displ = (send_block * block_cnt + ring_offset) * dt_size; + send_src = PTR_OFFSET(sbuf, data_displ); + next_recv_dst = PTR_OFFSET(scratch, ring_offset * dt_size); + } + + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(send_src, data_size, + mem_type, sendto, team, task), task, out); + + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(next_recv_dst, data_size, + mem_type, recvfrom, team, task), task, out); + } - if (UCC_INPROGRESS == ucc_tl_ucp_test_ring(task)) { + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } } - if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { - return; - } + task->super.status = UCC_OK; out: return; } static ucc_status_t -ucc_tl_ucp_reduce_scatter_ring_start(ucc_coll_task_t *coll_task) +reduce_scatter_ring_topo_start(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_coll_args_t * args = &TASK_ARGS(task); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t size = task->subset.map.ep_num; - ucc_rank_t rank = task->subset.myrank; - size_t count = args->dst.info.count * size; - ucc_datatype_t dt = args->dst.info.datatype; - size_t dt_size = ucc_dt_size(dt); - ucc_memory_type_t mem_type = args->dst.info.mem_type; - void * sbuf = args->src.info.buffer; - int step = 0; - ucc_rank_t sendto = (rank + 1) % size; - ucc_rank_t recvfrom = (rank - 1 + size) % size; - ucc_rank_t recv_block = (rank - 2 - step + size) % size; - ucc_rank_t send_block = (rank - 1 - step + size) % size; - size_t block_offset, frag_count, frag_offset; - void *r_scratch; - ucc_status_t status; + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_ring_pattern_t *ring = team->cuda_ring; + ucc_rank_t nrings = ucc_min(MAX_RINGS, ring->num_rings); + ucc_rank_t tsize = ucc_ring_pattern_size(ring, 0); + size_t total_cnt = rs_ring_total_count(args); + size_t block_cnt = total_cnt / tsize; + ucc_datatype_t dt = args->dst.info.datatype; + size_t dt_size = ucc_dt_size(dt); + ucc_memory_type_t mem_type = args->dst.info.mem_type; + void *sbuf = UCC_IS_INPLACE(*args) + ? args->dst.info.buffer + : args->src.info.buffer; + void *scratch = task->reduce_scatter_ring.scratch; + ucc_rank_t ring_id, rrank, adj_rrank, send_block, sendto, recvfrom; + size_t ring_offset, ring_count, data_displ, data_size; + ucc_status_t status; ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); - if (UCC_IS_INPLACE(*args)) { - sbuf = args->dst.info.buffer; - count /= size; - } status = ucc_coll_task_get_executor(&task->super, &task->reduce_scatter_ring.executor); if (ucc_unlikely(status != UCC_OK)) { return status; } - r_scratch = task->reduce_scatter_ring.scratch; - sendto = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, sendto); - recvfrom = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, recvfrom); - recv_block = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, recv_block); - send_block = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, send_block); - if (team->cfg.use_reordering) { - sendto = ucc_ep_map_eval(task->subset.map, sendto); - recvfrom = ucc_ep_map_eval(task->subset.map, recvfrom); - recv_block = ucc_ep_map_eval(task->subset.map, recv_block); - send_block = ucc_ep_map_eval(task->subset.map, send_block); + for (ring_id = 0; ring_id < nrings; ring_id++) { + rrank = ucc_ring_pattern_rank(ring, ring_id); + adj_rrank = (rrank + tsize - 1) % tsize; + send_block = ucc_ring_pattern_get_send_block(ring, ring_id, + adj_rrank, 0); + sendto = ucc_ring_pattern_get_send_peer(ring, ring_id, rrank); + recvfrom = ucc_ring_pattern_get_recv_peer(ring, ring_id, rrank); + ring_offset = ucc_buffer_block_offset(block_cnt, nrings, ring_id); + ring_count = ucc_buffer_block_count(block_cnt, nrings, ring_id); + + data_displ = (send_block * block_cnt + ring_offset) * dt_size; + data_size = ring_count * dt_size; + UCPCHECK_GOTO(ucc_tl_ucp_send_nb( + PTR_OFFSET(sbuf, data_displ), data_size, mem_type, sendto, + team, task), task, err); + + data_displ = ring_offset * dt_size; + data_size = ring_count * dt_size; + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb( + PTR_OFFSET(scratch, data_displ), data_size, mem_type, recvfrom, + team, task), task, err); } - ucc_ring_frag_count(task, count, recv_block, &frag_count); - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(r_scratch, frag_count * dt_size, mem_type, - recvfrom, team, task), - task, out); - - ucc_ring_frag_count(task, count, send_block, &frag_count); - ucc_ring_frag_block_offset(task, count, send_block, &block_offset, - &frag_offset); - UCPCHECK_GOTO(ucc_tl_ucp_send_nb( - PTR_OFFSET(sbuf, (block_offset + frag_offset) * dt_size), - frag_count * dt_size, mem_type, sendto, team, task), - task, out); - return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); -out: +err: return task->super.status; } static ucc_status_t -ucc_tl_ucp_reduce_scatter_ring_finalize(ucc_coll_task_t *coll_task) +reduce_scatter_ring_topo_finalize(ucc_coll_task_t *coll_task) { ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - if (task->reduce_scatter_ring.frag == REVERSED_FRAG) { - ucc_ep_map_destroy(&task->reduce_scatter_ring.inv_map); - } + + ucc_mc_free(task->reduce_scatter_ring.scratch_mc_header); return ucc_tl_ucp_coll_finalize(coll_task); } -static ucc_status_t ucc_tl_ucp_reduce_scatter_ring_init_subset( - ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, - ucc_coll_task_t **task_h, ucc_subset_t *subsets, int n_frags, int frag, - void *scratch, size_t max_block_count) +ucc_status_t ucc_tl_ucp_reduce_scatter_ring_init_common( + ucc_tl_ucp_task_t *task) { - ucc_tl_ucp_task_t *task; - ucc_tl_ucp_team_t *tl_team; - ucc_status_t status; + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_rank_t tsize; + size_t total_count, block_count, scratch_size; + ucc_datatype_t dt; + size_t dt_size; + ucc_memory_type_t mem_type; + ucc_mc_buffer_header_t *scratch_mc_header; + ucc_status_t status; - task = ucc_tl_ucp_init_task(coll_args, team); - tl_team = TASK_TEAM(task); - task->super.post = ucc_tl_ucp_reduce_scatter_ring_start; - task->super.progress = ucc_tl_ucp_reduce_scatter_ring_progress; - task->super.finalize = ucc_tl_ucp_reduce_scatter_ring_finalize; - task->subset.map = subsets[frag].map; - task->subset.myrank = subsets[frag].myrank; - if (frag == REVERSED_FRAG) { - if (tl_team->cfg.use_reordering) { - task->subset.map = subsets[0].map; - } - status = ucc_ep_map_create_inverse(subsets[frag].map, - &task->reduce_scatter_ring.inv_map, - frag && tl_team->cfg.use_reordering); - if (UCC_OK != status) { - return status; - } - } else { - task->reduce_scatter_ring.inv_map.type = UCC_EP_MAP_FULL; - task->reduce_scatter_ring.inv_map.ep_num = task->subset.map.ep_num; + if (!team->cuda_ring) { + return UCC_ERR_NOT_SUPPORTED; + } + + if (!ucc_coll_args_is_predefined_dt(args, UCC_RANK_INVALID)) { + tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported"); + return UCC_ERR_NOT_SUPPORTED; } - task->reduce_scatter_ring.n_frags = n_frags; - task->reduce_scatter_ring.frag = frag; - task->reduce_scatter_ring.scratch = scratch; - task->reduce_scatter_ring.max_block_count = max_block_count; + + if (UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op && + args->op == UCC_OP_AVG) { + return UCC_ERR_NOT_SUPPORTED; + } + + tsize = ucc_ring_pattern_size(team->cuda_ring, 0); + total_count = rs_ring_total_count(args); + if (total_count % tsize != 0) { + return UCC_ERR_NOT_SUPPORTED; + } + + block_count = total_count / tsize; + dt = args->dst.info.datatype; + dt_size = ucc_dt_size(dt); + mem_type = args->dst.info.mem_type; + scratch_size = UCC_IS_PERSISTENT(*args) ? 2 * block_count * dt_size + : block_count * dt_size; + + status = ucc_mc_alloc(&scratch_mc_header, scratch_size, mem_type); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(UCC_TASK_LIB(task), + "failed to allocate scratch for reduce_scatter ring"); + return status; + } + + task->reduce_scatter_ring.scratch = scratch_mc_header->addr; + task->reduce_scatter_ring.scratch_mc_header = scratch_mc_header; + task->reduce_scatter_ring.max_block_count = block_count; task->reduce_scatter_ring.s_scratch_busy[0] = 0; task->reduce_scatter_ring.s_scratch_busy[1] = 0; - *task_h = &task->super; - return UCC_OK; -} -static ucc_status_t -ucc_tl_ucp_reduce_scatter_ring_sched_post(ucc_coll_task_t *coll_task) -{ - return ucc_schedule_start(coll_task); -} + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + task->super.post = reduce_scatter_ring_topo_start; + task->super.progress = ucc_tl_ucp_reduce_scatter_ring_progress; + task->super.finalize = reduce_scatter_ring_topo_finalize; -static ucc_status_t -ucc_tl_ucp_reduce_scatter_ring_sched_finalize(ucc_coll_task_t *task) -{ - ucc_tl_ucp_schedule_t *schedule = ucc_derived_of(task, - ucc_tl_ucp_schedule_t); - ucc_status_t status; - - ucc_mc_free(schedule->scratch_mc_header); - status = ucc_schedule_finalize(task); - ucc_tl_ucp_put_schedule(&schedule->super.super); - return status; + return UCC_OK; } ucc_status_t ucc_tl_ucp_reduce_scatter_ring_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t * team, - ucc_coll_task_t ** task_h) + ucc_base_team_t *team, + ucc_coll_task_t **task_h) { - ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); - ucc_rank_t size = UCC_TL_TEAM_SIZE(tl_team); - size_t count = coll_args->args.dst.info.count; - ucc_datatype_t dt = coll_args->args.dst.info.datatype; - size_t dt_size = ucc_dt_size(dt); - ucc_memory_type_t mem_type = coll_args->args.dst.info.mem_type; - int bidir = - UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.reduce_scatter_ring_bidirectional; - size_t to_alloc_per_set, max_segcount, count_per_set; - ucc_tl_ucp_schedule_t *tl_schedule; - ucc_schedule_t *schedule; - ucc_coll_task_t *ctask; - ucc_sbgp_t *sbgp; - ucc_status_t status; - ucc_subset_t s[2]; - int i, n_subsets; - - if (UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.reduce_avg_pre_op && - coll_args->args.op == UCC_OP_AVG) { - return UCC_ERR_NOT_SUPPORTED; - } + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_tl_ucp_task_t *task; + ucc_status_t status; - if (!UCC_IS_INPLACE(coll_args->args)) { - count *= size; + if (!tl_team->cuda_ring) { + return UCC_ERR_NOT_SUPPORTED; } - status = ucc_tl_ucp_get_schedule(tl_team, coll_args, &tl_schedule); - if (ucc_unlikely(UCC_OK != status)) { + task = ucc_tl_ucp_init_task(coll_args, team); + status = ucc_tl_ucp_reduce_scatter_ring_init_common(task); + if (status != UCC_OK) { + ucc_tl_ucp_put_task(task); return status; } - - schedule = &tl_schedule->super.super; - /* if count == size then we have 1 elem per rank, not enough - to split into 2 sets */ - n_subsets = (bidir && (count > size)) ? 2 : 1; - - if (tl_team->cfg.use_reordering) { - sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_FULL_HOST_ORDERED); - s[0].myrank = sbgp->group_rank; - s[0].map = sbgp->map; - } else { - s[0].myrank = UCC_TL_TEAM_RANK(tl_team); - s[0].map.type = UCC_EP_MAP_FULL; - s[0].map.ep_num = UCC_TL_TEAM_SIZE(tl_team); - } - s[1].map = ucc_ep_map_create_reverse(UCC_TL_TEAM_SIZE(tl_team)); - s[1].myrank = ucc_ep_map_eval(s[1].map, s[0].myrank); - count_per_set = (count + n_subsets - 1) / n_subsets; - max_segcount = ucc_buffer_block_count(count_per_set, size, 0); - /* in flight we can have 2 sends from 2 differnt blocks and 1 recv: - need 3 * max_segcount of scratch per set */ - to_alloc_per_set = max_segcount * 3; - UCC_CHECK_GOTO(ucc_mc_alloc(&tl_schedule->scratch_mc_header, - to_alloc_per_set * dt_size * n_subsets, - mem_type), - out, status); - for (i = 0; i < n_subsets; i++) { - UCC_CHECK_GOTO(ucc_tl_ucp_reduce_scatter_ring_init_subset( - coll_args, team, &ctask, s, n_subsets, i, - PTR_OFFSET(tl_schedule->scratch_mc_header->addr, - to_alloc_per_set * i * dt_size), - max_segcount), - out_free, status); - ctask->n_deps = 1; - UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, ctask), out_free, - status); - UCC_CHECK_GOTO(ucc_event_manager_subscribe( - &schedule->super, UCC_EVENT_SCHEDULE_STARTED, ctask, - ucc_task_start_handler), - out_free, status); - } - schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; - schedule->super.post = ucc_tl_ucp_reduce_scatter_ring_sched_post; - schedule->super.finalize = ucc_tl_ucp_reduce_scatter_ring_sched_finalize; - *task_h = &schedule->super; + *task_h = &task->super; return UCC_OK; - -out_free: - ucc_mc_free(tl_schedule->scratch_mc_header); -out: - ucc_tl_ucp_put_schedule(schedule); - return status; } diff --git a/src/components/tl/ucp/tl_ucp_coll.c b/src/components/tl/ucp/tl_ucp_coll.c index 413e71294d3..32eaa89447a 100644 --- a/src/components/tl/ucp/tl_ucp_coll.c +++ b/src/components/tl/ucp/tl_ucp_coll.c @@ -51,8 +51,8 @@ const ucc_tl_ucp_default_alg_desc_t .str_get_fn = NULL }, { - .select_str = UCC_TL_UCP_REDUCE_SCATTER_DEFAULT_ALG_SELECT_STR, - .str_get_fn = NULL + .select_str = NULL, + .str_get_fn = ucc_tl_ucp_reduce_scatter_score_str_get }, { .select_str = UCC_TL_UCP_REDUCE_SCATTERV_DEFAULT_ALG_SELECT_STR, diff --git a/src/components/tl/ucp/tl_ucp_task.h b/src/components/tl/ucp/tl_ucp_task.h index 19b137bbe79..37c77eea187 100644 --- a/src/components/tl/ucp/tl_ucp_task.h +++ b/src/components/tl/ucp/tl_ucp_task.h @@ -89,6 +89,7 @@ typedef struct ucc_tl_ucp_task { } reduce_scatter_kn; struct { void *scratch; + ucc_mc_buffer_header_t *scratch_mc_header; size_t max_block_count; ucc_ep_map_t inv_map; int n_frags; @@ -99,6 +100,7 @@ typedef struct ucc_tl_ucp_task { } reduce_scatter_ring; struct { void *scratch; + ucc_mc_buffer_header_t *scratch_mc_header; size_t max_block_count; ucc_ep_map_t inv_map; int n_frags;