Skip to content

Commit c4f6048

Browse files
UCC/CORE: Added automation to loading local rank from topo if not provided by user
1 parent f4be45d commit c4f6048

4 files changed

Lines changed: 144 additions & 6 deletions

File tree

src/components/topo/ucc_topo.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,16 @@ static inline ucc_rank_t ucc_topo_nnodes(ucc_topo_t *topo)
258258
return sbgp->group_size;
259259
}
260260

261+
static inline ucc_rank_t ucc_topo_node_local_rank(ucc_topo_t *topo)
262+
{
263+
ucc_sbgp_t *sbgp = ucc_topo_get_sbgp(topo, UCC_SBGP_NODE);
264+
265+
if (sbgp->status == UCC_SBGP_NOT_EXISTS) {
266+
return 0;
267+
}
268+
return sbgp->group_rank;
269+
}
270+
261271
/* Returns node leaders array - array that maps each rank to the TEAM RANK that
262272
is the leader of that rank's node. Also returns per-node leaders array - array
263273
mapping node_id to the TEAM RANK of that node's leader */

src/core/ucc_context.c

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "config.h"
88
#include "ucc_context.h"
9+
#include "components/topo/ucc_topo.h"
910
#include "components/cl/ucc_cl.h"
1011
#include "components/tl/ucc_tl.h"
1112
#include "utils/ucc_malloc.h"
@@ -583,6 +584,77 @@ ucc_status_t ucc_core_addr_exchange(ucc_context_t *context, ucc_oob_coll_t *oob,
583584
return UCC_OK;
584585
}
585586

