Skip to content

Commit ce512fd

Browse files
committed
TEST: Add alltoall dynamic seg gtest
TEST: add mpi dyn seg alltoallv test
1 parent 7c5ae59 commit ce512fd

10 files changed

Lines changed: 635 additions & 695 deletions

File tree

src/components/tl/ucp/alltoall/alltoall_onesided.c

Lines changed: 106 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,14 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_sched_finalize(ucc_coll_task_t *ctask)
7272

7373
ucc_status_t ucc_tl_ucp_alltoall_onesided_finalize(ucc_coll_task_t *coll_task)
7474
{
75-
ucc_status_t status;
75+
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
76+
ucc_status_t status;
7677

78+
status = ucc_tl_ucp_coll_dynamic_segment_destroy(task);
79+
if (ucc_unlikely(UCC_OK != status)) {
80+
tl_error(UCC_TASK_LIB(coll_task),
81+
"failed to destroy dynamic segment local handles");
82+
}
7783
status = ucc_tl_ucp_coll_finalize(coll_task);
7884
if (ucc_unlikely(UCC_OK != status)) {
7985
tl_error(UCC_TASK_LIB(coll_task), "failed to finalize collective");
@@ -92,27 +98,43 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask)
9298
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
9399
uint32_t ntokens = task->alltoall_onesided.tokens;
94100
int64_t npolls = task->alltoall_onesided.npolls;
95-
/* To resolve remote virtual addresses, the dst_memh is the one that must
96-
* have the rkey information. For this algorithm, we need to swap the
97-
* src and dst handles to operate correctly */
98-
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).src_memh.global_memh;
99-
uint32_t *posted = &task->onesided.get_posted;
100-
uint32_t *completed = &task->onesided.get_completed;
101-
ucc_rank_t peer = (grank + *posted + 1) % gsize;
102-
ucc_mem_map_mem_h src_memh;
101+
/* For GET, we read from each peer's src buffer into our local dst buffer.
102+
* remote_rkeys is the per-rank array of src rkeys (what we GET from);
103+
* local_h is our own dst buffer registration (where data lands). */
104+
ucc_mem_map_mem_h *remote_rkeys = TASK_ARGS(task).src_memh.global_memh;
105+
uint32_t *posted = &task->onesided.get_posted;
106+
uint32_t *completed = &task->onesided.get_completed;
107+
ucc_rank_t peer = (grank + *posted + 1) % gsize;
108+
ucc_mem_map_mem_h local_h;
103109
size_t nelems;
110+
ucc_status_t status;
111+
112+
if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) {
113+
status = ucc_tl_ucp_test_dynamic_segment(task);
114+
if (status == UCC_INPROGRESS) {
115+
return;
116+
}
117+
if (UCC_OK != status) {
118+
task->super.status = status;
119+
tl_error(UCC_TL_TEAM_LIB(team),
120+
"failed to exchange dynamic segments");
121+
return;
122+
}
123+
local_h = task->dynamic_segments.dst_local;
124+
remote_rkeys = (ucc_mem_map_mem_h *)task->dynamic_segments.src_global;
125+
} else {
126+
local_h = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)
127+
? TASK_ARGS(task).dst_memh.global_memh[grank]
128+
: TASK_ARGS(task).dst_memh.local_memh;
129+
}
104130

105131
nelems = TASK_ARGS(task).src.info.count;
106132
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
107-
src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)
108-
? TASK_ARGS(task).dst_memh.global_memh[grank]
109-
: TASK_ARGS(task).dst_memh.local_memh;
110-
111133
for (; *posted < gsize; peer = (peer + 1) % gsize) {
112134
UCPCHECK_GOTO(ucc_tl_ucp_get_nb(PTR_OFFSET(dest, peer * nelems),
113135
PTR_OFFSET(src, grank * nelems),
114-
nelems, mtype, peer, src_memh, dst_memh,
115-
team, task),
136+
nelems, mtype, peer, local_h,
137+
remote_rkeys, team, task),
116138
task, out);
117139

118140
if (!alltoall_onesided_handle_completion(task, posted, completed,
@@ -122,8 +144,13 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask)
122144
}
123145

