Skip to content

Commit 3350c01

Browse files
committed
TEST: Add alltoall dynamic seg gtest
TEST: add mpi dyn seg alltoallv test
1 parent 23f06a2 commit 3350c01

10 files changed

Lines changed: 618 additions & 693 deletions

File tree

src/components/tl/ucp/alltoall/alltoall_onesided.c

Lines changed: 100 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,14 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_sched_finalize(ucc_coll_task_t *ctask)
7272

7373
ucc_status_t ucc_tl_ucp_alltoall_onesided_finalize(ucc_coll_task_t *coll_task)
7474
{
75-
ucc_status_t status;
75+
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
76+
ucc_status_t status;
7677

78+
status = ucc_tl_ucp_coll_dynamic_segment_destroy(task);
79+
if (ucc_unlikely(UCC_OK != status)) {
80+
tl_error(UCC_TASK_LIB(coll_task),
81+
"failed to destroy dynamic segment local handles");
82+
}
7783
status = ucc_tl_ucp_coll_finalize(coll_task);
7884
if (ucc_unlikely(UCC_OK != status)) {
7985
tl_error(UCC_TASK_LIB(coll_task), "failed to finalize collective");
@@ -92,27 +98,43 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask)
9298
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
9399
uint32_t ntokens = task->alltoall_onesided.tokens;
94100
int64_t npolls = task->alltoall_onesided.npolls;
95-
/* To resolve remote virtual addresses, the dst_memh is the one that must
96-
* have the rkey information. For this algorithm, we need to swap the
97-
* src and dst handles to operate correctly */
98-
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).src_memh.global_memh;
99-
uint32_t *posted = &task->onesided.get_posted;
100-
uint32_t *completed = &task->onesided.get_completed;
101-
ucc_rank_t peer = (grank + *posted + 1) % gsize;
102-
ucc_mem_map_mem_h src_memh;
101+
/* For GET, we read from each peer's src buffer into our local dst buffer.
102+
* remote_rkeys is the per-rank array of src rkeys (what we GET from);
103+
* local_h is our own dst buffer registration (where data lands). */
104+
ucc_mem_map_mem_h *remote_rkeys = TASK_ARGS(task).src_memh.global_memh;
105+
uint32_t *posted = &task->onesided.get_posted;
106+
uint32_t *completed = &task->onesided.get_completed;
107+
ucc_rank_t peer = (grank + *posted + 1) % gsize;
108+
ucc_mem_map_mem_h local_h;
103109
size_t nelems;
110+
ucc_status_t status;
111+
112+
if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) {
113+
status = ucc_tl_ucp_test_dynamic_segment(task);
114+
if (status == UCC_INPROGRESS) {
115+
return;
116+
}
117+
if (UCC_OK != status) {
118+
task->super.status = status;
119+
tl_error(UCC_TL_TEAM_LIB(team),
120+
"failed to exchange dynamic segments");
121+
return;
122+
}
123+
local_h = task->dynamic_segments.dst_local;
124+
remote_rkeys = (ucc_mem_map_mem_h *)task->dynamic_segments.src_global;
125+
} else {
126+
local_h = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)
127+
? TASK_ARGS(task).dst_memh.global_memh[grank]
128+
: TASK_ARGS(task).dst_memh.local_memh;
129+
}
104130

105131
nelems = TASK_ARGS(task).src.info.count;
106132
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
107-
src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)
108-
? TASK_ARGS(task).dst_memh.global_memh[grank]
109-
: TASK_ARGS(task).dst_memh.local_memh;
110-
111133
for (; *posted < gsize; peer = (peer + 1) % gsize) {
112134
UCPCHECK_GOTO(ucc_tl_ucp_get_nb(PTR_OFFSET(dest, peer * nelems),
113135
PTR_OFFSET(src, grank * nelems),
114-
nelems, mtype, peer, src_memh, dst_memh,
115-
team, task),
136+
nelems, mtype, peer, local_h,
137+
remote_rkeys, team, task),
116138
task, out);
117139

