66
77#include "config.h"
88#include "ucc_context.h"
9+ #include "components/topo/ucc_topo.h"
910#include "components/cl/ucc_cl.h"
1011#include "components/tl/ucc_tl.h"
1112#include "utils/ucc_malloc.h"
@@ -583,6 +584,77 @@ ucc_status_t ucc_core_addr_exchange(ucc_context_t *context, ucc_oob_coll_t *oob,
583584 return UCC_OK ;
584585}
585586
587+ ucc_status_t ucc_core_ctx_id_exchange (ucc_context_t * context , ucc_oob_coll_t * oob ,
588+ ucc_addr_storage_t * addr_storage )
589+ {
590+ ucc_status_t status ;
591+ ucc_rank_t i ;
592+
593+ poll :
594+ if (addr_storage -> oob_req ) {
595+ status = oob -> req_test (addr_storage -> oob_req );
596+ if (status < 0 ) {
597+ oob -> req_free (addr_storage -> oob_req );
598+ ucc_error ("oob req test failed during topo addr exchange" );
599+ return status ;
600+ } else if (UCC_INPROGRESS == status ) {
601+ return status ;
602+ }
603+ oob -> req_free (addr_storage -> oob_req );
604+ addr_storage -> oob_req = NULL ;
605+ }
606+ if (0 == addr_storage -> addr_len ) {
607+ if (NULL == addr_storage -> storage ) {
608+ addr_storage -> size = oob -> n_oob_eps ;
609+
610+ addr_storage -> storage = (ucc_context_id_t * )ucc_malloc (
611+ addr_storage -> size * sizeof (ucc_context_id_t ), "ctx_ids_storage" );
612+ if (!addr_storage -> storage ) {
613+ ucc_error (
614+ "failed to allocate %zd bytes for ctx_ids storage" ,
615+ addr_storage -> size * sizeof (ucc_context_id_t ));
616+ return UCC_ERR_NO_MEMORY ;
617+ }
618+ addr_storage -> addr_len = sizeof (ucc_context_id_t );
619+
620+ status = oob -> allgather (& context -> id , addr_storage -> storage ,
621+ sizeof (ucc_context_id_t ), oob -> coll_info ,
622+ & addr_storage -> oob_req );
623+ if (UCC_OK != status ) {
624+ ucc_free (addr_storage -> storage );
625+ ucc_error ("failed to start oob allgather for ctx_ids" );
626+ return status ;
627+ }
628+ goto poll ;
629+ }
630+ }
631+ ucc_assert (addr_storage -> storage != NULL );
632+ if (addr_storage -> addr_len == 0 ) {
633+ ucc_free (addr_storage -> storage );
634+ addr_storage -> storage = NULL ;
635+ return UCC_OK ;
636+ }
637+
638+ {
639+ ucc_context_id_t * ctx_ids = (ucc_context_id_t * )addr_storage -> storage ;
640+ ucc_rank_t r = UCC_RANK_MAX ;
641+
642+ for (i = 0 ; i < addr_storage -> size ; i ++ ) {
643+ if (UCC_CTX_ID_EQUAL (context -> id , ctx_ids [i ])) {
644+ if (r != UCC_RANK_MAX ) {
645+ ucc_error ("ctx_id collision: %d %d" , r , i );
646+ return UCC_ERR_NO_MESSAGE ;
647+ }
648+ r = i ;
649+ }
650+ }
651+
652+ addr_storage -> flags = 0 ;
653+ addr_storage -> rank = r ;
654+ }
655+ return UCC_OK ;
656+ }
657+
586658static void remove_tl_ctx_from_array (ucc_tl_context_t * * array , unsigned * size ,
587659 ucc_tl_context_t * tl_ctx )
588660{
@@ -654,6 +726,60 @@ ucc_status_t ucc_context_create_proc_info(ucc_lib_h lib,
654726 ucc_check_wait_for_debugger (ctx -> rank );
655727#endif
656728 }
729+
730+ ctx -> id .pi = * proc_info ;
731+ ctx -> id .seq_num = ucc_atomic_fadd32 (& ucc_context_seq_num , 1 );
732+
733+ if (config -> node_local_id == UCC_ULUNITS_AUTO ) {
734+ if (params -> mask & UCC_CONTEXT_PARAM_FIELD_OOB && params -> oob .n_oob_eps > 1 ) {
735+ do {
736+ /* UCC context create is blocking fn, so we can wait here for the
737+ completion of addr exchange */
738+ status = ucc_core_ctx_id_exchange (ctx , & ctx -> params .oob ,
739+ & ctx -> addr_storage );
740+ if (status < 0 ) {
741+ ucc_error ("failed to exchange addresses during context "
742+ "creation with status: %s" ,
743+ ucc_status_string (status ));
744+ goto error_ctx_create ;
745+ }
746+ } while (status == UCC_INPROGRESS );
747+ status = ucc_context_topo_init (& ctx -> addr_storage , & ctx -> topo );
748+ if (UCC_OK != status ) {
749+ ucc_free (ctx -> addr_storage .storage );
750+ ucc_error ("failed to init ctx topo" );
751+ goto error_ctx_create ;
752+ }
753+ ucc_assert (ctx -> addr_storage .rank == params -> oob .oob_ep );
754+
755+ if (ctx -> topo ) {
756+ ucc_subset_t set ;
757+ ucc_topo_t * topo = NULL ;
758+
759+ memset (& set .map , 0 , sizeof (ucc_ep_map_t ));
760+ set .map .type = UCC_EP_MAP_FULL ;
761+ set .myrank = params -> oob .oob_ep ;
762+ set .map .ep_num = params -> oob .n_oob_eps ;
763+
764+ status = ucc_topo_init (set , ctx -> topo , & topo );
765+ if (UCC_OK != status ) {
766+ ucc_warn ("failed to init topo for computing local rank" );
767+ } else {
768+ b_params .node_local_id = ucc_topo_node_local_rank (topo );
769+ ucc_topo_cleanup (topo );
770+ }
771+ }
772+
773+ /* clean up addr_storage */
774+ ucc_free (ctx -> addr_storage .storage );
775+ ctx -> addr_storage .storage = NULL ;
776+ ctx -> addr_storage .addr_len = 0 ;
777+ ctx -> addr_storage .size = 0 ;
778+ ctx -> addr_storage .rank = UCC_RANK_MAX ;
779+ ctx -> addr_storage .flags = 0 ;
780+ }
781+ }
782+
657783 status = ucc_create_tl_contexts (ctx , config , b_params );
658784 if (UCC_OK != status ) {
659785 /* only critical error could have happened - bail */
@@ -733,8 +859,7 @@ ucc_status_t ucc_context_create_proc_info(ucc_lib_h lib,
733859 ucc_error ("failed to init progress queue for context %p" , ctx );
734860 goto error_ctx_create ;
735861 }
736- ctx -> id .pi = * proc_info ;
737- ctx -> id .seq_num = ucc_atomic_fadd32 (& ucc_context_seq_num , 1 );
862+
738863 if (params -> mask & UCC_CONTEXT_PARAM_FIELD_OOB &&
739864 params -> oob .n_oob_eps > 1 ) {
740865 do {
@@ -750,7 +875,7 @@ ucc_status_t ucc_context_create_proc_info(ucc_lib_h lib,
750875 }
751876 } while (status == UCC_INPROGRESS );
752877
753- if (topo_required ) {
878+ if (topo_required && ! ctx -> topo ) {
754879 /* At least one available CL context reported it needs topo info */
755880 status = ucc_context_topo_init (& ctx -> addr_storage , & ctx -> topo );
756881 if (UCC_OK != status ) {
0 commit comments