124146
alltoall_onesided_wait_completion(task, npolls);
125-
out:
126-
return;
147+
out: {
148+
ucc_status_t st = task->super.status;
149+
if (st != UCC_INPROGRESS &&
150+
(task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) {
151+
ucc_status_t fin = ucc_tl_ucp_coll_dynamic_segment_finalize(task);
152+
task->super.status = (st != UCC_OK) ? st : fin;
153+
}
127154
}
128155

129156
void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
@@ -143,12 +170,28 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
143170
ucc_rank_t peer = (grank + *posted + 1) % gsize;
144171
ucc_mem_map_mem_h src_memh;
145172
size_t nelems;
173+
ucc_status_t status;
146174

147175
nelems = TASK_ARGS(task).src.info.count;
148176
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
149-
src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL)
177+
if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) {
178+
status = ucc_tl_ucp_test_dynamic_segment(task);
179+
if (status == UCC_INPROGRESS) {
180+
return;
181+
}
182+
if (UCC_OK != status) {
183+
task->super.status = status;
184+
tl_error(UCC_TL_TEAM_LIB(team),
185+
"failed to exchange dynamic segments");
186+
return;
187+
}
188+
src_memh = task->dynamic_segments.src_local;
189+
dst_memh = (ucc_mem_map_mem_h *)task->dynamic_segments.dst_global;
190+
} else {
191+
src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL)
150192
? TASK_ARGS(task).src_memh.global_memh[grank]
151193
: TASK_ARGS(task).src_memh.local_memh;
194+
}
152195

153196
for (; *posted < gsize; peer = (peer + 1) % gsize) {
154197
UCPCHECK_GOTO(
@@ -165,12 +208,15 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
165208
}
166209

167210
alltoall_onesided_wait_completion(task, npolls);
168-
out:
169-
return;
211+
out: {
212+
ucc_status_t st = task->super.status;
213+
if (st != UCC_INPROGRESS &&
214+
(task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) {
215+
ucc_status_t fin = ucc_tl_ucp_coll_dynamic_segment_finalize(task);
216+
task->super.status = (st != UCC_OK) ? st : fin;
217+
}
170218
}
171219

172-
static ucc_status_t ucc_tl_ucp_alltoall_onesided_start_ops(ucc_tl_ucp_task_t *task);
173-
174220
ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
175221
{
176222
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
@@ -180,65 +226,16 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
180226
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
181227
if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) {
182228
status = ucc_tl_ucp_coll_dynamic_segment_exchange_nb(task);
183-
if (status == UCC_INPROGRESS) {
184-
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
185-
}
186-
if (UCC_OK != status) {
229+
if (UCC_OK != status && UCC_INPROGRESS != status) {
187230
task->super.status = status;
231+
tl_error(UCC_TL_TEAM_LIB(team),
232+
"failed to exchange dynamic segments");
188233
return task->super.status;
189234
}
190235
}
191236

192237
/* Start the onesided operations */
193-
return ucc_tl_ucp_alltoall_onesided_start_ops(task);
194-
}
195-
196-
static ucc_status_t ucc_tl_ucp_alltoall_onesided_start_ops(ucc_tl_ucp_task_t *task)
197-
{
198-
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
199-
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer;
200-
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer;
201-
size_t nelems = TASK_ARGS(task).src.info.count;
202-
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
203-
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
204-
ucc_rank_t start = (grank + 1) % gsize;
205-
long *pSync = TASK_ARGS(task).global_work_buffer;
206-
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
207-
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
208-
ucc_rank_t peer;
209-
210-
211-
if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) {
212-
status = ucc_tl_ucp_coll_dynamic_segment_exchange(task);
213-
if (UCC_OK != status) {
214-
task->super.status = status;
215-
return task->super.status;
216-
}
217-
src_memh = task->dynamic_segments.src_global[grank];
218-
dst_memh = (ucc_mem_map_mem_h *)task->dynamic_segments.dst_global;
219-
} else {
220-
if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) {
221-
src_memh = TASK_ARGS(task).src_memh.global_memh[grank];
222-
}
223-
}
224-
225-
/* TODO: change when support for library-based work buffers is complete */
226-
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
227-
dest = dest + grank * nelems;
228-
for (peer = start; task->onesided.put_posted < gsize;
229-
peer = (peer + 1) % gsize) {
230-
UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems),
231-
(void *)dest, nelems, peer, src_memh,
232-
dst_memh, team, task),
233-
task, out);
234-
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, dst_memh, team), task,
235-
out);
236-
}
237-
238-
/* Operations posted, return UCC_INPROGRESS to let progress function set flag */
239-
return UCC_INPROGRESS;
240-
out:
241-
return task->super.status;
238+
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
242239
}
243240

