Skip to content

Commit 2d9a7b5

Browse files
committed
TL/UCP: add support for onesided dynamic segments
1 parent e447fe7 commit 2d9a7b5

13 files changed

Lines changed: 681 additions & 134 deletions

File tree

src/components/tl/ucp/alltoall/alltoall.c

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,29 +72,22 @@ ucc_status_t ucc_tl_ucp_alltoall_pairwise_init(ucc_base_coll_args_t *coll_args,
7272
}
7373

7474
ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
75-
ucc_base_team_t *team,
76-
ucc_coll_task_t **task_h)
75+
ucc_base_team_t *team,
76+
ucc_coll_task_t **task_h)
7777
{
78-
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
79-
ucc_tl_ucp_task_t *task;
80-
ucc_status_t status;
78+
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
79+
ucc_status_t status = UCC_OK;
80+
ucc_tl_ucp_task_t *task;
8181

8282
ALLTOALL_TASK_CHECK(coll_args->args, tl_team);
8383

84+
/* memory handles do not support work buffers, so check here */
8485
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER)) {
8586
tl_error(UCC_TL_TEAM_LIB(tl_team),
8687
"global work buffer not provided nor associated with team");
8788
status = UCC_ERR_NOT_SUPPORTED;
8889
goto out;
8990
}
90-
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) {
91-
if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) {
92-
tl_error(UCC_TL_TEAM_LIB(tl_team),
93-
"non memory mapped buffers are not supported");
94-
status = UCC_ERR_NOT_SUPPORTED;
95-
goto out;
96-
}
97-
}
9891
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) {
9992
coll_args->args.src_memh.global_memh = NULL;
10093
}
@@ -103,7 +96,8 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
10396
} else {
10497
if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)) {
10598
tl_error(UCC_TL_TEAM_LIB(tl_team),
106-
"onesided alltoall requires global memory handles for dst buffers");
99+
"onesided alltoall requires global memory handles for dst "
100+
"buffers");
107101
status = UCC_ERR_INVALID_PARAM;
108102
goto out;
109103
}
@@ -113,7 +107,12 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
113107
*task_h = &task->super;
114108
task->super.post = ucc_tl_ucp_alltoall_onesided_start;
115109
task->super.progress = ucc_tl_ucp_alltoall_onesided_progress;
116-
status = UCC_OK;
110+
111+
status = ucc_tl_ucp_coll_dynamic_segment_init(&coll_args->args, task);
112+
if (UCC_OK != status) {
113+
tl_error(UCC_TL_TEAM_LIB(tl_team),
114+
"failed to initialize dynamic segments");
115+
}
117116
out:
118117
return status;
119118
}

src/components/tl/ucp/alltoall/alltoall_onesided.c

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,53 @@
99
#include "alltoall.h"
1010
#include "core/ucc_progress_queue.h"
1111
#include "utils/ucc_math.h"
12+
#include "tl_ucp_coll.h"
1213
#include "tl_ucp_sendrecv.h"
1314

1415
void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask);
1516

1617
ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
1718
{
18-
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
19-
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
20-
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer;
21-
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer;
22-
size_t nelems = TASK_ARGS(task).src.info.count;
23-
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
24-
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
25-
ucc_rank_t start = (grank + 1) % gsize;
26-
long *pSync = TASK_ARGS(task).global_work_buffer;
27-
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
28-
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
29-
ucc_rank_t peer;
19+
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
20+
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
21+
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer;
22+
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer;
23+
size_t nelems = TASK_ARGS(task).src.info.count;
24+
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
25+
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
26+
ucc_rank_t start = (grank + 1) % gsize;
27+
long *pSync = TASK_ARGS(task).global_work_buffer;
28+
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
29+
ucc_mem_map_mem_h *dst_memh_g = TASK_ARGS(task).dst_memh.global_memh;
30+
ucc_rank_t peer;
31+
ucc_status_t status;
3032

31-
if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) {
32-
src_memh = TASK_ARGS(task).src_memh.global_memh[grank];
33+
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
34+
if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) {
35+
status = ucc_tl_ucp_coll_dynamic_segment_exchange(task);
36+
if (UCC_OK != status) {
37+
task->super.status = status;
38+
goto out;
39+
}
40+
src_memh = task->dynamic_segments.src_global[grank];
41+
dst_memh_g = (ucc_mem_map_mem_h *)task->dynamic_segments.dst_global;
42+
} else {
43+
if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) {
44+
src_memh = TASK_ARGS(task).src_memh.global_memh[grank];
45+
}
3346
}
3447