118140
if (!alltoall_onesided_handle_completion(task, posted, completed,
@@ -123,7 +145,10 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask)
123145

124146
alltoall_onesided_wait_completion(task, npolls);
125147
out:
126-
return;
148+
if (task->super.status != UCC_INPROGRESS &&
149+
(task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) {
150+
task->super.status = ucc_tl_ucp_coll_dynamic_segment_finalize(task);
151+
}
127152
}
128153

129154
void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
@@ -143,12 +168,28 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
143168
ucc_rank_t peer = (grank + *posted + 1) % gsize;
144169
ucc_mem_map_mem_h src_memh;
145170
size_t nelems;
171+
ucc_status_t status;
146172

147173
nelems = TASK_ARGS(task).src.info.count;
148174
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
149-
src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL)
175+
if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) {
176+
status = ucc_tl_ucp_test_dynamic_segment(task);
177+
if (status == UCC_INPROGRESS) {
178+
return;
179+
}
180+
if (UCC_OK != status) {
181+
task->super.status = status;
182+
tl_error(UCC_TL_TEAM_LIB(team),
183+
"failed to exchange dynamic segments");
184+
return;
185+
}
186+
src_memh = task->dynamic_segments.src_local;
187+
dst_memh = (ucc_mem_map_mem_h *)task->dynamic_segments.dst_global;
188+
} else {
189+
src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL)
150190
? TASK_ARGS(task).src_memh.global_memh[grank]
151191
: TASK_ARGS(task).src_memh.local_memh;
192+
}
152193

153194
for (; *posted < gsize; peer = (peer + 1) % gsize) {
154195
UCPCHECK_GOTO(
@@ -166,11 +207,12 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask)
166207

167208
alltoall_onesided_wait_completion(task, npolls);
168209
out:
169-
return;
210+
if (task->super.status != UCC_INPROGRESS &&
211+
(task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) {
212+
task->super.status = ucc_tl_ucp_coll_dynamic_segment_finalize(task);
213+
}
170214
}
171215

172-
static ucc_status_t ucc_tl_ucp_alltoall_onesided_start_ops(ucc_tl_ucp_task_t *task);
173-
174216
ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
175217
{
176218
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
@@ -180,65 +222,16 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
180222
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
181223
if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) {
182224
status = ucc_tl_ucp_coll_dynamic_segment_exchange_nb(task);
183-
if (status == UCC_INPROGRESS) {
184-
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
185-
}
186-
if (UCC_OK != status) {
225+
if (UCC_OK != status && UCC_INPROGRESS != status) {
187226
task->super.status = status;
227+
tl_error(UCC_TL_TEAM_LIB(team),
228+
"failed to exchange dynamic segments");
188229
return task->super.status;
189230
}
190231
}
191232

192233
/* Start the onesided operations */
193-
return ucc_tl_ucp_alltoall_onesided_start_ops(task);
194-
}
195-
196-
static ucc_status_t ucc_tl_ucp_alltoall_onesided_start_ops(ucc_tl_ucp_task_t *task)
197-
{
198-
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
199-
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer;
200-
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer;
201-
size_t nelems = TASK_ARGS(task).src.info.count;
202-
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
203-
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
204-
ucc_rank_t start = (grank + 1) % gsize;
205-
long *pSync = TASK_ARGS(task).global_work_buffer;
206-
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
207-
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
208-
ucc_rank_t peer;
209-
210-
211-
if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) {
212-
status = ucc_tl_ucp_coll_dynamic_segment_exchange(task);
213-
if (UCC_OK != status) {
214-
task->super.status = status;
215-
return task->super.status;
216-
}
217-
src_memh = task->dynamic_segments.src_global[grank];
218-
dst_memh = (ucc_mem_map_mem_h *)task->dynamic_segments.dst_global;
219-
} else {
220-
if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) {
221-
src_memh = TASK_ARGS(task).src_memh.global_memh[grank];
222-
}
223-
}
224-
225-
/* TODO: change when support for library-based work buffers is complete */
226-
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
227-
dest = dest + grank * nelems;
228-
for (peer = start; task->onesided.put_posted < gsize;
229-
peer = (peer + 1) % gsize) {
230-
UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems),
231-
(void *)dest, nelems, peer, src_memh,
232-
dst_memh, team, task),
233-
task, out);
234-
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, dst_memh, team), task,
235-
out);
236-
}
237-
238-
/* Operations posted, return UCC_INPROGRESS to let progress function set flag */
239-
return UCC_INPROGRESS;
240-
out:
241-
return task->super.status;
234+
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
242235
}
243236

