Skip to content

Commit 75eed01

Browse files
authored
Merge pull request #182 from PyLops/cuy-sync
Add CuPy synchronization
2 parents 8b19ed3 + bf0c427 commit 75eed01

1 file changed

Lines changed: 40 additions & 12 deletions

File tree

pylops_mpi/utils/_mpi.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ def mpi_allgather(base_comm: MPI.Comm,
4747
send_shapes = base_comm.allgather(send_buf.shape)
4848
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine)
4949
recv_buffer_to_use = recv_buf if recv_buf else padded_recv
50-
base_comm.Allgather(padded_send, recv_buffer_to_use)
50+
_mpi_calls(base_comm, "Allgather", padded_send, recv_buffer_to_use, engine=engine)
5151
return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes)
5252
else:
5353
# CuPy with non-CUDA-aware MPI
5454
if recv_buf is None:
55-
return base_comm.allgather(send_buf)
56-
base_comm.Allgather(send_buf, recv_buf)
55+
return _mpi_calls(base_comm, "allgather", send_buf)
56+
_mpi_calls(base_comm, "Allgather", send_buf, recv_buf)
5757
return recv_buf
5858

5959

@@ -92,14 +92,14 @@ def mpi_allreduce(base_comm: MPI.Comm,
9292
if deps.cuda_aware_mpi_enabled or engine == "numpy":
9393
ncp = get_module(engine)
9494
recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype)
95-
base_comm.Allreduce(send_buf, recv_buf, op)
95+
_mpi_calls(base_comm, "Allreduce", send_buf, recv_buf, op, engine=engine)
9696
return recv_buf
9797
else:
9898
# CuPy with non-CUDA-aware MPI
9999
if recv_buf is None:
100-
return base_comm.allreduce(send_buf, op)
100+
return _mpi_calls(base_comm, "allreduce", send_buf, op, engine=engine)
101101
# For MIN and MAX which require recv_buf
102-
base_comm.Allreduce(send_buf, recv_buf, op)
102+
_mpi_calls(base_comm, "Allreduce", send_buf, recv_buf, op, engine=engine)
103103
return recv_buf
104104

105105

@@ -131,11 +131,11 @@ def mpi_bcast(base_comm: MPI.Comm,
131131
132132
"""
133133
if deps.cuda_aware_mpi_enabled or engine == "numpy":
134-
base_comm.Bcast(send_buf, root=root)
134+
_mpi_calls(base_comm, "Bcast", send_buf, engine=engine, root=root)
135135
return send_buf
136136
# CuPy with non-CUDA-aware MPI: use object broadcast
137137
value = send_buf if base_comm.Get_rank() == root else None
138-
return base_comm.bcast(value, root=root)
138+
return _mpi_calls(base_comm, "bcast", value, engine=engine, root=root)
139139

140140

141141
def mpi_send(base_comm: MPI.Comm,
@@ -171,10 +171,10 @@ def mpi_send(base_comm: MPI.Comm,
171171
mpi_type = MPI._typedict[send_buf.dtype.char]
172172
if count is None:
173173
count = send_buf.size
174-
base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag)
174+
_mpi_calls(base_comm, "Send", [send_buf, count, mpi_type], engine=engine, dest=dest, tag=tag)
175175
else:
176176
# Uses CuPy without CUDA-aware MPI
177-
base_comm.send(send_buf, dest, tag)
177+
_mpi_calls(base_comm, "send", send_buf, dest, tag, engine=engine)
178178

179179

180180
def mpi_recv(base_comm: MPI.Comm,
@@ -219,8 +219,36 @@ def mpi_recv(base_comm: MPI.Comm,
219219
# dimension or shape-related integers are send/recv
220220
recv_buf = ncp.zeros(count, dtype=ncp.int32)
221221
mpi_type = MPI._typedict[recv_buf.dtype.char]
222-
base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag)
222+
_mpi_calls(base_comm, "Recv", [recv_buf, recv_buf.size, mpi_type], engine=engine, source=source, tag=tag)
223223
else:
224224
# Uses CuPy without CUDA-aware MPI
225-
recv_buf = base_comm.recv(source=source, tag=tag)
225+
recv_buf = _mpi_calls(base_comm, "recv", engine=engine, source=source, tag=tag)
226226
return recv_buf
227+
228+
229+
def _mpi_calls(comm: MPI.Comm, func: str, *args, engine: Optional[str] = "numpy", **kwargs):
230+
"""MPI Calls
231+
Wrapper around MPI comm calls with optional GPU synchronization for CuPy arrays.
232+
233+
Parameters
234+
----------
235+
comm: :obj:`MPI.Comm`
236+
MPI Communicator
237+
func
238+
MPI Function to call.
239+
args
240+
Arguments to pass to the function.
241+
engine: :obj:`str`, optional
242+
Engine used to store array (``numpy`` or ``cupy``)
243+
kwargs
244+
Keyword arguments passed to the MPI call.
245+
246+
Returns
247+
-------
248+
Result of the MPI call
249+
"""
250+
if engine == "cupy" and deps.cuda_aware_mpi_enabled:
251+
ncp = get_module(engine)
252+
ncp.cuda.Device().synchronize()
253+
mpi_func = getattr(comm, func)
254+
return mpi_func(*args, **kwargs)

0 commit comments

Comments
 (0)