Skip to content

Commit e1b1eb0

Browse files
Merge pull request #186 from PyLops/allgatherv
MPI_Allgatherv in `mpi_allgather` function
2 parents d856231 + fe05882 commit e1b1eb0

6 files changed

Lines changed: 159 additions & 97 deletions

File tree

pylops_mpi/Distributed.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
from mpi4py import MPI
44
from pylops.utils import NDArray
55
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
6-
from pylops_mpi.utils._mpi import (mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, mpi_sendrecv,
7-
_prepare_allgather_inputs, _unroll_allgather_recv)
6+
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, mpi_sendrecv
7+
from pylops_mpi.utils._common import _unroll_allgather_recv
88
from pylops_mpi.utils import deps
99

1010
cupy_message = pylops_deps.cupy_import("the DistributedArray module")
1111
nccl_message = deps.nccl_import("the DistributedArray module")
1212

1313
if nccl_message is None and cupy_message is None:
1414
from pylops_mpi.utils._nccl import (
15-
nccl_allgather, nccl_allreduce, nccl_bcast, nccl_send, nccl_recv, nccl_sendrecv
15+
nccl_allgather, nccl_allreduce, nccl_bcast, nccl_send, nccl_recv, nccl_sendrecv, _prepare_allgather_inputs_nccl
1616
)
1717
from cupy.cuda.nccl import NcclCommunicator
1818
else:
@@ -32,7 +32,6 @@ class DistributedMixIn:
3232
MPI installation is not available).
3333
3434
"""
35-
3635
def _allreduce(self,
3736
base_comm: MPI.Comm,
3837
base_comm_nccl: NcclCommunicatorType,
@@ -145,7 +144,7 @@ def _allgather(self,
145144
return nccl_allgather(base_comm_nccl, send_buf, recv_buf)
146145
else:
147146
send_shapes = base_comm.allgather(send_buf.shape)
148-
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
147+
(padded_send, padded_recv) = _prepare_allgather_inputs_nccl(send_buf, send_shapes, engine="cupy")
149148
raw_recv = nccl_allgather(base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
150149
return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
151150
else:
@@ -187,7 +186,7 @@ def _allgather_subcomm(self,
187186
return nccl_allgather(sub_comm, send_buf, recv_buf)
188187
else:
189188
send_shapes = sub_comm._allgather_subcomm(send_buf.shape)
190-
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
189+
(padded_send, padded_recv) = _prepare_allgather_inputs_nccl(send_buf, send_shapes, engine="cupy")
191190
raw_recv = nccl_allgather(sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
192191
return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
193192
else:

pylops_mpi/signalprocessing/Fredholm1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
132132
engine=y.engine)).ravel()
133133
return y
134134

135-
def _rmatvec(self, x: NDArray) -> NDArray:
135+
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
136136
ncp = get_module(x.engine)
137137
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
138138
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"

pylops_mpi/utils/_common.py

Lines changed: 40 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,58 @@
11
__all__ = [
2-
"_prepare_allgather_inputs",
32
"_unroll_allgather_recv"
43
]
54

6-
75
import numpy as np
8-
from pylops.utils.backend import get_module
96

107

11-
# TODO: return type annotation for both cupy and numpy
12-
def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine):
13-
r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
8+
def _unroll_allgather_recv(recv_buf, buffer_chunk_shape, send_buf_shapes, displs=None) -> list:
9+
r"""Unroll recv_buf after Buffered Allgather (MPI and NCCL)
1410
15-
Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device.
16-
Therefore, padding is required when the array is not evenly partitioned across
17-
all the ranks. The padding is applied such that the each dimension of the sending buffers
18-
is equal to the max size of that dimension across all ranks.
11+
Depending on the provided parameters, the function:
12+
- uses ``displs`` and element counts to extract variable-sized chunks.
13+
- removes padding and reshapes each chunk using ``chunk_shape``.
1914
20-
Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size
21-
22-
Parameters
23-
----------
24-
send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like
25-
The data buffer from the local GPU to be sent for allgather.
26-
send_buf_shapes: :obj:`list`
27-
A list of shapes for each GPU send_buf (used to calculate padding size)
28-
engine : :obj:`str`
29-
Engine used to store array (``numpy`` or ``cupy``)
30-
31-
Returns
32-
-------
33-
send_buf: :obj:`cupy.ndarray`
34-
A buffer containing the data and padded elements to be sent by this rank.
35-
recv_buf : :obj:`cupy.ndarray`
36-
An empty, padded buffer to gather data from all GPUs.
37-
"""
38-
ncp = get_module(engine)
39-
sizes_each_dim = list(zip(*send_buf_shapes))
40-
send_shape = tuple(map(max, sizes_each_dim))
41-
pad_size = [
42-
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
43-
]
44-
45-
send_buf = ncp.pad(
46-
send_buf, pad_size, mode="constant", constant_values=0
47-
)
48-
49-
ndev = len(send_buf_shapes)
50-
recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
51-
52-
return send_buf, recv_buf
53-
54-
55-
def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list:
56-
r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL)
57-
58-
Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
59-
Each GPU may send array with a different shape, so the return type has to be a list of array
60-
instead of the concatenated array.
15+
Each rank may send an array with a different shape, so the return type is a list of arrays
16+
instead of a concatenated array.
6117
6218
Parameters
6319
----------
6420
recv_buf: :obj:`cupy.ndarray` or array-like
65-
The data buffer returned from nccl_allgather call
66-
padded_send_buf_shape: :obj:`tuple`:int
67-
The size of send_buf after padding used in nccl_allgather
21+
The data buffer returned from the allgather call
6822
send_buf_shapes: :obj:`list`
69-
A list of original shapes for each GPU send_buf prior to padding
70-
23+
A list of original shapes of each rank's send_buf before any padding.
24+
buffer_chunk_shape : tuple
25+
Shape of each rank’s data as stored in ``recv_buf``. This should match
26+
the layout used during allgather: use the padded send buffer shape when
27+
padding is applied (e.g., NCCL), or the original send buffer shape when
28+
no padding is used.
29+
displs : list, optional
30+
Starting offsets in recv_buf for each rank's data, used when chunks have
31+
variable sizes (e.g., mpi_allgather with displacements).
7132
Returns
7233
-------
73-
chunks: :obj:`list`
74-
A list of `cupy.ndarray` from each GPU with the padded element removed
34+
chunks : list of ndarray
35+
List of arrays (NumPy or CuPy, depending on ``engine``), one per rank,
36+
reconstructed to their original shapes with any padding removed.
7537
"""
7638
ndev = len(send_buf_shapes)
77-
# extract an individual array from each device
78-
chunk_size = np.prod(padded_send_buf_shape)
79-
chunks = [
80-
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
81-
]
82-
83-
# Remove padding from each array: the padded value may appear somewhere
84-
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
85-
for i in range(ndev):
86-
slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
87-
chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing]
88-
89-
return chunks
39+
if displs is not None:
40+
recvcounts = [int(np.prod(shape)) for shape in send_buf_shapes]
41+
# Slice recv_buf using displacements and then reconstruct the original-shaped chunk.
42+
return [
43+
recv_buf[displs[i]:displs[i] + recvcounts[i]].reshape(send_buf_shapes[i])
44+
for i in range(ndev)
45+
]
46+
else:
47+
# extract an individual array from each device
48+
chunk_size = np.prod(buffer_chunk_shape)
49+
chunks = [
50+
recv_buf[i * chunk_size:(i + 1) * chunk_size]
51+
for i in range(ndev)
52+
]
53+
# Remove padding from each array: the padded value may appear somewhere
54+
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
55+
for i in range(ndev):
56+
slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
57+
chunks[i] = chunks[i].reshape(buffer_chunk_shape)[slicing]
58+
return chunks