35-
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
3648
/* TODO: change when support for library-based work buffers is complete */
3749
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
3850
dest = dest + grank * nelems;
39-
for (peer = start; task->onesided.put_posted < gsize; peer = (peer + 1) % gsize) {
40-
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(
41-
(void *)(src + peer * nelems), (void *)dest, nelems,
42-
peer, src_memh, dst_memh, team, task),
43-
task, out);
44-
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, dst_memh, team),
45-
task, out);
51+
for (peer = start; task->onesided.put_posted < gsize;
52+
peer = (peer + 1) % gsize) {
53+
UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems),
54+
(void *)dest, nelems, peer, src_memh,
55+
dst_memh_g, team, task),
56+
task, out);
57+
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, dst_memh_g, team), task,
58+
out);
4659
}
4760
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
4861
out:
@@ -54,12 +67,12 @@ void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask)
5467
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
5568
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
5669
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
57-
long * pSync = TASK_ARGS(task).global_work_buffer;
70+
long *pSync = TASK_ARGS(task).global_work_buffer;
5871

5972
if (ucc_tl_ucp_test_onesided(task, gsize) == UCC_INPROGRESS) {
6073
return;
6174
}
6275

6376
pSync[0] = 0;
64-
task->super.status = UCC_OK;
77+
task->super.status = ucc_tl_ucp_coll_dynamic_segment_finalize(task);
6578
}

src/components/tl/ucp/alltoallv/alltoallv_onesided.c

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,35 @@
1313

1414
ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)
1515
{
16-
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
17-
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
18-
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info_v.buffer;
19-
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info_v.buffer;
20-
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
21-
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
22-
long *pSync = TASK_ARGS(task).global_work_buffer;
23-
ucc_aint_t *s_disp = TASK_ARGS(task).src.info_v.displacements;
24-
ucc_aint_t *d_disp = TASK_ARGS(task).dst.info_v.displacements;
25-
size_t sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype);
26-
size_t rdt_size = ucc_dt_size(TASK_ARGS(task).dst.info_v.datatype);
27-
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
28-
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
29-
ucc_rank_t peer;
30-
size_t sd_disp, dd_disp, data_size;
16+
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
17+
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
18+
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info_v.buffer;
19+
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info_v.buffer;
20+
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
21+
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
22+
long *pSync = TASK_ARGS(task).global_work_buffer;
23+
ucc_aint_t *s_disp = TASK_ARGS(task).src.info_v.displacements;
24+
ucc_aint_t *d_disp = TASK_ARGS(task).dst.info_v.displacements;
25+
size_t sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype);
26+
size_t rdt_size = ucc_dt_size(TASK_ARGS(task).dst.info_v.datatype);
27+
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
28+
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
29+
//ucc_mem_map_memh_t *src_memh_g = NULL;
30+
//ucc_mem_map_memh_t *dst_memh_g = NULL;
31+
ucc_rank_t peer;
32+
ucc_status_t status;
33+
size_t sd_disp, dd_disp, data_size;
3134

3235
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
36+
status = ucc_tl_ucp_coll_dynamic_segment_exchange(task);
37+
if (UCC_OK != status) {
38+
task->super.status = status;
39+
goto out;
40+
}
41+
42+
if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) {
43+
src_memh = TASK_ARGS(task).src_memh.global_memh[grank];
44+
}
3345

3446
/* perform a put to each member peer using the peer's index in the
3547
* destination displacement. */
@@ -42,18 +54,16 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)
4254
ucc_coll_args_get_displacement(&TASK_ARGS(task), d_disp, peer) *
4355
rdt_size;
4456
data_size =
45-
ucc_coll_args_get_count(
46-
&TASK_ARGS(task), TASK_ARGS(task).src.info_v.counts, peer) *
57+
ucc_coll_args_get_count(&TASK_ARGS(task),
58+
TASK_ARGS(task).src.info_v.counts, peer) *
4759
sdt_size;
4860

4961
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp),
50-
PTR_OFFSET(dest, dd_disp),
51-
data_size, peer, src_memh,
52-
dst_memh, team, task),
53-
task, out);
54-
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer,
55-
dst_memh, team),
62+
PTR_OFFSET(dest, dd_disp), data_size,
63+
peer, src_memh, dst_memh, team, task),
5664
task, out);
65+
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, dst_memh, team), task,
66+
out);
5767
}
5868
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
5969
out:
@@ -73,15 +83,16 @@ void ucc_tl_ucp_alltoallv_onesided_progress(ucc_coll_task_t *ctask)
7383