244241
ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
@@ -254,7 +251,7 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
254251
};
255252
size_t perc_bw =
256253
UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_percent_bw;
257-
ucc_tl_ucp_alltoall_onesided_alg_t alg =
254+
ucc_tl_ucp_onesided_alg_type alg =
258255
UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_alg;
259256
ucc_tl_ucp_schedule_t *tl_schedule = NULL;
260257
ucc_rank_t group_size = 1;
@@ -294,12 +291,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
294291
}
295292
}
296293

297-
status = ucc_tl_ucp_coll_dynamic_segment_init(&coll_args->args, task);
298-
if (UCC_OK != status) {
299-
tl_error(UCC_TL_TEAM_LIB(tl_team),
300-
"failed to initialize dynamic segments");
301-
return status;
302-
}
303294

304295
status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
305296
(ucc_tl_ucp_schedule_t **)&tl_schedule);
@@ -322,12 +313,35 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
322313
group_size = sbgp->group_size;
323314
}
324315

325-
task = ucc_tl_ucp_init_task(coll_args, team);
316+
task = ucc_tl_ucp_init_task(coll_args, team);
317+
if (ucc_unlikely(!task)) {
318+
status = UCC_ERR_NO_MEMORY;
319+
goto out;
320+
}
326321
task->super.finalize = ucc_tl_ucp_alltoall_onesided_finalize;
327322
a2a_task = &task->super;
328323

324+
/* initialize dynamic segments */
325+
if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ||
326+
(alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO &&
327+
sbgp->group_size >= CONGESTION_THRESHOLD)) {
328+
alg = UCC_TL_UCP_ALLTOALL_ONESIDED_GET;
329+
} else if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO) {
330+
alg = UCC_TL_UCP_ALLTOALL_ONESIDED_PUT;
331+
}
332+
status = ucc_tl_ucp_coll_dynamic_segment_init(&coll_args->args, alg, task);
333+
if (UCC_OK != status) {
334+
if (status != UCC_ERR_NOT_SUPPORTED) {
335+
tl_error(UCC_TL_TEAM_LIB(tl_team),
336+
"failed to initialize dynamic segments");
337+
}
338+
ucc_tl_ucp_coll_finalize(&task->super);
339+
goto out;
340+
}
341+
329342
status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, &barrier_task);
330343
if (status != UCC_OK) {
344+
task->super.finalize(&task->super);
331345
goto out;
332346
}
333347
if (perc_bw > 100) {
@@ -340,23 +354,25 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
340354
nelems = nelems / UCC_TL_TEAM_SIZE(tl_team);
341355
param.field_mask = UCP_EP_PERF_PARAM_FIELD_MESSAGE_SIZE;
342356
attr.field_mask = UCP_EP_PERF_ATTR_FIELD_ESTIMATED_TIME;
343-
param.message_size = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype);;
357+
param.message_size = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
344358
ucc_tl_ucp_get_ep(
345359
tl_team, (UCC_TL_TEAM_RANK(tl_team) + 1) % UCC_TL_TEAM_SIZE(tl_team),
346360
&ep);
347361
ucp_ep_evaluate_perf(ep, &param, &attr);
348362

349-
rate = (1 / attr.estimated_time) * (double)(perc_bw / 100.0);
350-
ratio = (nelems > 0) ? nelems * group_size : 1;
351-
task->alltoall_onesided.tokens = rate / ratio;
363+
if (attr.estimated_time > 0) {
364+
rate = (1 / attr.estimated_time) * (double)(perc_bw / 100.0);
365+
ratio = (nelems > 0) ? nelems * group_size : 1;
366+
task->alltoall_onesided.tokens = rate / ratio;
367+
} else {
368+
task->alltoall_onesided.tokens = 0;
369+
}
352370
if (task->alltoall_onesided.tokens < 1) {
353371
task->alltoall_onesided.tokens = 1;
354372
}
355373
task->super.post = ucc_tl_ucp_alltoall_onesided_start;
356374
npolls = task->n_polls;
357-
if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ||
358-
(alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO &&
359-
group_size >= CONGESTION_THRESHOLD)) {
375+
if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET) {
360376
npolls = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
361377
if (npolls < task->n_polls) {
362378
npolls = task->n_polls;

0 commit comments

Comments
 (0)