Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions pylops_mpi/Distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, NewType, Optional, Union
from typing import Any, NewType, Optional

from mpi4py import MPI
from pylops.utils import NDArray
Expand Down Expand Up @@ -194,38 +194,34 @@ def _allgather_subcomm(self,
def _bcast(self,
base_comm: MPI.Comm,
base_comm_nccl: NcclCommunicatorType,
rank : int,
local_array: NDArray,
index: int,
value: Union[int, NDArray],
send_buf: NDArray,
root: int = 0,
engine: str = "numpy",
) -> None:
"""BCast operation
) -> NDArray:
"""Broadcast operation

Parameters
----------
base_comm : :obj:`MPI.Comm`
Base MPI Communicator.
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
NCCL Communicator.
rank : :obj:`int`
Rank.
local_array : :obj:`numpy.ndarray`
Localy array to be broadcasted.
index : :obj:`int` or :obj:`slice`
Represents the index positions where a value needs to be assigned.
value : :obj:`int` or :obj:`numpy.ndarray`
Represents the value that will be assigned to the local array at
the specified index positions.
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
A buffer containing the data to be broadcast from the root rank.
root : :obj:`int`, optional
The rank of the process that holds the source data.
engine : :obj:`str`, optional
Engine used to store array (``numpy`` or ``cupy``)

Returns
-------
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
The buffer containing the broadcasted data.
"""
if deps.nccl_enabled and base_comm_nccl is not None:
nccl_bcast(base_comm_nccl, local_array, index, value)
else:
mpi_bcast(base_comm, rank, local_array, index, value,
engine=engine)
nccl_bcast(base_comm_nccl, send_buf, root=root)
return send_buf
return mpi_bcast(base_comm, send_buf, root=root, engine=engine)

def _send(self,
base_comm: MPI.Comm,
Expand Down
9 changes: 6 additions & 3 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,12 @@ def __setitem__(self, index, value):
the specified index positions.
"""
if self.partition is Partition.BROADCAST:
self._bcast(self.base_comm, self.base_comm_nccl,
self.rank, self.local_array,
index, value, engine=self.engine)
view = self.local_array[index]
if self.rank == 0:
view[...] = value
view = self._bcast(self.base_comm, self.base_comm_nccl,
view, root=0, engine=self.engine)
self.local_array[index] = view
else:
self.local_array[index] = value

Expand Down
122 changes: 87 additions & 35 deletions pylops_mpi/basicoperators/MatrixMult.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
import numpy as np
from mpi4py import MPI
from typing import Tuple, Literal
from typing import Tuple, Literal, Optional, Any

from pylops.utils.backend import get_module
from pylops.utils.typing import DTypeLike, NDArray
Expand All @@ -17,6 +17,8 @@
MPILinearOperator,
Partition
)
from pylops_mpi.Distributed import DistributedMixIn
from pylops_mpi.DistributedArray import subcomm_split


def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
Expand Down Expand Up @@ -159,7 +161,8 @@ def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Com
if p_prime * p_prime != comm.Get_size():
raise RuntimeError(f"Communicator size must be a perfect square, got {comm.Get_size()!r}")

all_blks = comm.allgather(x.local_array)
comm_nccl = x.base_comm_nccl if comm == x.base_comm else None
all_blks = x._allgather(comm, comm_nccl, x.local_array, engine=x.engine)
nr, nc = orig_shape
br, bc = math.ceil(nr / p_prime), math.ceil(nc / p_prime)
C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype)
Expand All @@ -172,7 +175,7 @@ def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Com
return C


