@@ -72,8 +72,14 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_sched_finalize(ucc_coll_task_t *ctask)
7272
7373ucc_status_t ucc_tl_ucp_alltoall_onesided_finalize (ucc_coll_task_t * coll_task )
7474{
75- ucc_status_t status ;
75+ ucc_tl_ucp_task_t * task = ucc_derived_of (coll_task , ucc_tl_ucp_task_t );
76+ ucc_status_t status ;
7677
78+ status = ucc_tl_ucp_coll_dynamic_segment_destroy (task );
79+ if (ucc_unlikely (UCC_OK != status )) {
80+ tl_error (UCC_TASK_LIB (coll_task ),
81+ "failed to destroy dynamic segment local handles" );
82+ }
7783 status = ucc_tl_ucp_coll_finalize (coll_task );
7884 if (ucc_unlikely (UCC_OK != status )) {
7985 tl_error (UCC_TASK_LIB (coll_task ), "failed to finalize collective" );
@@ -92,27 +98,43 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask)
9298 ucc_rank_t gsize = UCC_TL_TEAM_SIZE (team );
9399 uint32_t ntokens = task -> alltoall_onesided .tokens ;
94100 int64_t npolls = task -> alltoall_onesided .npolls ;
95- /* To resolve remote virtual addresses, the dst_memh is the one that must
96- * have the rkey information. For this algorithm, we need to swap the
97- * src and dst handles to operate correctly */
98- ucc_mem_map_mem_h * dst_memh = TASK_ARGS (task ).src_memh .global_memh ;
99- uint32_t * posted = & task -> onesided .get_posted ;
100- uint32_t * completed = & task -> onesided .get_completed ;
101- ucc_rank_t peer = (grank + * posted + 1 ) % gsize ;
102- ucc_mem_map_mem_h src_memh ;
101+ /* For GET, we read from each peer's src buffer into our local dst buffer.
102+ * remote_rkeys is the per-rank array of src rkeys (what we GET from);
103+ * local_h is our own dst buffer registration (where data lands). */
104+ ucc_mem_map_mem_h * remote_rkeys = TASK_ARGS (task ).src_memh .global_memh ;
105+ uint32_t * posted = & task -> onesided .get_posted ;
106+ uint32_t * completed = & task -> onesided .get_completed ;
107+ ucc_rank_t peer = (grank + * posted + 1 ) % gsize ;
108+ ucc_mem_map_mem_h local_h ;
103109 size_t nelems ;
110+ ucc_status_t status ;
111+
112+ if (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG ) {
113+ status = ucc_tl_ucp_test_dynamic_segment (task );
114+ if (status == UCC_INPROGRESS ) {
115+ return ;
116+ }
117+ if (UCC_OK != status ) {
118+ task -> super .status = status ;
119+ tl_error (UCC_TL_TEAM_LIB (team ),
120+ "failed to exchange dynamic segments" );
121+ return ;
122+ }
123+ local_h = task -> dynamic_segments .dst_local ;
124+ remote_rkeys = (ucc_mem_map_mem_h * )task -> dynamic_segments .src_global ;
125+ } else {
126+ local_h = (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL )
127+ ? TASK_ARGS (task ).dst_memh .global_memh [grank ]
128+ : TASK_ARGS (task ).dst_memh .local_memh ;
129+ }
104130
105131 nelems = TASK_ARGS (task ).src .info .count ;
106132 nelems = (nelems / gsize ) * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
107- src_memh = (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL )
108- ? TASK_ARGS (task ).dst_memh .global_memh [grank ]
109- : TASK_ARGS (task ).dst_memh .local_memh ;
110-
111133 for (; * posted < gsize ; peer = (peer + 1 ) % gsize ) {
112134 UCPCHECK_GOTO (ucc_tl_ucp_get_nb (PTR_OFFSET (dest , peer * nelems ),
113135 PTR_OFFSET (src , grank * nelems ),
114- nelems , mtype , peer , src_memh , dst_memh ,
115- team , task ),
136+ nelems , mtype , peer , local_h ,
137+ remote_rkeys , team , task ),
116138 task , out );
117139
118140 if (!alltoall_onesided_handle_completion (task , posted , completed ,
@@ -122,8 +144,13 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask)
122144 }
123145
124146 alltoall_onesided_wait_completion (task , npolls );
125- out :
126- return ;
147+ out : {
148+ ucc_status_t st = task -> super .status ;
149+ if (st != UCC_INPROGRESS &&
150+ (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG )) {
151+ ucc_status_t fin = ucc_tl_ucp_coll_dynamic_segment_finalize (task );
152+ task -> super .status = (st != UCC_OK ) ? st : fin ;
153+ }
127154}
128155
129156void ucc_tl_ucp_alltoall_onesided_put_progress (ucc_coll_task_t * ctask )
@@ -143,12 +170,28 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
143170 ucc_rank_t peer = (grank + * posted + 1 ) % gsize ;
144171 ucc_mem_map_mem_h src_memh ;
145172 size_t nelems ;
173+ ucc_status_t status ;
146174
147175 nelems = TASK_ARGS (task ).src .info .count ;
148176 nelems = (nelems / gsize ) * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
149- src_memh = (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL )
177+ if (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG ) {
178+ status = ucc_tl_ucp_test_dynamic_segment (task );
179+ if (status == UCC_INPROGRESS ) {
180+ return ;
181+ }
182+ if (UCC_OK != status ) {
183+ task -> super .status = status ;
184+ tl_error (UCC_TL_TEAM_LIB (team ),
185+ "failed to exchange dynamic segments" );
186+ return ;
187+ }
188+ src_memh = task -> dynamic_segments .src_local ;
189+ dst_memh = (ucc_mem_map_mem_h * )task -> dynamic_segments .dst_global ;
190+ } else {
191+ src_memh = (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL )
150192 ? TASK_ARGS (task ).src_memh .global_memh [grank ]
151193 : TASK_ARGS (task ).src_memh .local_memh ;
194+ }
152195
153196 for (; * posted < gsize ; peer = (peer + 1 ) % gsize ) {
154197 UCPCHECK_GOTO (
@@ -165,12 +208,15 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
165208 }
166209
167210 alltoall_onesided_wait_completion (task , npolls );
168- out :
169- return ;
211+ out : {
212+ ucc_status_t st = task -> super .status ;
213+ if (st != UCC_INPROGRESS &&
214+ (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG )) {
215+ ucc_status_t fin = ucc_tl_ucp_coll_dynamic_segment_finalize (task );
216+ task -> super .status = (st != UCC_OK ) ? st : fin ;
217+ }
170218}
171219
172- static ucc_status_t ucc_tl_ucp_alltoall_onesided_start_ops (ucc_tl_ucp_task_t * task );
173-
174220ucc_status_t ucc_tl_ucp_alltoall_onesided_start (ucc_coll_task_t * ctask )
175221{
176222 ucc_tl_ucp_task_t * task = ucc_derived_of (ctask , ucc_tl_ucp_task_t );
@@ -180,65 +226,16 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
180226 ucc_tl_ucp_task_reset (task , UCC_INPROGRESS );
181227 if (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG ) {
182228 status = ucc_tl_ucp_coll_dynamic_segment_exchange_nb (task );
183- if (status == UCC_INPROGRESS ) {
184- return ucc_progress_queue_enqueue (UCC_TL_CORE_CTX (team )-> pq , & task -> super );
185- }
186- if (UCC_OK != status ) {
229+ if (UCC_OK != status && UCC_INPROGRESS != status ) {
187230 task -> super .status = status ;
231+ tl_error (UCC_TL_TEAM_LIB (team ),
232+ "failed to exchange dynamic segments" );
188233 return task -> super .status ;
189234 }
190235 }
191236
192237 /* Start the onesided operations */
193- return ucc_tl_ucp_alltoall_onesided_start_ops (task );
194- }
195-
196- static ucc_status_t ucc_tl_ucp_alltoall_onesided_start_ops (ucc_tl_ucp_task_t * task )
197- {
198- ucc_tl_ucp_team_t * team = TASK_TEAM (task );
199- ptrdiff_t src = (ptrdiff_t )TASK_ARGS (task ).src .info .buffer ;
200- ptrdiff_t dest = (ptrdiff_t )TASK_ARGS (task ).dst .info .buffer ;
201- size_t nelems = TASK_ARGS (task ).src .info .count ;
202- ucc_rank_t grank = UCC_TL_TEAM_RANK (team );
203- ucc_rank_t gsize = UCC_TL_TEAM_SIZE (team );
204- ucc_rank_t start = (grank + 1 ) % gsize ;
205- long * pSync = TASK_ARGS (task ).global_work_buffer ;
206- ucc_mem_map_mem_h src_memh = TASK_ARGS (task ).src_memh .local_memh ;
207- ucc_mem_map_mem_h * dst_memh = TASK_ARGS (task ).dst_memh .global_memh ;
208- ucc_rank_t peer ;
209-
210-
211- if (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG ) {
212- status = ucc_tl_ucp_coll_dynamic_segment_exchange (task );
213- if (UCC_OK != status ) {
214- task -> super .status = status ;
215- return task -> super .status ;
216- }
217- src_memh = task -> dynamic_segments .src_global [grank ];
218- dst_memh = (ucc_mem_map_mem_h * )task -> dynamic_segments .dst_global ;
219- } else {
220- if (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL ) {
221- src_memh = TASK_ARGS (task ).src_memh .global_memh [grank ];
222- }
223- }
224-
225- /* TODO: change when support for library-based work buffers is complete */
226- nelems = (nelems / gsize ) * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
227- dest = dest + grank * nelems ;
228- for (peer = start ; task -> onesided .put_posted < gsize ;
229- peer = (peer + 1 ) % gsize ) {
230- UCPCHECK_GOTO (ucc_tl_ucp_put_nb ((void * )(src + peer * nelems ),
231- (void * )dest , nelems , peer , src_memh ,
232- dst_memh , team , task ),
233- task , out );
234- UCPCHECK_GOTO (ucc_tl_ucp_atomic_inc (pSync , peer , dst_memh , team ), task ,
235- out );
236- }
237-
238- /* Operations posted, return UCC_INPROGRESS to let progress function set flag */
239- return UCC_INPROGRESS ;
240- out :
241- return task -> super .status ;
238+ return ucc_progress_queue_enqueue (UCC_TL_CORE_CTX (team )-> pq , & task -> super );
242239}
243240
244241ucc_status_t ucc_tl_ucp_alltoall_onesided_init (ucc_base_coll_args_t * coll_args ,
@@ -254,7 +251,7 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
254251 };
255252 size_t perc_bw =
256253 UCC_TL_UCP_TEAM_LIB (tl_team )-> cfg .alltoall_onesided_percent_bw ;
257- ucc_tl_ucp_alltoall_onesided_alg_t alg =
254+ ucc_tl_ucp_onesided_alg_type alg =
258255 UCC_TL_UCP_TEAM_LIB (tl_team )-> cfg .alltoall_onesided_alg ;
259256 ucc_tl_ucp_schedule_t * tl_schedule = NULL ;
260257 ucc_rank_t group_size = 1 ;
@@ -294,12 +291,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
294291 }
295292 }
296293
297- status = ucc_tl_ucp_coll_dynamic_segment_init (& coll_args -> args , task );
298- if (UCC_OK != status ) {
299- tl_error (UCC_TL_TEAM_LIB (tl_team ),
300- "failed to initialize dynamic segments" );
301- return status ;
302- }
303294
304295 status = ucc_tl_ucp_get_schedule (tl_team , coll_args ,
305296 (ucc_tl_ucp_schedule_t * * )& tl_schedule );
@@ -322,12 +313,35 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
322313 group_size = sbgp -> group_size ;
323314 }
324315
325- task = ucc_tl_ucp_init_task (coll_args , team );
316+ task = ucc_tl_ucp_init_task (coll_args , team );
317+ if (ucc_unlikely (!task )) {
318+ status = UCC_ERR_NO_MEMORY ;
319+ goto out ;
320+ }
326321 task -> super .finalize = ucc_tl_ucp_alltoall_onesided_finalize ;
327322 a2a_task = & task -> super ;
328323
324+ /* initialize dynamic segments */
325+ if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ||
326+ (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO &&
327+ sbgp -> group_size >= CONGESTION_THRESHOLD )) {
328+ alg = UCC_TL_UCP_ALLTOALL_ONESIDED_GET ;
329+ } else if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO ) {
330+ alg = UCC_TL_UCP_ALLTOALL_ONESIDED_PUT ;
331+ }
332+ status = ucc_tl_ucp_coll_dynamic_segment_init (& coll_args -> args , alg , task );
333+ if (UCC_OK != status ) {
334+ if (status != UCC_ERR_NOT_SUPPORTED ) {
335+ tl_error (UCC_TL_TEAM_LIB (tl_team ),
336+ "failed to initialize dynamic segments" );
337+ }
338+ ucc_tl_ucp_coll_finalize (& task -> super );
339+ goto out ;
340+ }
341+
329342 status = ucc_tl_ucp_coll_init (& barrier_coll_args , team , & barrier_task );
330343 if (status != UCC_OK ) {
344+ task -> super .finalize (& task -> super );
331345 goto out ;
332346 }
333347 if (perc_bw > 100 ) {
@@ -340,23 +354,25 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
340354 nelems = nelems / UCC_TL_TEAM_SIZE (tl_team );
341355 param .field_mask = UCP_EP_PERF_PARAM_FIELD_MESSAGE_SIZE ;
342356 attr .field_mask = UCP_EP_PERF_ATTR_FIELD_ESTIMATED_TIME ;
343- param .message_size = nelems * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );;
357+ param .message_size = nelems * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
344358 ucc_tl_ucp_get_ep (
345359 tl_team , (UCC_TL_TEAM_RANK (tl_team ) + 1 ) % UCC_TL_TEAM_SIZE (tl_team ),
346360 & ep );
347361 ucp_ep_evaluate_perf (ep , & param , & attr );
348362
349- rate = (1 / attr .estimated_time ) * (double )(perc_bw / 100.0 );
350- ratio = (nelems > 0 ) ? nelems * group_size : 1 ;
351- task -> alltoall_onesided .tokens = rate / ratio ;
363+ if (attr .estimated_time > 0 ) {
364+ rate = (1 / attr .estimated_time ) * (double )(perc_bw / 100.0 );
365+ ratio = (nelems > 0 ) ? nelems * group_size : 1 ;
366+ task -> alltoall_onesided .tokens = rate / ratio ;
367+ } else {
368+ task -> alltoall_onesided .tokens = 0 ;
369+ }
352370 if (task -> alltoall_onesided .tokens < 1 ) {
353371 task -> alltoall_onesided .tokens = 1 ;
354372 }
355373 task -> super .post = ucc_tl_ucp_alltoall_onesided_start ;
356374 npolls = task -> n_polls ;
357- if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ||
358- (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO &&
359- group_size >= CONGESTION_THRESHOLD )) {
375+ if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ) {
360376 npolls = nelems * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
361377 if (npolls < task -> n_polls ) {
362378 npolls = task -> n_polls ;
0 commit comments