244237
ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
@@ -254,7 +247,7 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
254247
};
255248
size_t perc_bw =
256249
UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_percent_bw;
257-
ucc_tl_ucp_alltoall_onesided_alg_t alg =
250+
ucc_tl_ucp_onesided_alg_type alg =
258251
UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_alg;
259252
ucc_tl_ucp_schedule_t *tl_schedule = NULL;
260253
ucc_rank_t group_size = 1;
@@ -294,12 +287,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
294287
}
295288
}
296289

297-
status = ucc_tl_ucp_coll_dynamic_segment_init(&coll_args->args, task);
298-
if (UCC_OK != status) {
299-
tl_error(UCC_TL_TEAM_LIB(tl_team),
300-
"failed to initialize dynamic segments");
301-
return status;
302-
}
303290

304291
status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
305292
(ucc_tl_ucp_schedule_t **)&tl_schedule);
@@ -322,12 +309,35 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
322309
group_size = sbgp->group_size;
323310
}
324311

325-
task = ucc_tl_ucp_init_task(coll_args, team);
312+
task = ucc_tl_ucp_init_task(coll_args, team);
313+
if (ucc_unlikely(!task)) {
314+
status = UCC_ERR_NO_MEMORY;
315+
goto out;
316+
}
326317
task->super.finalize = ucc_tl_ucp_alltoall_onesided_finalize;
327318
a2a_task = &task->super;
328319

320+
/* initialize dynamic segments */
321+
if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ||
322+
(alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO &&
323+
sbgp->group_size >= CONGESTION_THRESHOLD)) {
324+
alg = UCC_TL_UCP_ALLTOALL_ONESIDED_GET;
325+
} else if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO) {
326+
alg = UCC_TL_UCP_ALLTOALL_ONESIDED_PUT;
327+
}
328+
status = ucc_tl_ucp_coll_dynamic_segment_init(&coll_args->args, alg, task);
329+
if (UCC_OK != status) {
330+
if (status != UCC_ERR_NOT_SUPPORTED) {
331+
tl_error(UCC_TL_TEAM_LIB(tl_team),
332+
"failed to initialize dynamic segments");
333+
}
334+
ucc_tl_ucp_coll_finalize(&task->super);
335+
goto out;
336+
}
337+
329338
status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, &barrier_task);
330339
if (status != UCC_OK) {
340+
task->super.finalize(&task->super);
331341
goto out;
332342
}
333343
if (perc_bw > 100) {
@@ -340,23 +350,25 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
340350
nelems = nelems / UCC_TL_TEAM_SIZE(tl_team);
341351
param.field_mask = UCP_EP_PERF_PARAM_FIELD_MESSAGE_SIZE;
342352
attr.field_mask = UCP_EP_PERF_ATTR_FIELD_ESTIMATED_TIME;
343-
param.message_size = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype);;
353+
param.message_size = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
344354
ucc_tl_ucp_get_ep(
345355
tl_team, (UCC_TL_TEAM_RANK(tl_team) + 1) % UCC_TL_TEAM_SIZE(tl_team),
346356
&ep);
347357
ucp_ep_evaluate_perf(ep, &param, &attr);
348358

349-
rate = (1 / attr.estimated_time) * (double)(perc_bw / 100.0);
350-
ratio = (nelems > 0) ? nelems * group_size : 1;
351-
task->alltoall_onesided.tokens = rate / ratio;
359+
if (attr.estimated_time > 0) {
360+
rate = (1 / attr.estimated_time) * (double)(perc_bw / 100.0);
361+
ratio = (nelems > 0) ? nelems * group_size : 1;
362+
task->alltoall_onesided.tokens = rate / ratio;
363+
} else {
364+
task->alltoall_onesided.tokens = 0;
365+
}
352366
if (task->alltoall_onesided.tokens < 1) {
353367
task->alltoall_onesided.tokens = 1;
354368
}
355369
task->super.post = ucc_tl_ucp_alltoall_onesided_start;
356370
npolls = task->n_polls;
357-
if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET ||
358-
(alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO &&
359-
group_size >= CONGESTION_THRESHOLD)) {
371+
if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET) {
360372
npolls = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
361373
if (npolls < task->n_polls) {
362374
npolls = task->n_polls;

0 commit comments

Comments
 (0)