@@ -72,8 +72,14 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_sched_finalize(ucc_coll_task_t *ctask)
7272
7373ucc_status_t ucc_tl_ucp_alltoall_onesided_finalize (ucc_coll_task_t * coll_task )
7474{
75- ucc_status_t status ;
75+ ucc_tl_ucp_task_t * task = ucc_derived_of (coll_task , ucc_tl_ucp_task_t );
76+ ucc_status_t status ;
7677
78+ status = ucc_tl_ucp_coll_dynamic_segment_destroy (task );
79+ if (ucc_unlikely (UCC_OK != status )) {
80+ tl_error (UCC_TASK_LIB (coll_task ),
81+ "failed to destroy dynamic segment local handles" );
82+ }
7783 status = ucc_tl_ucp_coll_finalize (coll_task );
7884 if (ucc_unlikely (UCC_OK != status )) {
7985 tl_error (UCC_TASK_LIB (coll_task ), "failed to finalize collective" );
@@ -92,27 +98,43 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask)
9298 ucc_rank_t gsize = UCC_TL_TEAM_SIZE (team );
9399 uint32_t ntokens = task -> alltoall_onesided .tokens ;
94100 int64_t npolls = task -> alltoall_onesided .npolls ;
95- /* To resolve remote virtual addresses, the dst_memh is the one that must
96- * have the rkey information. For this algorithm, we need to swap the
97- * src and dst handles to operate correctly */
98- ucc_mem_map_mem_h * dst_memh = TASK_ARGS (task ).src_memh .global_memh ;
99- uint32_t * posted = & task -> onesided .get_posted ;
100- uint32_t * completed = & task -> onesided .get_completed ;
101- ucc_rank_t peer = (grank + * posted + 1 ) % gsize ;
102- ucc_mem_map_mem_h src_memh ;
101+ /* For GET, we read from each peer's src buffer into our local dst buffer.
102+ * remote_rkeys is the per-rank array of src rkeys (what we GET from);
103+ * local_h is our own dst buffer registration (where data lands). */
104+ ucc_mem_map_mem_h * remote_rkeys = TASK_ARGS (task ).src_memh .global_memh ;
105+ uint32_t * posted = & task -> onesided .get_posted ;
106+ uint32_t * completed = & task -> onesided .get_completed ;
107+ ucc_rank_t peer = (grank + * posted + 1 ) % gsize ;
108+ ucc_mem_map_mem_h local_h ;
103109 size_t nelems ;
110+ ucc_status_t status ;
111+
112+ if (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG ) {
113+ status = ucc_tl_ucp_test_dynamic_segment (task );
114+ if (status == UCC_INPROGRESS ) {
115+ return ;
116+ }
117+ if (UCC_OK != status ) {
118+ task -> super .status = status ;
119+ tl_error (UCC_TL_TEAM_LIB (team ),
120+ "failed to exchange dynamic segments" );
121+ return ;
122+ }
123+ local_h = task -> dynamic_segments .dst_local ;
124+ remote_rkeys = (ucc_mem_map_mem_h * )task -> dynamic_segments .src_global ;
125+ } else {
126+ local_h = (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL )
127+ ? TASK_ARGS (task ).dst_memh .global_memh [grank ]
128+ : TASK_ARGS (task ).dst_memh .local_memh ;
129+ }
104130
105131 nelems = TASK_ARGS (task ).src .info .count ;
106132 nelems = (nelems / gsize ) * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
107- src_memh = (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL )
108- ? TASK_ARGS (task ).dst_memh .global_memh [grank ]
109- : TASK_ARGS (task ).dst_memh .local_memh ;
110-
111133 for (; * posted < gsize ; peer = (peer + 1 ) % gsize ) {
112134 UCPCHECK_GOTO (ucc_tl_ucp_get_nb (PTR_OFFSET (dest , peer * nelems ),
113135 PTR_OFFSET (src , grank * nelems ),
114- nelems , mtype , peer , src_memh , dst_memh ,
115- team , task ),
136+ nelems , mtype , peer , local_h ,
137+ remote_rkeys , team , task ),
116138 task , out );
117139
118140 if (!alltoall_onesided_handle_completion (task , posted , completed ,
@@ -123,7 +145,10 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask)
123145
124146 alltoall_onesided_wait_completion (task , npolls );
125147out :
126- return ;
148+ if (task -> super .status != UCC_INPROGRESS &&
149+ (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG )) {
150+ task -> super .status = ucc_tl_ucp_coll_dynamic_segment_finalize (task );
151+ }
127152}
128153
129154void ucc_tl_ucp_alltoall_onesided_put_progress (ucc_coll_task_t * ctask )
@@ -143,12 +168,28 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
143168 ucc_rank_t peer = (grank + * posted + 1 ) % gsize ;
144169 ucc_mem_map_mem_h src_memh ;
145170 size_t nelems ;
171+ ucc_status_t status ;
146172
147173 nelems = TASK_ARGS (task ).src .info .count ;
148174 nelems = (nelems / gsize ) * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
149- src_memh = (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL )
175+ if (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG ) {
176+ status = ucc_tl_ucp_test_dynamic_segment (task );
177+ if (status == UCC_INPROGRESS ) {
178+ return ;
179+ }
180+ if (UCC_OK != status ) {
181+ task -> super .status = status ;
182+ tl_error (UCC_TL_TEAM_LIB (team ),
183+ "failed to exchange dynamic segments" );
184+ return ;
185+ }
186+ src_memh = task -> dynamic_segments .src_local ;
187+ dst_memh = (ucc_mem_map_mem_h * )task -> dynamic_segments .dst_global ;
188+ } else {
189+ src_memh = (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL )
150190 ? TASK_ARGS (task ).src_memh .global_memh [grank ]
151191 : TASK_ARGS (task ).src_memh .local_memh ;
192+ }
152193
153194 for (; * posted < gsize ; peer = (peer + 1 ) % gsize ) {
154195 UCPCHECK_GOTO (
@@ -166,11 +207,12 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
166207
167208 alltoall_onesided_wait_completion (task , npolls );
168209out :
169- return ;
210+ if (task -> super .status != UCC_INPROGRESS &&
211+ (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG )) {
212+ task -> super .status = ucc_tl_ucp_coll_dynamic_segment_finalize (task );
213+ }
170214}
171215
172- static ucc_status_t ucc_tl_ucp_alltoall_onesided_start_ops (ucc_tl_ucp_task_t * task );
173-
174216ucc_status_t ucc_tl_ucp_alltoall_onesided_start (ucc_coll_task_t * ctask )
175217{
176218 ucc_tl_ucp_task_t * task = ucc_derived_of (ctask , ucc_tl_ucp_task_t );
@@ -180,65 +222,16 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
180222 ucc_tl_ucp_task_reset (task , UCC_INPROGRESS );
181223 if (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG ) {
182224 status = ucc_tl_ucp_coll_dynamic_segment_exchange_nb (task );
183- if (status == UCC_INPROGRESS ) {
184- return ucc_progress_queue_enqueue (UCC_TL_CORE_CTX (team )-> pq , & task -> super );
185- }
186- if (UCC_OK != status ) {
225+ if (UCC_OK != status && UCC_INPROGRESS != status ) {
187226 task -> super .status = status ;
227+ tl_error (UCC_TL_TEAM_LIB (team ),
228+ "failed to exchange dynamic segments" );
188229 return task -> super .status ;
189230 }
190231 }
191232
192233 /* Start the onesided operations */
193- return ucc_tl_ucp_alltoall_onesided_start_ops (task );
194- }
195-
196- static ucc_status_t ucc_tl_ucp_alltoall_onesided_start_ops (ucc_tl_ucp_task_t * task )
197- {
198- ucc_tl_ucp_team_t * team = TASK_TEAM (task );
199- ptrdiff_t src = (ptrdiff_t )TASK_ARGS (task ).src .info .buffer ;
200- ptrdiff_t dest = (ptrdiff_t )TASK_ARGS (task ).dst .info .buffer ;
201- size_t nelems = TASK_ARGS (task ).src .info .count ;
202- ucc_rank_t grank = UCC_TL_TEAM_RANK (team );
203- ucc_rank_t gsize = UCC_TL_TEAM_SIZE (team );
204- ucc_rank_t start = (grank + 1 ) % gsize ;
205- long * pSync = TASK_ARGS (task ).global_work_buffer ;
206- ucc_mem_map_mem_h src_memh = TASK_ARGS (task ).src_memh .local_memh ;
207- ucc_mem_map_mem_h * dst_memh = TASK_ARGS (task ).dst_memh .global_memh ;
208- ucc_rank_t peer ;
209-
210-
211- if (task -> flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG ) {
212- status = ucc_tl_ucp_coll_dynamic_segment_exchange (task );
213- if (UCC_OK != status ) {
214- task -> super .status = status ;
215- return task -> super .status ;
216- }
217- src_memh = task -> dynamic_segments .src_global [grank ];
218- dst_memh = (ucc_mem_map_mem_h * )task -> dynamic_segments .dst_global ;
219- } else {
220- if (TASK_ARGS (task ).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL ) {
221- src_memh = TASK_ARGS (task ).src_memh .global_memh [grank ];
222- }
223- }
224-
225- /* TODO: change when support for library-based work buffers is complete */
226- nelems = (nelems / gsize ) * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
227- dest = dest + grank * nelems ;
228- for (peer = start ; task -> onesided .put_posted < gsize ;
229- peer = (peer + 1 ) % gsize ) {
230- UCPCHECK_GOTO (ucc_tl_ucp_put_nb ((void * )(src + peer * nelems ),
231- (void * )dest , nelems , peer , src_memh ,
232- dst_memh , team , task ),
233- task , out );
234- UCPCHECK_GOTO (ucc_tl_ucp_atomic_inc (pSync , peer , dst_memh , team ), task ,
235- out );
236- }
237-
238- /* Operations posted, return UCC_INPROGRESS to let progress function set flag */
239- return UCC_INPROGRESS ;
240- out :
241- return task -> super .status ;
234+ return ucc_progress_queue_enqueue (UCC_TL_CORE_CTX (team )-> pq , & task -> super );
242235}
243236
244237ucc_status_t ucc_tl_ucp_alltoall_onesided_init (ucc_base_coll_args_t * coll_args ,
@@ -254,7 +247,7 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
254247 };
255248 size_t perc_bw =
256249 UCC_TL_UCP_TEAM_LIB (tl_team )-> cfg .alltoall_onesided_percent_bw ;
257- ucc_tl_ucp_alltoall_onesided_alg_t alg =
250+ ucc_tl_ucp_onesided_alg_type alg =
258251 UCC_TL_UCP_TEAM_LIB (tl_team )-> cfg .alltoall_onesided_alg ;
259252 ucc_tl_ucp_schedule_t * tl_schedule = NULL ;
260253 ucc_rank_t group_size = 1 ;
@@ -294,12 +287,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
294287 }
295288 }
296289
297- status = ucc_tl_ucp_coll_dynamic_segment_init (& coll_args -> args , task );
298- if (UCC_OK != status ) {
299- tl_error (UCC_TL_TEAM_LIB (tl_team ),
300- "failed to initialize dynamic segments" );
301- return status ;
302- }
303290
304291 status = ucc_tl_ucp_get_schedule (tl_team , coll_args ,
305292 (ucc_tl_ucp_schedule_t * * )& tl_schedule );
@@ -322,12 +309,35 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
322309 group_size = sbgp -> group_size ;
323310 }
324311
325- task = ucc_tl_ucp_init_task (coll_args , team );
312+ task = ucc_tl_ucp_init_task (coll_args , team );
313+ if (ucc_unlikely (!task )) {
314+ status = UCC_ERR_NO_MEMORY ;
315+ goto out ;
316+ }
326317 task -> super .finalize = ucc_tl_ucp_alltoall_onesided_finalize ;
327318 a2a_task = & task -> super ;
328319
320+ /* initialize dynamic segments */
321+ if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ||
322+ (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO &&
323+ sbgp -> group_size >= CONGESTION_THRESHOLD )) {
324+ alg = UCC_TL_UCP_ALLTOALL_ONESIDED_GET ;
325+ } else if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO ) {
326+ alg = UCC_TL_UCP_ALLTOALL_ONESIDED_PUT ;
327+ }
328+ status = ucc_tl_ucp_coll_dynamic_segment_init (& coll_args -> args , alg , task );
329+ if (UCC_OK != status ) {
330+ if (status != UCC_ERR_NOT_SUPPORTED ) {
331+ tl_error (UCC_TL_TEAM_LIB (tl_team ),
332+ "failed to initialize dynamic segments" );
333+ }
334+ ucc_tl_ucp_coll_finalize (& task -> super );
335+ goto out ;
336+ }
337+
329338 status = ucc_tl_ucp_coll_init (& barrier_coll_args , team , & barrier_task );
330339 if (status != UCC_OK ) {
340+ task -> super .finalize (& task -> super );
331341 goto out ;
332342 }
333343 if (perc_bw > 100 ) {
@@ -340,23 +350,25 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
340350 nelems = nelems / UCC_TL_TEAM_SIZE (tl_team );
341351 param .field_mask = UCP_EP_PERF_PARAM_FIELD_MESSAGE_SIZE ;
342352 attr .field_mask = UCP_EP_PERF_ATTR_FIELD_ESTIMATED_TIME ;
343- param .message_size = nelems * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );;
353+ param .message_size = nelems * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
344354 ucc_tl_ucp_get_ep (
345355 tl_team , (UCC_TL_TEAM_RANK (tl_team ) + 1 ) % UCC_TL_TEAM_SIZE (tl_team ),
346356 & ep );
347357 ucp_ep_evaluate_perf (ep , & param , & attr );
348358
349- rate = (1 / attr .estimated_time ) * (double )(perc_bw / 100.0 );
350- ratio = (nelems > 0 ) ? nelems * group_size : 1 ;
351- task -> alltoall_onesided .tokens = rate / ratio ;
359+ if (attr .estimated_time > 0 ) {
360+ rate = (1 / attr .estimated_time ) * (double )(perc_bw / 100.0 );
361+ ratio = (nelems > 0 ) ? nelems * group_size : 1 ;
362+ task -> alltoall_onesided .tokens = rate / ratio ;
363+ } else {
364+ task -> alltoall_onesided .tokens = 0 ;
365+ }
352366 if (task -> alltoall_onesided .tokens < 1 ) {
353367 task -> alltoall_onesided .tokens = 1 ;
354368 }
355369 task -> super .post = ucc_tl_ucp_alltoall_onesided_start ;
356370 npolls = task -> n_polls ;
357- if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ||
358- (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO &&
359- group_size >= CONGESTION_THRESHOLD )) {
371+ if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ) {
360372 npolls = nelems * ucc_dt_size (TASK_ARGS (task ).src .info .datatype );
361373 if (npolls < task -> n_polls ) {
362374 npolls = task -> n_polls ;
0 commit comments