7484
pSync[0] = 0;
7585
task->super.status = UCC_OK;
86+
ucc_tl_ucp_coll_dynamic_segment_finalize(task);
7687
}
7788

7889
ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args,
7990
ucc_base_team_t *team,
8091
ucc_coll_task_t **task_h)
8192
{
82-
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
83-
ucc_tl_ucp_task_t *task;
84-
ucc_status_t status;
93+
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
94+
ucc_status_t status = UCC_OK;
95+
ucc_tl_ucp_task_t *task;
8596

8697
ALLTOALLV_TASK_CHECK(coll_args->args, tl_team);
8798
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER)) {
@@ -90,14 +101,6 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args,
90101
status = UCC_ERR_NOT_SUPPORTED;
91102
goto out;
92103
}
93-
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) {
94-
if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) {
95-
tl_error(UCC_TL_TEAM_LIB(tl_team),
96-
"non memory mapped buffers are not supported");
97-
status = UCC_ERR_NOT_SUPPORTED;
98-
goto out;
99-
}
100-
}
101104
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) {
102105
coll_args->args.src_memh.global_memh = NULL;
103106
}
@@ -109,7 +112,14 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args,
109112
*task_h = &task->super;
110113
task->super.post = ucc_tl_ucp_alltoallv_onesided_start;
111114
task->super.progress = ucc_tl_ucp_alltoallv_onesided_progress;
112-
status = UCC_OK;
115+
116+
status = ucc_tl_ucp_coll_dynamic_segment_init(
117+
&coll_args->args, task);
118+
if (UCC_OK != status) {
119+
tl_error(UCC_TL_TEAM_LIB(tl_team),
120+
"failed to initialize dynamic segments");
121+
goto out;
122+
}
113123
out:
114124
return status;
115125
}

src/components/tl/ucp/tl_ucp.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,10 @@ static ucs_config_field_t ucc_tl_ucp_context_config_table[] = {
291291
ucc_offsetof(ucc_tl_ucp_context_config_t, memtype_copy_enable),
292292
UCC_CONFIG_TYPE_BOOL},
293293

294-
{"EXPORTED_MEMORY_HANDLE", "n",
295-
"If set to yes, initialize UCP context with the exported memory handle "
296-
"feature, which is useful for offload devices such as a DPU. Otherwise "
297-
"disable the use of this feature.",
294+
{"EXPORTED_MEMORY_HANDLE", "0",
295+
"If set to 1, initialize UCP context with the exported memory handle "
296+
"feature, which is useful for offload devices such as a DPU. Set to 0 "
297+
"to disable this feature (default is 0).",
298298
ucc_offsetof(ucc_tl_ucp_context_config_t, exported_memory_handle),
299299
UCC_CONFIG_TYPE_BOOL},
300300

src/components/tl/ucp/tl_ucp.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@ typedef struct ucc_tl_ucp_team {
157157
ucc_status_t status;
158158
uint32_t seq_num;
159159
ucc_tl_ucp_task_t *preconnect_task;
160-
void * va_base[MAX_NR_SEGMENTS];
161-
size_t base_length[MAX_NR_SEGMENTS];
162160
ucc_tl_ucp_worker_t * worker;
163161
ucc_tl_ucp_team_config_t cfg;
164162
const char * tuning_str;
@@ -296,4 +294,19 @@ void ucc_tl_ucp_pre_register_mem(ucc_tl_ucp_team_t *team, void *addr,
296294
ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t *ctx,
297295
ucc_mem_map_params_t map,
298296
ucc_team_oob_coll_t oob);
297+
298+
ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context,
299+
ucc_mem_map_mode_t mode,
300+
ucc_mem_map_memh_t *memh,
301+
ucc_mem_map_tl_t *tl_h);
302+
303+
ucc_status_t ucc_tl_ucp_memh_pack(const ucc_base_context_t *context,
304+
ucc_mem_map_mode_t mode,
305+
ucc_mem_map_tl_t *tl_h,
306+
void **pack_buffer);
307+
308+
ucc_status_t ucc_tl_ucp_mem_unmap(const ucc_base_context_t *context,
309+
ucc_mem_map_mode_t mode,
310+
ucc_mem_map_tl_t *memh);
311+
299312
#endif

0 commit comments

Comments
 (0)