class _MPIBlockMatrixMult(MPILinearOperator):
class _MPIBlockMatrixMult(DistributedMixIn, MPILinearOperator):
r"""MPI Blocked Matrix multiplication

Implement distributed matrix-matrix multiplication between a matrix
Expand Down Expand Up @@ -281,7 +284,10 @@ def __init__(
saveAt: bool = False,
base_comm: MPI.Comm = MPI.COMM_WORLD,
dtype: DTypeLike = "float64",
base_comm_nccl: Optional[Any] = None,
) -> None:
if base_comm_nccl is not None and base_comm is not MPI.COMM_WORLD:
raise ValueError("base_comm_nccl requires base_comm=MPI.COMM_WORLD")
rank = base_comm.Get_rank()
size = base_comm.Get_size()

Expand All @@ -295,8 +301,17 @@ def __init__(
self._row_id = rank // self._P_prime

self.base_comm = base_comm
self.base_comm_nccl = base_comm_nccl
self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id)
self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id)
if base_comm_nccl is not None:
mask_row = [r // self._P_prime for r in range(size)]
mask_col = [r % self._P_prime for r in range(size)]
self._row_comm_nccl = subcomm_split(mask_row, base_comm_nccl)
self._col_comm_nccl = subcomm_split(mask_col, base_comm_nccl)
else:
self._row_comm_nccl = None
self._col_comm_nccl = None

self.A = A.astype(np.dtype(dtype))
if saveAt:
Expand All @@ -316,7 +331,7 @@ def __init__(
self._col_end = min(self.M, self._col_start + block_cols)

self._local_ncols = max(0, self._col_end - self._col_start)
self._rank_col_lens = self.base_comm.allgather(self._local_ncols)
self._rank_col_lens = self._allgather(self.base_comm, None, self._local_ncols)
total_ncols = np.sum(self._rank_col_lens)

self.dims = (self.K, total_ncols)
Expand All @@ -336,17 +351,21 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
partition=Partition.SCATTER,
engine=x.engine,
dtype=output_dtype,
base_comm=self.base_comm
base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl
)

my_own_cols = self._rank_col_lens[self.rank]
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
X_local = x_arr.astype(output_dtype)
Y_local = ncp.vstack(
self._row_comm.allgather(
ncp.matmul(self.A, X_local)
)
row_comm_nccl = self._row_comm_nccl if x.engine == "cupy" else None
Y_tiles = self._allgather(
self._row_comm,
row_comm_nccl,
ncp.matmul(self.A, X_local),
engine=x.engine,
)
Y_local = ncp.vstack(Y_tiles)
y[:] = Y_local.flatten()
return y

Expand Down Expand Up @@ -374,19 +393,27 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
partition=Partition.SCATTER,
engine=x.engine,
dtype=output_dtype,
base_comm=self.base_comm
base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl
)

x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(output_dtype)
X_tile = x_arr[self._row_start:self._row_end, :]
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
Y_local = ncp.matmul(A_local, X_tile)
y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM)
row_comm_nccl = self._row_comm_nccl if x.engine == "cupy" else None
y_layer = self._allreduce(
self._row_comm,
row_comm_nccl,
Y_local.ravel(),
op=MPI.SUM,
engine=x.engine,
).reshape(Y_local.shape)
y[:] = y_layer.flatten()
return y


class _MPISummaMatrixMult(MPILinearOperator):
class _MPISummaMatrixMult(DistributedMixIn, MPILinearOperator):
r"""MPI SUMMA Matrix multiplication

Implements distributed matrix-matrix multiplication using the SUMMA algorithm
Expand Down Expand Up @@ -512,7 +539,10 @@ def __init__(
saveAt: bool = False,
base_comm: MPI.Comm = MPI.COMM_WORLD,
dtype: DTypeLike = "float64",
base_comm_nccl: Optional[Any] = None,
) -> None:
if base_comm_nccl is not None and base_comm is not MPI.COMM_WORLD:
raise ValueError("base_comm_nccl requires base_comm=MPI.COMM_WORLD")
rank = base_comm.Get_rank()
size = base_comm.Get_size()

Expand All @@ -524,8 +554,17 @@ def __init__(
self._row_id, self._col_id = divmod(rank, self._P_prime)

self.base_comm = base_comm
self.base_comm_nccl = base_comm_nccl
self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id)
self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id)
if base_comm_nccl is not None:
mask_row = [r // self._P_prime for r in range(size)]
mask_col = [r % self._P_prime for r in range(size)]
self._row_comm_nccl = subcomm_split(mask_row, base_comm_nccl)
self._col_comm_nccl = subcomm_split(mask_col, base_comm_nccl)
else:
self._row_comm_nccl = None
self._col_comm_nccl = None

self.A = A.astype(np.dtype(dtype))

Expand Down Expand Up @@ -568,15 +607,16 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn
local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm

local_shapes = self.base_comm.allgather(local_n * local_m)
local_shapes = self._allgather(self.base_comm, None, local_n * local_m)

y = DistributedArray(global_shape=(self.N * self.M),
mask=x.mask,
local_shapes=local_shapes,
partition=Partition.SCATTER,
engine=x.engine,
dtype=output_dtype,
base_comm=self.base_comm)
base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl)

