@@ -47,13 +47,13 @@ def mpi_allgather(base_comm: MPI.Comm,
4747 send_shapes = base_comm .allgather (send_buf .shape )
4848 (padded_send , padded_recv ) = _prepare_allgather_inputs (send_buf , send_shapes , engine = engine )
4949 recv_buffer_to_use = recv_buf if recv_buf else padded_recv
50- _mpi_calls (base_comm . Allgather , padded_send , recv_buffer_to_use , engine = engine )
50+ _mpi_calls (base_comm , " Allgather" , padded_send , recv_buffer_to_use , engine = engine )
5151 return _unroll_allgather_recv (recv_buffer_to_use , padded_send .shape , send_shapes )
5252 else :
5353 # CuPy with non-CUDA-aware MPI
5454 if recv_buf is None :
55- return _mpi_calls (base_comm . allgather , send_buf )
56- _mpi_calls (base_comm . Allgather , send_buf , recv_buf )
55+ return _mpi_calls (base_comm , " allgather" , send_buf )
56+ _mpi_calls (base_comm , " Allgather" , send_buf , recv_buf )
5757 return recv_buf
5858
5959
@@ -92,14 +92,14 @@ def mpi_allreduce(base_comm: MPI.Comm,
9292 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
9393 ncp = get_module (engine )
9494 recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
95- _mpi_calls (base_comm . Allreduce , send_buf , recv_buf , op , engine = engine )
95+ _mpi_calls (base_comm , " Allreduce" , send_buf , recv_buf , op , engine = engine )
9696 return recv_buf
9797 else :
9898 # CuPy with non-CUDA-aware MPI
9999 if recv_buf is None :
100- return _mpi_calls (base_comm . allreduce , send_buf , op , engine = engine )
100+ return _mpi_calls (base_comm , " allreduce" , send_buf , op , engine = engine )
101101 # For MIN and MAX which require recv_buf
102- _mpi_calls (base_comm . Allreduce , send_buf , recv_buf , op , engine = engine )
102+ _mpi_calls (base_comm , " Allreduce" , send_buf , recv_buf , op , engine = engine )
103103 return recv_buf
104104
105105
@@ -131,11 +131,11 @@ def mpi_bcast(base_comm: MPI.Comm,
131131
132132 """
133133 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
134- _mpi_calls (base_comm . Bcast , send_buf , engine = engine , root = root )
134+ _mpi_calls (base_comm , " Bcast" , send_buf , engine = engine , root = root )
135135 return send_buf
136136 # CuPy with non-CUDA-aware MPI: use object broadcast
137137 value = send_buf if base_comm .Get_rank () == root else None
138- return _mpi_calls (base_comm . bcast , value , engine = engine , root = root )
138+ return _mpi_calls (base_comm , " bcast" , value , engine = engine , root = root )
139139
140140
141141def mpi_send (base_comm : MPI .Comm ,
@@ -171,10 +171,10 @@ def mpi_send(base_comm: MPI.Comm,
171171 mpi_type = MPI ._typedict [send_buf .dtype .char ]
172172 if count is None :
173173 count = send_buf .size
174- _mpi_calls (base_comm . Send , [send_buf , count , mpi_type ], engine = engine , dest = dest , tag = tag )
174+ _mpi_calls (base_comm , " Send" , [send_buf , count , mpi_type ], engine = engine , dest = dest , tag = tag )
175175 else :
176176 # Uses CuPy without CUDA-aware MPI
177- _mpi_calls (base_comm . send , send_buf , dest , tag , engine = engine )
177+ _mpi_calls (base_comm , " send" , send_buf , dest , tag , engine = engine )
178178
179179
180180def mpi_recv (base_comm : MPI .Comm ,
@@ -219,23 +219,25 @@ def mpi_recv(base_comm: MPI.Comm,
219219 # dimension or shape-related integers are send/recv
220220 recv_buf = ncp .zeros (count , dtype = ncp .int32 )
221221 mpi_type = MPI ._typedict [recv_buf .dtype .char ]
222- _mpi_calls (base_comm . Recv , [recv_buf , recv_buf .size , mpi_type ], engine = engine , source = source , tag = tag )
222+ _mpi_calls (base_comm , " Recv" , [recv_buf , recv_buf .size , mpi_type ], engine = engine , source = source , tag = tag )
223223 else :
224224 # Uses CuPy without CUDA-aware MPI
225- recv_buf = _mpi_calls (base_comm . recv , engine = engine , source = source , tag = tag )
225+ recv_buf = _mpi_calls (base_comm , " recv" , engine = engine , source = source , tag = tag )
226226 return recv_buf
227227
228228
229- def _mpi_calls (call , * args , engine : Optional [str ] = "numpy" , ** kwargs ):
229+ def _mpi_calls (comm : MPI . Comm , func : str , * args , engine : Optional [str ] = "numpy" , ** kwargs ):
230230 """MPI Calls
231231 Wrapper around MPI comm calls with optional GPU synchronization for CuPy arrays.
232232
233233 Parameters
234234 ----------
235- call: :obj:`MPI.Comm`
236- MPI Communicator function
235+ comm: :obj:`MPI.Comm`
236+ MPI Communicator
237+ func
238+ MPI Function to call.
237239 args
238- Arguments passed to the MPI call .
240+ Arguments to pass to the function .
239241 engine: :obj:`str`, optional
240242 Engine used to store array (``numpy`` or ``cupy``)
241243 kwargs
@@ -246,6 +248,7 @@ def _mpi_calls(call, *args, engine: Optional[str] = "numpy", **kwargs):
246248 Result of the MPI call
247249 """
248250 if engine == "cupy" :
249- cp = get_module (engine )
250- cp .cuda .runtime .deviceSynchronize ()
251- return call (* args , ** kwargs )
251+ ncp = get_module (engine )
252+ ncp .cuda .Device ().synchronize ()
253+ mpi_func = getattr (comm , func )
254+ return mpi_func (* args , ** kwargs )
0 commit comments