@@ -47,13 +47,13 @@ def mpi_allgather(base_comm: MPI.Comm,
4747 send_shapes = base_comm .allgather (send_buf .shape )
4848 (padded_send , padded_recv ) = _prepare_allgather_inputs (send_buf , send_shapes , engine = engine )
4949 recv_buffer_to_use = recv_buf if recv_buf else padded_recv
50- base_comm . Allgather ( padded_send , recv_buffer_to_use )
50+ _mpi_calls ( base_comm , " Allgather" , padded_send , recv_buffer_to_use , engine = engine )
5151 return _unroll_allgather_recv (recv_buffer_to_use , padded_send .shape , send_shapes )
5252 else :
5353 # CuPy with non-CUDA-aware MPI
5454 if recv_buf is None :
55- return base_comm . allgather ( send_buf )
56- base_comm . Allgather ( send_buf , recv_buf )
55+ return _mpi_calls ( base_comm , " allgather" , send_buf )
56+ _mpi_calls ( base_comm , " Allgather" , send_buf , recv_buf )
5757 return recv_buf
5858
5959
@@ -92,14 +92,14 @@ def mpi_allreduce(base_comm: MPI.Comm,
9292 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
9393 ncp = get_module (engine )
9494 recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
95- base_comm . Allreduce ( send_buf , recv_buf , op )
95+ _mpi_calls ( base_comm , " Allreduce" , send_buf , recv_buf , op , engine = engine )
9696 return recv_buf
9797 else :
9898 # CuPy with non-CUDA-aware MPI
9999 if recv_buf is None :
100- return base_comm . allreduce ( send_buf , op )
100+ return _mpi_calls ( base_comm , " allreduce" , send_buf , op , engine = engine )
101101 # For MIN and MAX which require recv_buf
102- base_comm . Allreduce ( send_buf , recv_buf , op )
102+ _mpi_calls ( base_comm , " Allreduce" , send_buf , recv_buf , op , engine = engine )
103103 return recv_buf
104104
105105
@@ -131,11 +131,11 @@ def mpi_bcast(base_comm: MPI.Comm,
131131
132132 """
133133 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
134- base_comm . Bcast ( send_buf , root = root )
134+ _mpi_calls ( base_comm , " Bcast" , send_buf , engine = engine , root = root )
135135 return send_buf
136136 # CuPy with non-CUDA-aware MPI: use object broadcast
137137 value = send_buf if base_comm .Get_rank () == root else None
138- return base_comm . bcast ( value , root = root )
138+ return _mpi_calls ( base_comm , " bcast" , value , engine = engine , root = root )
139139
140140
141141def mpi_send (base_comm : MPI .Comm ,
@@ -171,10 +171,10 @@ def mpi_send(base_comm: MPI.Comm,
171171 mpi_type = MPI ._typedict [send_buf .dtype .char ]
172172 if count is None :
173173 count = send_buf .size
174- base_comm . Send ( [send_buf , count , mpi_type ], dest = dest , tag = tag )
174+ _mpi_calls ( base_comm , " Send" , [send_buf , count , mpi_type ], engine = engine , dest = dest , tag = tag )
175175 else :
176176 # Uses CuPy without CUDA-aware MPI
177- base_comm . send ( send_buf , dest , tag )
177+ _mpi_calls ( base_comm , " send" , send_buf , dest , tag , engine = engine )
178178
179179
180180def mpi_recv (base_comm : MPI .Comm ,
@@ -219,8 +219,36 @@ def mpi_recv(base_comm: MPI.Comm,
219219 # dimension or shape-related integers are send/recv
220220 recv_buf = ncp .zeros (count , dtype = ncp .int32 )
221221 mpi_type = MPI ._typedict [recv_buf .dtype .char ]
222- base_comm . Recv ( [recv_buf , recv_buf .size , mpi_type ], source = source , tag = tag )
222+ _mpi_calls ( base_comm , " Recv" , [recv_buf , recv_buf .size , mpi_type ], engine = engine , source = source , tag = tag )
223223 else :
224224 # Uses CuPy without CUDA-aware MPI
225- recv_buf = base_comm . recv ( source = source , tag = tag )
225+ recv_buf = _mpi_calls ( base_comm , " recv" , engine = engine , source = source , tag = tag )
226226 return recv_buf
227+
228+
229+ def _mpi_calls (comm : MPI .Comm , func : str , * args , engine : Optional [str ] = "numpy" , ** kwargs ):
230+ """MPI Calls
231+ Wrapper around MPI comm calls with optional GPU synchronization for CuPy arrays.
232+
233+ Parameters
234+ ----------
235+ comm: :obj:`MPI.Comm`
236+ MPI Communicator
237+ func
238+ MPI Function to call.
239+ args
240+ Arguments to pass to the function.
241+ engine: :obj:`str`, optional
242+ Engine used to store array (``numpy`` or ``cupy``)
243+ kwargs
244+ Keyword arguments passed to the MPI call.
245+
246+ Returns
247+ -------
248+ Result of the MPI call
249+ """
250+ if engine == "cupy" and deps .cuda_aware_mpi_enabled :
251+ ncp = get_module (engine )
252+ ncp .cuda .Device ().synchronize ()
253+ mpi_func = getattr (comm , func )
254+ return mpi_func (* args , ** kwargs )
0 commit comments