Skip to content

Commit 7ed2933

Browse files
committed
Fix docs and comment complex case in test
1 parent 57b6602 commit 7ed2933

5 files changed

Lines changed: 46 additions & 18 deletions

File tree

pylops_mpi/Distributed.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, NewType, Optional, Union
1+
from typing import Any, NewType, Optional
22

33
from mpi4py import MPI
44
from pylops.utils import NDArray
@@ -198,12 +198,25 @@ def _bcast(self,
198198
root: int = 0,
199199
engine: str = "numpy",
200200
) -> NDArray:
201-
"""BCast operation.
201+
"""Broadcast operation
202202
203-
Notes
204-
-----
205-
Any root-only assignment (e.g., setting a value prior to broadcast) must
206-
be done outside this method.
203+
Parameters
204+
----------
205+
base_comm : :obj:`MPI.Comm`
206+
Base MPI Communicator.
207+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
208+
NCCL Communicator.
209+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
210+
A buffer containing the data to be broadcast from the root rank.
211+
root : :obj:`int`, optional
212+
The rank of the process that holds the source data.
213+
engine : :obj:`str`, optional
214+
Engine used to store array (``numpy`` or ``cupy``)
215+
216+
Returns
217+
-------
218+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
219+
The buffer containing the broadcasted data.
207220
"""
208221
if deps.nccl_enabled and base_comm_nccl is not None:
209222
nccl_bcast(base_comm_nccl, send_buf, root=root)

pylops_mpi/DistributedArray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def __setitem__(self, index, value):
205205
"""
206206
if self.partition is Partition.BROADCAST:
207207
view = self.local_array[index]
208-
if self.rank == 0:
208+
if self.rank == 0:
209209
view[...] = value
210210
view = self._bcast(self.base_comm, self.base_comm_nccl,
211211
view, root=0, engine=self.engine)

pylops_mpi/utils/_mpi.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"mpi_recv",
77
]
88

9-
from typing import Optional, Union
9+
from typing import Optional
1010

1111
from mpi4py import MPI
1212
from pylops.utils import NDArray
@@ -113,10 +113,22 @@ def mpi_bcast(base_comm: MPI.Comm,
113113
Dispatch bcast routine based on type of input and availability of
114114
CUDA-Aware MPI.
115115
116-
Notes
117-
-----
118-
Any root-only assignment (e.g., setting a value prior to broadcast) must be
119-
done outside this function.
116+
Parameters
117+
----------
118+
base_comm : :obj:`MPI.Comm`
119+
Base MPI Communicator.
120+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
121+
The data buffer to be broadcasted to the other ranks from the broadcasting root rank.
122+
root : :obj:`int`, optional
123+
The rank of the broadcasting process.
124+
engine : :obj:`str`, optional
125+
Engine used to store array (``numpy`` or ``cupy``)
126+
127+
Returns
128+
-------
129+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
130+
The buffer containing the broadcasted data.
131+
120132
"""
121133
if deps.cuda_aware_mpi_enabled or engine == "numpy":
122134
base_comm.Bcast(send_buf, root=root)

pylops_mpi/utils/_nccl.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,14 @@ def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) ->
240240

241241
def nccl_bcast(nccl_comm, send_buf, root: int = 0) -> None:
242242
"""NCCL equivalent of MPI_Bcast for an array buffer.
243-
244-
Notes
245-
-----
246-
Any root-only assignment (e.g., setting a value prior to broadcast) must be
247-
done outside this function.
243+
Parameters
244+
----------
245+
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
246+
The NCCL communicator used for collective communication.
247+
send_buf: :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
248+
The data buffer to be broadcasted to the other ranks from the broadcasting root rank.
249+
root: :obj:`int`
250+
The rank of the broadcasting process.
248251
"""
249252
send_buf = send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)
250253
nccl_comm.bcast(

tests/test_matrixmult.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
pytest.param(37, 37, 37, "float64", id="f32_37_37_37"),
4242
pytest.param(50, 30, 40, "float64", id="f64_50_30_40"),
4343
# temporarely removed as sometimes crashed CI... to be investigated
44-
pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"),
44+
# pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"),
4545
pytest.param(3, 4, 5, "float32", id="f32_3_4_5"),
4646
pytest.param(1, 2, 1, "float64", id="f64_1_2_1",),
4747
pytest.param(2, 1, 3, "float32", id="f32_2_1_3",),

0 commit comments

Comments
 (0)