|
1 | 1 | __all__ = [ |
2 | | - "_prepare_allgather_inputs", |
3 | 2 | "_unroll_allgather_recv" |
4 | 3 | ] |
5 | 4 |
|
6 | | - |
7 | 5 | import numpy as np |
8 | | -from pylops.utils.backend import get_module |
9 | 6 |
|
10 | 7 |
|
11 | | -# TODO: return type annotation for both cupy and numpy |
12 | | -def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): |
13 | | - r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) |
| 8 | +def _unroll_allgather_recv(recv_buf, buffer_chunk_shape, send_buf_shapes, displs=None) -> list: |
| 9 | + r"""Unroll recv_buf after Buffered Allgather (MPI and NCCL) |
14 | 10 |
|
15 | | - Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device. |
16 | | - Therefore, padding is required when the array is not evenly partitioned across |
17 | | - all the ranks. The padding is applied such that the each dimension of the sending buffers |
18 | | - is equal to the max size of that dimension across all ranks. |
| 11 | + Depending on the provided parameters, the function: |
| 12 | + - uses ``displs`` and element counts to extract variable-sized chunks. |
| 13 | + - removes padding and reshapes each chunk using ``chunk_shape``. |
19 | 14 |
|
20 | | - Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size |
21 | | -
|
22 | | - Parameters |
23 | | - ---------- |
24 | | - send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like |
25 | | - The data buffer from the local GPU to be sent for allgather. |
26 | | - send_buf_shapes: :obj:`list` |
27 | | - A list of shapes for each GPU send_buf (used to calculate padding size) |
28 | | - engine : :obj:`str` |
29 | | - Engine used to store array (``numpy`` or ``cupy``) |
30 | | -
|
31 | | - Returns |
32 | | - ------- |
33 | | - send_buf: :obj:`cupy.ndarray` |
34 | | - A buffer containing the data and padded elements to be sent by this rank. |
35 | | - recv_buf : :obj:`cupy.ndarray` |
36 | | - An empty, padded buffer to gather data from all GPUs. |
37 | | - """ |
38 | | - ncp = get_module(engine) |
39 | | - sizes_each_dim = list(zip(*send_buf_shapes)) |
40 | | - send_shape = tuple(map(max, sizes_each_dim)) |
41 | | - pad_size = [ |
42 | | - (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) |
43 | | - ] |
44 | | - |
45 | | - send_buf = ncp.pad( |
46 | | - send_buf, pad_size, mode="constant", constant_values=0 |
47 | | - ) |
48 | | - |
49 | | - ndev = len(send_buf_shapes) |
50 | | - recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) |
51 | | - |
52 | | - return send_buf, recv_buf |
53 | | - |
54 | | - |
55 | | -def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: |
56 | | - r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL) |
57 | | -
|
58 | | - Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays |
59 | | - Each GPU may send array with a different shape, so the return type has to be a list of array |
60 | | - instead of the concatenated array. |
| 15 | + Each rank may send an array with a different shape, so the return type is a list of arrays |
| 16 | + instead of a concatenated array. |
61 | 17 |
|
62 | 18 | Parameters |
63 | 19 | ---------- |
64 | 20 | recv_buf: :obj:`cupy.ndarray` or array-like |
65 | | - The data buffer returned from nccl_allgather call |
66 | | - padded_send_buf_shape: :obj:`tuple`:int |
67 | | - The size of send_buf after padding used in nccl_allgather |
| 21 | + The data buffer returned from the allgather call |
68 | 22 | send_buf_shapes: :obj:`list` |
69 | | - A list of original shapes for each GPU send_buf prior to padding |
70 | | -
|
| 23 | + A list of original shapes of each rank's send_buf before any padding. |
| 24 | + buffer_chunk_shape : tuple |
| 25 | + Shape of each rank’s data as stored in ``recv_buf``. This should match |
| 26 | + the layout used during allgather: use the padded send buffer shape when |
| 27 | + padding is applied (e.g., NCCL), or the original send buffer shape when |
| 28 | + no padding is used. |
| 29 | + displs : list, optional |
| 30 | + Starting offsets in recv_buf for each rank's data, used when chunks have |
| 31 | + variable sizes (e.g., mpi_allgather with displacements). |
71 | 32 | Returns |
72 | 33 | ------- |
73 | | - chunks: :obj:`list` |
74 | | - A list of `cupy.ndarray` from each GPU with the padded element removed |
| 34 | + chunks : list of ndarray |
| 35 | + List of arrays (NumPy or CuPy, depending on ``engine``), one per rank, |
| 36 | + reconstructed to their original shapes with any padding removed. |
75 | 37 | """ |
76 | 38 | ndev = len(send_buf_shapes) |
77 | | - # extract an individual array from each device |
78 | | - chunk_size = np.prod(padded_send_buf_shape) |
79 | | - chunks = [ |
80 | | - recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) |
81 | | - ] |
82 | | - |
83 | | - # Remove padding from each array: the padded value may appear somewhere |
84 | | - # in the middle of the flat array and thus the reshape and slicing for each dimension is required |
85 | | - for i in range(ndev): |
86 | | - slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) |
87 | | - chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] |
88 | | - |
89 | | - return chunks |
| 39 | + if displs is not None: |
| 40 | + recvcounts = [int(np.prod(shape)) for shape in send_buf_shapes] |
| 41 | + # Slice recv_buf using displacements and then reconstruct the original-shaped chunk. |
| 42 | + return [ |
| 43 | + recv_buf[displs[i]:displs[i] + recvcounts[i]].reshape(send_buf_shapes[i]) |
| 44 | + for i in range(ndev) |
| 45 | + ] |
| 46 | + else: |
| 47 | + # extract an individual array from each device |
| 48 | + chunk_size = np.prod(buffer_chunk_shape) |
| 49 | + chunks = [ |
| 50 | + recv_buf[i * chunk_size:(i + 1) * chunk_size] |
| 51 | + for i in range(ndev) |
| 52 | + ] |
| 53 | + # Remove padding from each array: the padded value may appear somewhere |
| 54 | + # in the middle of the flat array and thus the reshape and slicing for each dimension is required |
| 55 | + for i in range(ndev): |
| 56 | + slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) |
| 57 | + chunks[i] = chunks[i].reshape(buffer_chunk_shape)[slicing] |
| 58 | + return chunks |
0 commit comments