Skip to content

Commit 5805f4f

Browse files
committed
Added comments
1 parent 316b241 commit 5805f4f

2 files changed

Lines changed: 55 additions & 33 deletions

File tree

pylops_mpi/basicoperators/Halo.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from typing import Any, Dict, Optional, Tuple, Union
23

34
import numpy as np
45
from mpi4py import MPI
@@ -8,7 +9,11 @@
89
from pylops_mpi.Distributed import DistributedMixIn
910

1011

11-
def halo_block_split(global_shape: tuple, comm, grid_shape: tuple = None) -> tuple:
12+
def halo_block_split(
13+
global_shape: tuple,
14+
comm: MPI.Comm,
15+
grid_shape: Optional[tuple] = None,
16+
) -> tuple:
1217
r"""Split a global array over a Cartesian process grid.
1318
1419
Compute the local slice owned by the calling rank when ``global_shape`` is
@@ -133,11 +138,11 @@ class MPIHalo(DistributedMixIn, MPILinearOperator):
133138
def __init__(
134139
self,
135140
dims: tuple,
136-
halo,
137-
proc_grid_shape: tuple = None,
141+
halo: Union[int, tuple],
142+
proc_grid_shape: Optional[tuple] = None,
138143
comm: MPI.Comm = MPI.COMM_WORLD,
139-
dtype=np.float64,
140-
):
144+
dtype: Any = np.float64,
145+
) -> None:
141146
self.global_dims = tuple(dims)
142147
self.ndim = len(dims)
143148

@@ -163,7 +168,8 @@ def __init__(
163168
)
164169
super().__init__(shape=self.shape, dtype=np.dtype(dtype), base_comm=comm)
165170

166-
def _parse_halo(self, h):
171+
def _parse_halo(self, h: Union[int, tuple]) -> tuple:
172+
"""Normalize halo input and trim halos at global boundaries."""
167173
if isinstance(h, (int, np.int64, np.int32)):
168174
halo = (h,) * (2 * self.ndim)
169175
trimmed = list(halo)
@@ -185,7 +191,8 @@ def _parse_halo(self, h):
185191
raise ValueError(f"Invalid halo length {len(h)} for ndim={self.ndim}")
186192
return halo
187193

188-
def _build_topo(self):
194+
def _build_topo(self) -> Tuple[MPI.Comm, Dict[Tuple[str, int], int]]:
195+
"""Create the Cartesian communicator and map neighboring ranks on the distribution axis."""
189196
cart_comm = self.comm.Create_cart(
190197
self.proc_grid_shape,
191198
periods=[False] * self.ndim,
@@ -198,7 +205,8 @@ def _build_topo(self):
198205
neigh[("+", ax)] = after
199206
return cart_comm, neigh
200207

201-
def _calc_local_dims(self):
208+
def _calc_local_dims(self) -> tuple:
209+
"""Compute this rank's local block shape before halo padding."""
202210
rank = self.cart_comm.Get_rank()
203211
coords = self.cart_comm.Get_coords(rank)
204212
local = []
@@ -211,14 +219,16 @@ def _calc_local_dims(self):
211219
local.append(end - start)
212220
return tuple(local)
213221

214-
def _calc_local_extent(self):
222+
def _calc_local_extent(self) -> tuple:
223+
"""Compute this rank's local block shape after halo padding."""
215224
ext = []
216225
for ax in range(self.ndim):
217226
minus_halo, plus_halo = self.halo[2 * ax], self.halo[2 * ax + 1]
218227
ext.append(self.local_dims[ax] + minus_halo + plus_halo)
219228
return tuple(ext)
220229

221-
def _exchange_along_axis(self, ncp, arr, axis, before, after, engine):
230+
def _exchange_along_axis(self, ncp: Any, arr: Any, axis: int, before: int, after: int, engine: str) -> None:
231+
"""Exchange boundary/halo slices with neighboring ranks along one axis."""
222232
minus_nbr, plus_nbr = self.neigh[("-", axis)], self.neigh[("+", axis)]
223233
# slice definitions
224234
slicer = [slice(None)] * self.ndim
@@ -259,7 +269,7 @@ def _exchange_along_axis(self, ncp, arr, axis, before, after, engine):
259269
)
260270
arr[tuple(rcv_s)] = rcv
261271