587+
ucc_status_t ucc_core_ctx_id_exchange(ucc_context_t *context, ucc_oob_coll_t *oob,
588+
ucc_addr_storage_t *addr_storage)
589+
{
590+
ucc_status_t status;
591+
ucc_rank_t i;
592+
593+
poll:
594+
if (addr_storage->oob_req) {
595+
status = oob->req_test(addr_storage->oob_req);
596+
if (status < 0) {
597+
oob->req_free(addr_storage->oob_req);
598+
ucc_error("oob req test failed during topo addr exchange");
599+
return status;
600+
} else if (UCC_INPROGRESS == status) {
601+
return status;
602+
}
603+
oob->req_free(addr_storage->oob_req);
604+
addr_storage->oob_req = NULL;
605+
}
606+
if (0 == addr_storage->addr_len) {
607+
if (NULL == addr_storage->storage) {
608+
addr_storage->size = oob->n_oob_eps;
609+
610+
addr_storage->storage = (ucc_context_id_t *)ucc_malloc(
611+
addr_storage->size * sizeof(ucc_context_id_t), "ctx_ids_storage");
612+
if (!addr_storage->storage) {
613+
ucc_error(
614+
"failed to allocate %zd bytes for ctx_ids storage",
615+
addr_storage->size * sizeof(ucc_context_id_t));
616+
return UCC_ERR_NO_MEMORY;
617+
}
618+
addr_storage->addr_len = sizeof(ucc_context_id_t);
619+
620+
status = oob->allgather(&context->id, addr_storage->storage,
621+
sizeof(ucc_context_id_t), oob->coll_info,
622+
&addr_storage->oob_req);
623+
if (UCC_OK != status) {
624+
ucc_free(addr_storage->storage);
625+
ucc_error("failed to start oob allgather for ctx_ids");
626+
return status;
627+
}
628+
goto poll;
629+
}
630+
}
631+
ucc_assert(addr_storage->storage != NULL);
632+
if (addr_storage->addr_len == 0 ) {
633+
ucc_free(addr_storage->storage);
634+
addr_storage->storage = NULL;
635+
return UCC_OK;
636+
}
637+
638+
{
639+
ucc_context_id_t *ctx_ids = (ucc_context_id_t *)addr_storage->storage;
640+
ucc_rank_t r = UCC_RANK_MAX;
641+
642+
for (i = 0; i < addr_storage->size; i++) {
643+
if (UCC_CTX_ID_EQUAL(context->id, ctx_ids[i])) {
644+
if (r != UCC_RANK_MAX) {
645+
ucc_error("ctx_id collision: %d %d", r, i);
646+
return UCC_ERR_NO_MESSAGE;
647+
}
648+
r = i;
649+
}
650+
}
651+
652+
addr_storage->flags = 0;
653+
addr_storage->rank = r;
654+
}
655+
return UCC_OK;
656+
}
657+
586658
static void remove_tl_ctx_from_array(ucc_tl_context_t **array, unsigned *size,
587659
ucc_tl_context_t *tl_ctx)
588660
{
@@ -654,6 +726,60 @@ ucc_status_t ucc_context_create_proc_info(ucc_lib_h lib,
654726
ucc_check_wait_for_debugger(ctx->rank);
655727
#endif
656728
}
729+
730+
ctx->id.pi = *proc_info;
731+
ctx->id.seq_num = ucc_atomic_fadd32(&ucc_context_seq_num, 1);
732+
733+
if (config->node_local_id == UCC_ULUNITS_AUTO) {
734+
if (params->mask & UCC_CONTEXT_PARAM_FIELD_OOB && params->oob.n_oob_eps > 1) {
735+
do {
736+
/* UCC context create is blocking fn, so we can wait here for the
737+
completion of addr exchange */
738+
status = ucc_core_ctx_id_exchange(ctx, &ctx->params.oob,
739+
&ctx->addr_storage);
740+
if (status < 0) {
741+
ucc_error("failed to exchange addresses during context "
742+
"creation with status: %s",
743+
ucc_status_string(status));
744+
goto error_ctx_create;
745+
}
746+
} while (status == UCC_INPROGRESS);
747+
status = ucc_context_topo_init(&ctx->addr_storage, &ctx->topo);
748+
if (UCC_OK != status) {
749+
ucc_free(ctx->addr_storage.storage);
750+
ucc_error("failed to init ctx topo");
751+
goto error_ctx_create;
752+
}
753+
ucc_assert(ctx->addr_storage.rank == params->oob.oob_ep);
754+
755+
if (ctx->topo) {
756+
ucc_subset_t set;
757+
ucc_topo_t *topo = NULL;
758+
759+
memset(&set.map, 0, sizeof(ucc_ep_map_t));
760+
set.map.type = UCC_EP_MAP_FULL;
761+
set.myrank = params->oob.oob_ep;
762+
set.map.ep_num = params->oob.n_oob_eps;
763+
764+
status = ucc_topo_init(set, ctx->topo, &topo);
765+
if (UCC_OK != status) {
766+
ucc_warn("failed to init topo for computing local rank");
767+
} else {
768+
b_params.node_local_id = ucc_topo_node_local_rank(topo);
769+
ucc_topo_cleanup(topo);
770+
}
771+
}
772+
773+
/* clean up addr_storage */
774+
ucc_free(ctx->addr_storage.storage);
775+
ctx->addr_storage.storage = NULL;
776+
ctx->addr_storage.addr_len = 0;
777+
ctx->addr_storage.size = 0;
778+
ctx->addr_storage.rank = UCC_RANK_MAX;
779+
ctx->addr_storage.flags = 0;
780+
}
781+
}
782+
657783
status = ucc_create_tl_contexts(ctx, config, b_params);
658784
if (UCC_OK != status) {
659785
/* only critical error could have happened - bail */
@@ -733,8 +859,7 @@ ucc_status_t ucc_context_create_proc_info(ucc_lib_h lib,
733859
ucc_error("failed to init progress queue for context %p", ctx);
734860
goto error_ctx_create;
735861
}
736-
ctx->id.pi = *proc_info;
737-
ctx->id.seq_num = ucc_atomic_fadd32(&ucc_context_seq_num, 1);
862+
738863
if (params->mask & UCC_CONTEXT_PARAM_FIELD_OOB &&
739864
params->oob.n_oob_eps > 1) {
740865
do {
@@ -750,7 +875,7 @@ ucc_status_t ucc_context_create_proc_info(ucc_lib_h lib,
750875
}
751876
} while (status == UCC_INPROGRESS);
752877

753-
if (topo_required) {
878+
if (topo_required && !ctx->topo) {
754879
/* At least one available CL context reported it needs topo info */
755880
status = ucc_context_topo_init(&ctx->addr_storage, &ctx->topo);
756881
if (UCC_OK != status) {

src/core/ucc_context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ ucc_status_t ucc_context_progress_deregister(ucc_context_t *ctx,
152152
ucc_status_t ucc_core_addr_exchange(ucc_context_t *context, ucc_oob_coll_t *oob,
153153
ucc_addr_storage_t *addr_storage);
154154

155+
/* Performs context id address exchange between the processes group defined by OOB.
156+
This function is used to exchange the context ids between the processes in order
157+
to find the local rank.
158+
*/
159+
ucc_status_t ucc_core_ctx_id_exchange(ucc_context_t *context, ucc_oob_coll_t *oob,
160+
ucc_addr_storage_t *addr_storage);
155161
/* UCC context packed address layout:
156162
--------------------------------------------------------------------------
157163
|n_components|id0|offset0|id1|offset1|..|idN|offsetN|data0|data1|..|dataN|

tools/perf/ucc_pt_comm.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,6 @@ ucc_status_t ucc_pt_comm::init()
167167
cfg_mod = std::to_string(bootstrap->get_ppn());
168168
UCCCHECK_GOTO(ucc_context_config_modify(ctx_config, NULL,
169169
"ESTIMATED_NUM_PPN", cfg_mod.c_str()), free_ctx_config, st);
170-
cfg_mod = std::to_string(bootstrap->get_local_rank());
171-
UCCCHECK_GOTO(ucc_context_config_modify(ctx_config, NULL,
172-
"NODE_LOCAL_ID", cfg_mod.c_str()), free_ctx_config, st);
173170
std::memset(&ctx_params, 0, sizeof(ucc_context_params_t));
174171
ctx_params.mask = UCC_CONTEXT_PARAM_FIELD_TYPE |
175172
UCC_CONTEXT_PARAM_FIELD_OOB |

0 commit comments

Comments
 (0)