# Calculate expected padded dimensions for x
bk = self._K_padded // self._P_prime # block size in K dimension
Expand All @@ -600,8 +640,10 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
for k in range(self._P_prime):
Atemp = self.A.copy() if self._col_id == k else ncp.empty_like(self.A)
Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block)
self._row_comm.Bcast(Atemp, root=k)
self._col_comm.Bcast(Xtemp, root=k)
row_comm_nccl = self._row_comm_nccl if x.engine == "cupy" else None
col_comm_nccl = self._col_comm_nccl if x.engine == "cupy" else None
Atemp = self._bcast(self._row_comm, row_comm_nccl, Atemp, root=k, engine=x.engine)
Xtemp = self._bcast(self._col_comm, col_comm_nccl, Xtemp, root=k, engine=x.engine)
Y_local += ncp.dot(Atemp, Xtemp)

Y_local_unpadded = Y_local[:local_n, :local_m]
Expand All @@ -622,7 +664,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk
local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm

local_shapes = self.base_comm.allgather(local_k * local_m)
local_shapes = self._allgather(self.base_comm, None, local_k * local_m)
# - If A is real: A^H = A^T,
# so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
# - If A is complex: A^H is complex,
Expand All @@ -642,7 +684,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
partition=Partition.SCATTER,
engine=x.engine,
dtype=output_dtype,
base_comm=self.base_comm
base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl
)

# Calculate expected padded dimensions for x
Expand All @@ -664,22 +707,28 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:

A_local = self.At if hasattr(self, "At") else self.A.T.conj()
Y_local = ncp.zeros((self.A.shape[1], bm), dtype=output_dtype)

base_comm_nccl = self.base_comm_nccl if x.engine == "cupy" else None
for k in range(self._P_prime):
requests = []
ATtemp = ncp.empty_like(A_local)
srcA = k * self._P_prime + self._row_id
tagA = (100 + k) * 1000 + self.rank
requests.append(self.base_comm.Irecv(ATtemp, source=srcA, tag=tagA))
if self._row_id == k:
fixed_col = self._col_id
for moving_col in range(self._P_prime):
destA = fixed_col * self._P_prime + moving_col
tagA = (100 + k) * 1000 + destA
requests.append(self.base_comm.Isend(A_local, dest=destA, tag=tagA))
Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block)
requests.append(self._col_comm.Ibcast(Xtemp, root=k))
MPI.Request.Waitall(requests)
col_comm_nccl = self._col_comm_nccl if x.engine == "cupy" else None
Xtemp = self._bcast(self._col_comm, col_comm_nccl, Xtemp, root=k, engine=x.engine)

ATtemp = None
srcA = k * self._P_prime + self._row_id
if srcA == self.rank:
ATtemp = A_local
for moving_col in range(self._P_prime):
if self._row_id == k:
destA = self._col_id * self._P_prime + moving_col
if destA != self.rank:
tagA = (100 + k) * 1000 + destA
self._send(self.base_comm, base_comm_nccl, A_local,
dest=destA, tag=tagA, engine=x.engine)
if self._col_id == moving_col and ATtemp is None:
tagA = (100 + k) * 1000 + self.rank
recv_buf = ncp.empty_like(A_local)
ATtemp = self._recv(self.base_comm, base_comm_nccl, recv_buf,
source=srcA, tag=tagA, engine=x.engine)
Y_local += ncp.dot(ATtemp, Xtemp)

Y_local_unpadded = Y_local[:local_k, :local_m]
Expand All @@ -693,7 +742,8 @@ def MPIMatrixMult(
saveAt: bool = False,
base_comm: MPI.Comm = MPI.COMM_WORLD,
kind: Literal["summa", "block"] = "summa",
dtype: DTypeLike = "float64"):
dtype: DTypeLike = "float64",
base_comm_nccl: Optional[Any] = None):
r"""
MPI Distributed Matrix Multiplication Operator

Expand All @@ -714,6 +764,8 @@ def MPIMatrixMult(
memory). Default is ``False``.
base_comm : :obj:`mpi4py.MPI.Comm`, optional
MPI communicator to use. Defaults to ``MPI.COMM_WORLD``.
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional
NCCL communicator to use when operating on ``cupy`` arrays.
kind : :obj:`str`, optional
Algorithm used to perform matrix multiplication: ``'block'`` for #
block-row-column decomposition, and ``'summa'`` for SUMMA algorithm, or
Expand Down Expand Up @@ -784,8 +836,8 @@ def MPIMatrixMult(

"""
if kind == "summa":
return _MPISummaMatrixMult(A, M, saveAt, base_comm, dtype)
return _MPISummaMatrixMult(A, M, saveAt, base_comm, dtype, base_comm_nccl)
elif kind == "block":
return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype)
return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype, base_comm_nccl)
else:
raise NotImplementedError("kind must be summa or block")
Loading
Loading