pylops_mpi/utils/_mpi.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,33 @@
44
"mpi_bcast",
55
"mpi_send",
66
"mpi_recv",
7-
"mpi_sendrecv"
7+
"mpi_sendrecv",
8+
"_prepare_allgather_inputs_mpi"
89
]
910

10-
from typing import Optional
11+
from typing import List, Optional
12+
import numpy as np
1113

1214
from mpi4py import MPI
1315
from pylops.utils import NDArray
1416
from pylops.utils.backend import get_module
1517
from pylops_mpi.utils import deps
16-
from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv
18+
from pylops_mpi.utils._common import _unroll_allgather_recv
1719

1820

1921
def mpi_allgather(base_comm: MPI.Comm,
2022
send_buf: NDArray,
2123
recv_buf: Optional[NDArray] = None,
2224
engine: str = "numpy",
23-
) -> NDArray:
24-
"""MPI_Allallgather/allallgather
25+
) -> List[NDArray]:
26+
"""MPI_Allgather/allgather
2527
26-
Dispatch allgather routine based on type of input and availability of
27-
CUDA-Aware MPI
28+
Dispatch the appropriate allgather routine based on buffer sizes and
29+
CUDA-aware MPI availability.
30+
31+
If all ranks provide buffers of equal size, the standard `Allgather`
32+
collective is used. Otherwise, `Allgatherv` is invoked to handle
33+
variable-sized buffers.
2834
2935
Parameters
3036
----------
@@ -40,16 +46,19 @@ def mpi_allgather(base_comm: MPI.Comm,
4046
4147
Returns
4248
-------
43-
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
44-
A buffer containing the gathered data from all ranks.
49+
recv_buf : :obj:`list`
50+
A list of arrays containing the gathered data from all ranks.
4551
4652
"""
4753
if deps.cuda_aware_mpi_enabled or engine == "numpy":
4854
send_shapes = base_comm.allgather(send_buf.shape)
49-
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine)
50-
recv_buffer_to_use = recv_buf if recv_buf else padded_recv
51-
_mpi_calls(base_comm, "Allgather", padded_send, recv_buffer_to_use, engine=engine)
52-
return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes)
55+
send_buf, recv_buf, displs, recvcounts = _prepare_allgather_inputs_mpi(send_buf, send_shapes, engine)
56+
if len(set(send_shapes)) == 1:
57+
_mpi_calls(base_comm, "Allgather", send_buf, recv_buf, engine=engine)
58+
else:
59+
_mpi_calls(base_comm, "Allgatherv", send_buf,
60+
[recv_buf, recvcounts, displs, MPI._typedict[send_buf.dtype.char]], engine=engine)
61+
return _unroll_allgather_recv(recv_buf, send_buf.shape, send_shapes, displs)
5362
else:
5463
# CuPy with non-CUDA-aware MPI
5564
if recv_buf is None:
@@ -293,3 +302,43 @@ def _mpi_calls(comm: MPI.Comm, func: str, *args, engine: Optional[str] = "numpy"
293302
ncp.cuda.Device().synchronize()
294303
mpi_func = getattr(comm, func)
295304
return mpi_func(*args, **kwargs)
305+
306+
307+
def _prepare_allgather_inputs_mpi(send_buf, send_buf_shapes, engine):
308+
r"""Prepare send_buf and recv_buf for MPI allgather (mpi_allgather)
309+
310+
Buffered Allgather (MPI) supports both uniform and variable-sized data across ranks. Unlike NCCL, padding is
311+
not required when array sizes differ. Instead, displacements are used to correctly place each rank’s data
312+
within the received buffer. The function ensures that the send_buf is contiguous.
313+
314+
Parameters
315+
----------
316+
send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like
317+
The data buffer to be sent for allgather.
318+
send_buf_shapes: :obj:`list`
319+
A list of shapes for each send_buf (used to calculate padding size)
320+
engine : :obj:`str`
321+
Engine used to store array (``numpy`` or ``cupy``)
322+
323+
Returns
324+
-------
325+
send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` or array-like
326+
A buffer containing the data and padded elements to be sent by this rank.
327+
recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like
328+
A buffer to gather data from all ranks.
329+
displs : list, optional
330+
Starting offsets in recv_buf for each rank's data, used when chunks have
331+
variable sizes
332+
recvcounts: :obj:`list`
333+
A list of element counts from all ranks, where each entry corresponds to one rank.
334+
"""
335+
ncp = get_module(engine)
336+
recvcounts = [int(np.prod(shape)) for shape in send_buf_shapes]
337+
recv_buf = ncp.zeros(sum(recvcounts), dtype=send_buf.dtype)
338+
if len(set(send_buf_shapes)) == 1:
339+
displs = None
340+
else:
341+
displs = [0]
342+
for i in range(1, len(recvcounts)):
343+
displs.append(displs[i - 1] + recvcounts[i - 1])
344+
return ncp.ascontiguousarray(send_buf), recv_buf, displs, recvcounts

pylops_mpi/utils/_nccl.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88
"nccl_asarray",
99
"nccl_send",
1010
"nccl_recv",
11-
"nccl_sendrecv"
11+
"nccl_sendrecv",
12+
"_prepare_allgather_inputs_nccl"
1213
]
1314

1415
from enum import IntEnum
1516
from mpi4py import MPI
1617
import os
1718
import cupy as cp
1819
import cupy.cuda.nccl as nccl
19-
from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv
20+
from pylops.utils.backend import get_module
21+
from pylops_mpi.utils._common import _unroll_allgather_recv
2022

2123
cupy_to_nccl_dtype = {
2224
"float32": nccl.NCCL_FLOAT32,
@@ -282,7 +284,7 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
282284
Global array gathered from all GPUs and concatenated along `axis`.
283285
"""
284286

285-
send_buf, recv_buf = _prepare_allgather_inputs(local_array, local_shapes, engine="cupy")
287+
send_buf, recv_buf = _prepare_allgather_inputs_nccl(local_array, local_shapes, engine="cupy")
286288
nccl_allgather(nccl_comm, send_buf, recv_buf)
287289
chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, local_shapes)
288290

@@ -356,3 +358,46 @@ def nccl_sendrecv(nccl_comm, sendbuf, dest, recvbuf, source):
356358
nccl_send(nccl_comm, sendbuf, dest, sendbuf.size)
357359
nccl_recv(nccl_comm, recvbuf, source)
358360
nccl.groupEnd()
361+
362+
363+
def _prepare_allgather_inputs_nccl(send_buf, send_buf_shapes, engine):
364+
r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
365+
366+
Buffered Allgather (NCCL) requires the sending buffer to have the same size for every device.
367+
Therefore, padding is required when the array is not evenly partitioned across
368+
all the ranks. The padding is applied such that each dimension of the sending buffers
369+
is equal to the max size of that dimension across all ranks.
370+
371+
Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size
372+
373+
Parameters
374+
----------
375+
send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like
376+
The data buffer from the local GPU to be sent for allgather.
377+
send_buf_shapes: :obj:`list`
378+
A list of shapes for each GPU send_buf (used to calculate padding size)
379+
engine : :obj:`str`
380+
Engine used to store array (``numpy`` or ``cupy``)
381+
382+
Returns
383+
-------
384+
send_buf: :obj:`cupy.ndarray`
385+
A buffer containing the data and padded elements to be sent by this rank.
386+
recv_buf : :obj:`cupy.ndarray`
387+
An empty, padded buffer to gather data from all GPUs.
388+
"""
389+
ncp = get_module(engine)
390+
sizes_each_dim = list(zip(*send_buf_shapes))
391+
send_shape = tuple(map(max, sizes_each_dim))
392+
pad_size = [
393+
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
394+
]
395+
396+
send_buf = ncp.pad(
397+
send_buf, pad_size, mode="constant", constant_values=0
398+
)
399+
400+
ndev = len(send_buf_shapes)
401+
recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
402+
403+
return send_buf, recv_buf

tests_nccl/test_ncclutils_nccl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from numpy.testing import assert_allclose
99
import pytest
1010

11-
from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather
12-
from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv
11+
from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather, _prepare_allgather_inputs_nccl
12+
from pylops_mpi.utils._common import _unroll_allgather_recv
1313
from pylops_mpi.utils.deps import nccl_enabled
1414

1515
np.random.seed(42)
@@ -90,7 +90,7 @@ def test_allgather_differentsize_withrecbuf(par):
9090

9191
# Gathered array
9292
send_shapes = MPI.COMM_WORLD.allgather(local_array.shape)
93-
send_buf, recv_buf = _prepare_allgather_inputs(local_array, send_shapes, engine="cupy")
93+
send_buf, recv_buf = _prepare_allgather_inputs_nccl(local_array, send_shapes, engine="cupy")
9494
recv_buf = nccl_allgather(nccl_comm, send_buf, recv_buf)
9595
chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, send_shapes)
9696
gathered_array = cp.concatenate(chunks)

0 commit comments

Comments
 (0)