262-
def _matvec(self, x):
272+
def _matvec(self, x: DistributedArray) -> DistributedArray:
263273
ncp = get_module(x.engine)
264274
if x.partition != Partition.SCATTER:
265275
raise ValueError(
@@ -295,7 +305,7 @@ def _matvec(self, x):
295305
y[:] = halo_arr.ravel()
296306
return y
297307

298-
def _rmatvec(self, x):
308+
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
299309
if x.partition != Partition.SCATTER:
300310
raise ValueError(
301311
f"x should have partition={Partition.SCATTER} Got {x.partition} instead..."

pylops_mpi/utils/_common.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,32 @@ def _float_scalar(value) -> float:
1717

1818
# TODO: return type annotation for both cupy and numpy
1919
def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine):
20-
r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
20+
r"""Prepare send_buf and recv_buf for buffered allgather
2121
22-
Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device.
22+
Buffered Allgather (MPI and NCCL) requires the sending buffer to have the
23+
same size for every rank/device.
2324
Therefore, padding is required when the array is not evenly partitioned across
24-
all the ranks. The padding is applied such that the each dimension of the sending buffers
25+
all the ranks. The padding is applied such that each dimension of the sending buffers
2526
is equal to the max size of that dimension across all ranks.
2627
27-
Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size
28+
Similarly, each receiver buffer (recv_buf) is created with size equal to
29+
:math:`n_rank \cdot send_buf.size`
2830
2931
Parameters
3032
----------
31-
send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like
32-
The data buffer from the local GPU to be sent for allgather.
33-
send_buf_shapes: :obj:`list`
34-
A list of shapes for each GPU send_buf (used to calculate padding size)
33+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or array-like
34+
The data buffer from the local rank/device to be sent for allgather.
35+
send_buf_shapes : :obj:`list`
36+
A list of shapes for each rank/device send_buf (used to calculate padding size)
3537
engine : :obj:`str`
3638
Engine used to store array (``numpy`` or ``cupy``)
3739
3840
Returns
3941
-------
40-
send_buf: :obj:`cupy.ndarray`
42+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
4143
A buffer containing the data and padded elements to be sent by this rank.
42-
recv_buf : :obj:`cupy.ndarray`
43-
An empty, padded buffer to gather data from all GPUs.
44+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
45+
An empty, padded buffer to gather data from all ranks.
4446
"""
4547
ncp = get_module(engine)
4648
sizes_each_dim = list(zip(*send_buf_shapes))
@@ -60,25 +62,35 @@ def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine):
6062

6163

6264
def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes, displs=None) -> list:
63-
r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL)
65+
r"""Unroll recv_buf after Buffered Allgather (MPI and NCCL)
66+
67+
Depending on the provided parameters, the function:
68+
- uses ``displs`` and element counts to extract variable-sized chunks.
69+
- removes padding and reshapes each chunk using ``padded_send_buf_shape``.
6470
65-
Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
66-
Each GPU may send array with a different shape, so the return type has to be a list of array
67-
instead of the concatenated array.
71+
Each rank may send an array with a different shape, so the return type is a list of arrays
72+
instead of a concatenated array.
6873
6974
Parameters
7075
----------
7176
recv_buf: :obj:`cupy.ndarray` or array-like
72-
The data buffer returned from nccl_allgather call
73-
padded_send_buf_shape: :obj:`tuple`:int
74-
The size of send_buf after padding used in nccl_allgather
77+
The data buffer returned from the allgather call
7578
send_buf_shapes: :obj:`list`
76-
A list of original shapes for each GPU send_buf prior to padding
79+
A list of original shapes of each rank's send_buf before any padding.
80+
padded_send_buf_shape : tuple
81+
Shape of each rank's data as stored in ``recv_buf``. This should match
82+
the layout used during allgather: use the padded send buffer shape when
83+
padding is applied (e.g., NCCL), or the original send buffer shape when
84+
no padding is used.
85+
displs : list, optional
86+
Starting offsets in recv_buf for each rank's data, used when chunks have
87+
variable sizes (e.g., mpi_allgather with displacements).
7788
7889
Returns
7990
-------
80-
chunks: :obj:`list`
81-
A list of `cupy.ndarray` from each GPU with the padded element removed
91+
chunks : list of ndarray
92+
List of arrays (NumPy or CuPy, depending on ``engine``), one per rank,
93+
reconstructed to their original shapes with any padding removed.
8294
"""
8395
ndev = len(send_buf_shapes)
8496
if displs is not None:

0 commit comments

Comments
 (0)