diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 3f1cf068..905530c5 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -1,4 +1,4 @@ -from typing import Any, NewType, Optional, Union +from typing import Any, NewType, Optional from mpi4py import MPI from pylops.utils import NDArray @@ -194,13 +194,11 @@ def _allgather_subcomm(self, def _bcast(self, base_comm: MPI.Comm, base_comm_nccl: NcclCommunicatorType, - rank : int, - local_array: NDArray, - index: int, - value: Union[int, NDArray], + send_buf: NDArray, + root: int = 0, engine: str = "numpy", - ) -> None: - """BCast operation + ) -> NDArray: + """Broadcast operation Parameters ---------- @@ -208,24 +206,22 @@ def _bcast(self, Base MPI Communicator. base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` NCCL Communicator. - rank : :obj:`int` - Rank. - local_array : :obj:`numpy.ndarray` - Localy array to be broadcasted. - index : :obj:`int` or :obj:`slice` - Represents the index positions where a value needs to be assigned. - value : :obj:`int` or :obj:`numpy.ndarray` - Represents the value that will be assigned to the local array at - the specified index positions. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the data to be broadcast from the root rank. + root : :obj:`int`, optional + The rank of the process that holds the source data. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) + Returns + ------- + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The buffer containing the broadcasted data. """ if deps.nccl_enabled and base_comm_nccl is not None: - nccl_bcast(base_comm_nccl, local_array, index, value) - else: - mpi_bcast(base_comm, rank, local_array, index, value, - engine=engine) + nccl_bcast(base_comm_nccl, send_buf, root=root) + return send_buf + return mpi_bcast(base_comm, send_buf, root=root, engine=engine) def _send(self, base_comm: MPI.Comm, diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index f82ff0ab..9e71bf8e 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -204,9 +204,12 @@ def __setitem__(self, index, value): the specified index positions. """ if self.partition is Partition.BROADCAST: - self._bcast(self.base_comm, self.base_comm_nccl, - self.rank, self.local_array, - index, value, engine=self.engine) + view = self.local_array[index] + if self.rank == 0: + view[...] = value + view = self._bcast(self.base_comm, self.base_comm_nccl, + view, root=0, engine=self.engine) + self.local_array[index] = view else: self.local_array[index] = value diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 9487c3ce..5fa08cba 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -7,7 +7,7 @@ import math import numpy as np from mpi4py import MPI -from typing import Tuple, Literal +from typing import Tuple, Literal, Optional, Any from pylops.utils.backend import get_module from pylops.utils.typing import DTypeLike, NDArray @@ -17,6 +17,8 @@ MPILinearOperator, Partition ) +from pylops_mpi.Distributed import DistributedMixIn +from pylops_mpi.DistributedArray import subcomm_split def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): @@ -159,7 +161,8 @@ def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Com if p_prime * p_prime != comm.Get_size(): raise RuntimeError(f"Communicator size must be a perfect square, got {comm.Get_size()!r}") - all_blks = comm.allgather(x.local_array) + comm_nccl = x.base_comm_nccl if comm == x.base_comm else None + all_blks = x._allgather(comm, comm_nccl, x.local_array, engine=x.engine) nr, nc = orig_shape br, bc = math.ceil(nr / p_prime), math.ceil(nc / p_prime) C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype) @@ -172,7 +175,7 @@ def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Com return C -class _MPIBlockMatrixMult(MPILinearOperator): +class _MPIBlockMatrixMult(DistributedMixIn, MPILinearOperator): r"""MPI Blocked Matrix multiplication Implement distributed matrix-matrix multiplication between a matrix @@ -281,7 +284,10 @@ def __init__( saveAt: bool = False, base_comm: MPI.Comm = MPI.COMM_WORLD, dtype: DTypeLike = "float64", + base_comm_nccl: Optional[Any] = None, ) -> None: + if base_comm_nccl is not None and base_comm is not MPI.COMM_WORLD: + raise ValueError("base_comm_nccl requires base_comm=MPI.COMM_WORLD") rank = base_comm.Get_rank() size = base_comm.Get_size() @@ -295,8 +301,17 @@ def __init__( self._row_id = rank // self._P_prime self.base_comm = base_comm + self.base_comm_nccl = base_comm_nccl self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id) self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id) + if base_comm_nccl is not None: + mask_row = [r // self._P_prime for r in range(size)] + mask_col = [r % self._P_prime for r in range(size)] + self._row_comm_nccl = subcomm_split(mask_row, base_comm_nccl) + self._col_comm_nccl = subcomm_split(mask_col, base_comm_nccl) + else: + self._row_comm_nccl = None + self._col_comm_nccl = None self.A = A.astype(np.dtype(dtype)) if saveAt: @@ -316,7 +331,7 @@ def __init__( self._col_end = min(self.M, self._col_start + block_cols) self._local_ncols = max(0, self._col_end - self._col_start) - self._rank_col_lens = self.base_comm.allgather(self._local_ncols) + self._rank_col_lens = self._allgather(self.base_comm, None, self._local_ncols) total_ncols = np.sum(self._rank_col_lens) self.dims = (self.K, total_ncols) @@ -336,17 +351,21 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: partition=Partition.SCATTER, engine=x.engine, dtype=output_dtype, - base_comm=self.base_comm + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl ) my_own_cols = self._rank_col_lens[self.rank] x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) X_local = x_arr.astype(output_dtype) - Y_local = ncp.vstack( - self._row_comm.allgather( - ncp.matmul(self.A, X_local) - ) + row_comm_nccl = self._row_comm_nccl if x.engine == "cupy" else None + Y_tiles = self._allgather( + self._row_comm, + row_comm_nccl, + ncp.matmul(self.A, X_local), + engine=x.engine, ) + Y_local = ncp.vstack(Y_tiles) y[:] = Y_local.flatten() return y @@ -374,19 +393,27 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: partition=Partition.SCATTER, engine=x.engine, dtype=output_dtype, - base_comm=self.base_comm + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl ) x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(output_dtype) X_tile = x_arr[self._row_start:self._row_end, :] A_local = self.At if hasattr(self, "At") else self.A.T.conj() Y_local = ncp.matmul(A_local, X_tile) - y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM) + row_comm_nccl = self._row_comm_nccl if x.engine == "cupy" else None + y_layer = self._allreduce( + self._row_comm, + row_comm_nccl, + Y_local.ravel(), + op=MPI.SUM, + engine=x.engine, + ).reshape(Y_local.shape) y[:] = y_layer.flatten() return y -class _MPISummaMatrixMult(MPILinearOperator): +class _MPISummaMatrixMult(DistributedMixIn, MPILinearOperator): r"""MPI SUMMA Matrix multiplication Implements distributed matrix-matrix multiplication using the SUMMA algorithm @@ -512,7 +539,10 @@ def __init__( saveAt: bool = False, base_comm: MPI.Comm = MPI.COMM_WORLD, dtype: DTypeLike = "float64", + base_comm_nccl: Optional[Any] = None, ) -> None: + if base_comm_nccl is not None and base_comm is not MPI.COMM_WORLD: + raise ValueError("base_comm_nccl requires base_comm=MPI.COMM_WORLD") rank = base_comm.Get_rank() size = base_comm.Get_size() @@ -524,8 +554,17 @@ def __init__( self._row_id, self._col_id = divmod(rank, self._P_prime) self.base_comm = base_comm + self.base_comm_nccl = base_comm_nccl self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id) self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id) + if base_comm_nccl is not None: + mask_row = [r // self._P_prime for r in range(size)] + mask_col = [r % self._P_prime for r in range(size)] + self._row_comm_nccl = subcomm_split(mask_row, base_comm_nccl) + self._col_comm_nccl = subcomm_split(mask_col, base_comm_nccl) + else: + self._row_comm_nccl = None + self._col_comm_nccl = None self.A = A.astype(np.dtype(dtype)) @@ -568,7 +607,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm - local_shapes = self.base_comm.allgather(local_n * local_m) + local_shapes = self._allgather(self.base_comm, None, local_n * local_m) y = DistributedArray(global_shape=(self.N * self.M), mask=x.mask, @@ -576,7 +615,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: partition=Partition.SCATTER, engine=x.engine, dtype=output_dtype, - base_comm=self.base_comm) + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl) # Calculate expected padded dimensions for x bk = self._K_padded // self._P_prime # block size in K dimension @@ -600,8 +640,10 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: for k in range(self._P_prime): Atemp = self.A.copy() if self._col_id == k else ncp.empty_like(self.A) Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block) - self._row_comm.Bcast(Atemp, root=k) - self._col_comm.Bcast(Xtemp, root=k) + row_comm_nccl = self._row_comm_nccl if x.engine == "cupy" else None + col_comm_nccl = self._col_comm_nccl if x.engine == "cupy" else None + Atemp = self._bcast(self._row_comm, row_comm_nccl, Atemp, root=k, engine=x.engine) + Xtemp = self._bcast(self._col_comm, col_comm_nccl, Xtemp, root=k, engine=x.engine) Y_local += ncp.dot(Atemp, Xtemp) Y_local_unpadded = Y_local[:local_n, :local_m] @@ -622,7 +664,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm - local_shapes = self.base_comm.allgather(local_k * local_m) + local_shapes = self._allgather(self.base_comm, None, local_k * local_m) # - If A is real: A^H = A^T, # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real) # - If A is complex: A^H is complex, @@ -642,7 +684,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: partition=Partition.SCATTER, engine=x.engine, dtype=output_dtype, - base_comm=self.base_comm + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl ) # Calculate expected padded dimensions for x @@ -664,22 +707,28 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: A_local = self.At if hasattr(self, "At") else self.A.T.conj() Y_local = ncp.zeros((self.A.shape[1], bm), dtype=output_dtype) - + base_comm_nccl = self.base_comm_nccl if x.engine == "cupy" else None for k in range(self._P_prime): - requests = [] - ATtemp = ncp.empty_like(A_local) - srcA = k * self._P_prime + self._row_id - tagA = (100 + k) * 1000 + self.rank - requests.append(self.base_comm.Irecv(ATtemp, source=srcA, tag=tagA)) - if self._row_id == k: - fixed_col = self._col_id - for moving_col in range(self._P_prime): - destA = fixed_col * self._P_prime + moving_col - tagA = (100 + k) * 1000 + destA - requests.append(self.base_comm.Isend(A_local, dest=destA, tag=tagA)) Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block) - requests.append(self._col_comm.Ibcast(Xtemp, root=k)) - MPI.Request.Waitall(requests) + col_comm_nccl = self._col_comm_nccl if x.engine == "cupy" else None + Xtemp = self._bcast(self._col_comm, col_comm_nccl, Xtemp, root=k, engine=x.engine) + + ATtemp = None + srcA = k * self._P_prime + self._row_id + if srcA == self.rank: + ATtemp = A_local + for moving_col in range(self._P_prime): + if self._row_id == k: + destA = self._col_id * self._P_prime + moving_col + if destA != self.rank: + tagA = (100 + k) * 1000 + destA + self._send(self.base_comm, base_comm_nccl, A_local, + dest=destA, tag=tagA, engine=x.engine) + if self._col_id == moving_col and ATtemp is None: + tagA = (100 + k) * 1000 + self.rank + recv_buf = ncp.empty_like(A_local) + ATtemp = self._recv(self.base_comm, base_comm_nccl, recv_buf, + source=srcA, tag=tagA, engine=x.engine) Y_local += ncp.dot(ATtemp, Xtemp) Y_local_unpadded = Y_local[:local_k, :local_m] @@ -693,7 +742,8 @@ def MPIMatrixMult( saveAt: bool = False, base_comm: MPI.Comm = MPI.COMM_WORLD, kind: Literal["summa", "block"] = "summa", - dtype: DTypeLike = "float64"): + dtype: DTypeLike = "float64", + base_comm_nccl: Optional[Any] = None): r""" MPI Distributed Matrix Multiplication Operator @@ -714,6 +764,8 @@ def MPIMatrixMult( memory). Default is ``False``. base_comm : :obj:`mpi4py.MPI.Comm`, optional MPI communicator to use. Defaults to ``MPI.COMM_WORLD``. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional + NCCL communicator to use when operating on ``cupy`` arrays. kind : :obj:`str`, optional Algorithm used to perform matrix multiplication: ``'block'`` for # block-row-column decomposition, and ``'summa'`` for SUMMA algorithm, or @@ -784,8 +836,8 @@ def MPIMatrixMult( """ if kind == "summa": - return _MPISummaMatrixMult(A, M, saveAt, base_comm, dtype) + return _MPISummaMatrixMult(A, M, saveAt, base_comm, dtype, base_comm_nccl) elif kind == "block": - return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype) + return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype, base_comm_nccl) else: raise NotImplementedError("kind must be summa or block") diff --git a/pylops_mpi/signalprocessing/Fredholm1.py b/pylops_mpi/signalprocessing/Fredholm1.py index 2969e3c9..2ab7e65d 100644 --- a/pylops_mpi/signalprocessing/Fredholm1.py +++ b/pylops_mpi/signalprocessing/Fredholm1.py @@ -1,6 +1,8 @@ +import math import numpy as np from mpi4py import MPI +from typing import Optional, Any, Tuple from pylops.utils.backend import get_module from pylops.utils.typing import DTypeLike, NDArray @@ -9,8 +11,373 @@ MPILinearOperator, Partition ) +from pylops_mpi.Distributed import DistributedMixIn +from pylops_mpi.DistributedArray import subcomm_split +def _choose_pb_and_p(P: int, nsl: int) -> Tuple[int, int]: + """ + Choose Pb to minimize the α-β model under constraint P/Pb is a perfect square. + Heuristic: largest Pb <= nsl such that P % Pb == 0 and is_square(P/Pb). + """ + best = None + for Pb in range(min(P, nsl), 0, -1): + if P % Pb != 0: continue + P2 = P // Pb + p = int(math.isqrt(P2)) + if p * p == P2: + best = (Pb, p) + break + if best is None: + raise ValueError( + f"No valid (Pb,p) with Pb<=nsl and P/Pb square. P={P}, nsl={nsl}." + ) + return best + + +class MPIFredholm1SUMMA(DistributedMixIn, MPILinearOperator): + """ + Distributed Fredholm-1 using batched SUMMA on contraction: + d[k,:,:] = G[k,:,:] @ m[k,:,:] + + G is distributed as tiles (batch, x_tile, y_tile) over (batch_group, grid_row, grid_col). + m is distributed as tiles (batch, y_tile, z_tile) over (batch_group, grid_row, grid_col). + d is distributed as tiles (batch, x_tile, z_tile) over (batch_group, grid_row, grid_col). + + This operator uses Partition.SCATTER for both input and output. + + Parameters + ---------- + G_local : ndarray + Local tile of G of shape (B_g, nx_loc, ny_loc) for this rank. + nz : int + Global nz dimension. + nsl_global : int, optional + Global number of slices. If None, inferred from batch sizes across ranks. + saveGt : bool, optional + Save local conjugate-transpose of G tile for adjoint. + pb : int, optional + Number of batch groups. If None, auto-chosen. + base_comm : MPI.Comm + base_comm_nccl : optional NCCL comm (only if base_comm == COMM_WORLD) + dtype : str + """ + def __init__( + self, + G_local: NDArray, + nz: int, + nsl_global: Optional[int] = None, + saveGt: bool = False, + pb: Optional[int] = None, + base_comm: MPI.Comm = MPI.COMM_WORLD, + base_comm_nccl: Optional[Any] = None, + dtype: DTypeLike = "float64", + ) -> None: + if base_comm_nccl is not None and base_comm is not MPI.COMM_WORLD: + raise ValueError("base_comm_nccl requires base_comm=MPI.COMM_WORLD") + + self.base_comm = base_comm + self.base_comm_nccl = base_comm_nccl + self.rank = base_comm.Get_rank() + self.size = base_comm.Get_size() + + # Local batch size + if G_local.ndim != 3: raise ValueError(f"G_local must be 3D (B,nx_loc,ny_loc). Got {G_local.shape}") + self.B = int(G_local.shape[0]) + self.nz = int(nz) + + # Determine batch-grouping (Pb) and inner SUMMA grid size (p) + # Need nsl_global for optimal choice; if not provided, use a conservative estimate from all ranks: + if nsl_global is None: + # Infer: sum(B_rank) over all ranks = p^2 * nsl_global (only true after we pick p) + # So we first pick Pb,p using nsl_est = sum(B_rank)/min_square_factor_guess + # Practical approach: assume Pb=1 initially, require P square, then compute nsl_global + # If user doesn't provide nsl_global, we do *no* auto-optimization; pb must be given or P must be square + if pb is None: + p0 = int(math.isqrt(self.size)) + if p0 * p0 != self.size: + raise ValueError( + "If nsl_global is not provided, pb must be provided, " + "or P must be a perfect square (so pb=1 is valid)." + ) + pb = 1 + + if pb is None: + pb, p = _choose_pb_and_p(self.size, int(nsl_global)) + else: + # For now we error but we could do something like where we would deactivate certain procs + if self.size % pb != 0: + raise ValueError(f"pb must divide P. Got pb={pb}, P={self.size}.") + P2 = self.size // pb + p = int(math.isqrt(P2)) + if p * p != P2: + raise ValueError(f"P/pb must be a perfect square. Got P/pb={P2}.") + if nsl_global is not None and pb > nsl_global: + raise ValueError(f"pb must be <= nsl_global. Got pb={pb}, nsl_global={nsl_global}.") + + self.pb = int(pb) + self.p = int(p) + self.P2 = self.p * self.p + + # Batch-group id and rank within group + self.batch_id = self.rank // self.P2 + self.rank_in_group = self.rank % self.P2 + + if self.batch_id >= self.pb: + raise ValueError( + f"Rank mapping expects P == pb*p^2. " + f"Got P={self.size}, pb={self.pb}, p={self.p} => pb*p^2={self.pb*self.P2}." + ) + + # Create batch communicator + self.batch_comm = base_comm.Split(color=self.batch_id, key=self.rank_in_group) + + # Within group, 2D grid coords + self.row_id, self.col_id = divmod(self.rank_in_group, self.p) + + # Row/col communicators (within group) + self.row_comm = self.batch_comm.Split(color=self.row_id, key=self.col_id) + self.col_comm = self.batch_comm.Split(color=self.col_id, key=self.row_id) + + # # NCCL subcomms if provided + if base_comm_nccl is not None: + # subcomm_split expects mask per WORLD rank + # batch_comm: group by batch_id + mask_batch = [r // self.P2 for r in range(self.size)] + self.batch_comm_nccl = subcomm_split(mask_batch, base_comm_nccl) + + # row_comm: group by (batch_id,row_id) + mask_row = [] + mask_col = [] + for r in range(self.size): + bid = r // self.P2 + rig = r % self.P2 + rr, cc = divmod(rig, self.p) + mask_row.append(bid * self.p + rr) + mask_col.append(bid * self.p + cc) + self.row_comm_nccl = subcomm_split(mask_row, base_comm_nccl) + self.col_comm_nccl = subcomm_split(mask_col, base_comm_nccl) + else: + self.batch_comm_nccl = None + self.row_comm_nccl = None + self.col_comm_nccl = None + + # Store G tile and optional GT + self.G = G_local.astype(np.dtype(dtype)) + if saveGt: + self.GT = self.G.transpose(0, 2, 1).conj() # (B, nx_loc, ny_loc) -> (B, ny_loc, nx_loc) + + # Infer global nx, ny from within-group tiling + # A tile: (nx_loc, ny_loc) where nx is reduced on col_comm, ny on row_comm + nx_loc = self.G.shape[1] + ny_loc = self.G.shape[2] + self.nx = int(self.col_comm.allreduce(nx_loc, op=MPI.SUM)) + self.ny = int(self.row_comm.allreduce(ny_loc, op=MPI.SUM)) + + # Determine global nsl + if nsl_global is None: + # sum B over WORLD ranks = p^2 * sum(B over batch groups) = p^2 * nsl_global + Bsum = int(self.base_comm.allreduce(self.B, op=MPI.SUM)) + if Bsum % self.P2 != 0: + raise ValueError( + f"Cannot infer nsl_global cleanly: sum(B)={Bsum} not divisible by p^2={self.P2}." + ) + self.nsl = Bsum // self.P2 + else: + self.nsl = int(nsl_global) + + # Padding sizes for SUMMA blocks + self.nx_pad = math.ceil(self.nx / self.p) * self.p + self.ny_pad = math.ceil(self.ny / self.p) * self.p + self.nz_pad = math.ceil(self.nz / self.p) * self.p + + self.bn = self.nx_pad // self.p + self.bk = self.ny_pad // self.p + self.bm = self.nz_pad // self.p + + # Local (unpadded) extents for this rank’s output tile (x,z) and input tile (y,z) + self.local_n = max(0, min(self.bn, self.nx - self.row_id * self.bn)) + self.local_k = max(0, min(self.bk, self.ny - self.row_id * self.bk)) # for m (K rows) uses row_id + self.local_ka = max(0, min(self.bk, self.ny - self.col_id * self.bk)) # for G (K cols) uses col_id + self.local_m = max(0, min(self.bm, self.nz - self.col_id * self.bm)) + + # Operator global shapes (conceptual / unpadded) + self.dims_model = (self.nsl, self.ny, self.nz) + self.dims_data = (self.nsl, self.nx, self.nz) + shape = (int(np.prod(self.dims_data)), int(np.prod(self.dims_model))) + super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) + + # Ensure local G matches expected tile sizes in (nx_loc, ny_loc) for A distribution + # We allow edge tiles to be smaller since we will pad later + if self.G.shape[1] != self.local_n or self.G.shape[2] != self.local_ka: + # Not necessarily fatal if user pre-padded; allow larger, but disallow mismatch that breaks slicing + if self.G.shape[1] < self.local_n or self.G.shape[2] < self.local_ka: + raise ValueError( + f"G_local tile too small for this rank. " + f"Expected at least ({self.B},{self.local_n},{self.local_ka}), got {self.G.shape}." + ) + + def _matvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") + + # Input local tile expected shape: (B, local_k (by row_id), local_m (by col_id)) + expected_in = self.B * self.local_k * self.local_m + if x.local_array.size != expected_in: + raise ValueError( + f"Local x size mismatch. Expected {expected_in} elements " + f"(B={self.B}, local_k={self.local_k}, local_m={self.local_m}), " + f"got {x.local_array.size}." + ) + + output_dtype = np.result_type(self.dtype, x.dtype) + + # Output local shapes for SCATTER vector + my_out = self.B * self.local_n * self.local_m + local_shapes = self.base_comm.allgather(my_out) + + y = DistributedArray( + global_shape=int(np.prod(self.dims_data)), + local_shapes=local_shapes, + mask=x.mask, + partition=Partition.SCATTER, + engine=x.engine, + dtype=output_dtype, + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl, + ) + + # Reshape local x tile and pad to (B, bk, bm) + X = x.local_array.reshape((self.B, self.local_k, self.local_m)).astype(output_dtype) + if self.local_k != self.bk or self.local_m != self.bm: + X_padded = ncp.zeros((self.B, self.bk, self.bm), dtype=output_dtype) + X_padded[:, :self.local_k, :self.local_m] = X + X = X_padded + + # Pad local G tile to (B, bn, bk) for SUMMA A tiles + G = self.G[:, :self.local_n, :self.local_ka].astype(output_dtype) + if self.local_n != self.bn or self.local_ka != self.bk: + G_padded = ncp.zeros((self.B, self.bn, self.bk), dtype=output_dtype) + G_padded[:, :self.local_n, :self.local_ka] = G + G = G_padded + + Y = ncp.zeros((self.B, self.bn, self.bm), dtype=output_dtype) + + row_nccl = self.row_comm_nccl if x.engine == "cupy" else None + col_nccl = self.col_comm_nccl if x.engine == "cupy" else None + + # Batched SUMMA + for k in range(self.p): + Atemp = G.copy() if self.col_id == k else ncp.empty_like(G) + Btemp = X.copy() if self.row_id == k else ncp.empty_like(X) + + Atemp = self._bcast(self.row_comm, row_nccl, Atemp, root=k, engine=x.engine) + Btemp = self._bcast(self.col_comm, col_nccl, Btemp, root=k, engine=x.engine) + + Y += ncp.matmul(Atemp, Btemp) + + Y = Y[:, :self.local_n, :self.local_m] # Unpad to local (B, local_n, local_m) and write out + y[:] = Y.ravel() + return y + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") + + # Input to adjoint is data tile: (B, local_n, local_m) + expected_in = self.B * self.local_n * self.local_m + if x.local_array.size != expected_in: + raise ValueError( + f"Local x size mismatch for adjoint. Expected {expected_in} elements " + f"(B={self.B}, local_n={self.local_n}, local_m={self.local_m}), got {x.local_array.size}." + ) + + # Output dtype rules similar to your matrix-mult operators + if np.iscomplexobj(self.G): + output_dtype = np.result_type(self.dtype, x.dtype) + else: + output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype + output_dtype = np.result_type(self.dtype, output_dtype) + + # Output local shapes for SCATTER model vector + my_out = self.B * self.local_k * self.local_m # (B, local_k(row_id), local_m(col_id)) + local_shapes = self.base_comm.allgather(my_out) + + y = DistributedArray( + global_shape=int(np.prod(self.dims_model)), + local_shapes=local_shapes, + mask=x.mask, + partition=Partition.SCATTER, + engine=x.engine, + dtype=output_dtype, + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl, + ) + + # Reshape x tile and pad to (B, bn, bm) + X = x.local_array.reshape((self.B, self.local_n, self.local_m)).astype(output_dtype) + if self.local_n != self.bn or self.local_m != self.bm: + X_padded = ncp.zeros((self.B, self.bn, self.bm), dtype=output_dtype) + X_padded[:, :self.local_n, :self.local_m] = X + X = X_padded + + # Local A^H tile (transpose-conj of A tile): (B, bk, bn) + if hasattr(self, "GT"): + AT_local = self.GT[:, :self.local_ka, :self.local_n].astype(output_dtype) + else: + AT_local = self.G[:, :self.local_n, :self.local_ka].transpose(0, 2, 1).conj().astype(output_dtype) + + if self.local_ka != self.bk or self.local_n != self.bn: + AT_padded = ncp.zeros((self.B, self.bk, self.bn), dtype=output_dtype) + AT_padded[:, :self.local_ka, :self.local_n] = AT_local + AT_local = AT_padded + AT_local = ncp.ascontiguousarray(AT_local) + + Y = ncp.zeros((self.B, self.bk, self.bm), dtype=output_dtype) + + base_nccl = self.base_comm_nccl if x.engine == "cupy" else None + col_nccl = self.col_comm_nccl if x.engine == "cupy" else None + + # Batched adjoint SUMMA variant matching your existing _MPISummaMatrixMult._rmatvec: + # - broadcast X panels down col_comm + # - move AT blocks across WORLD ranks to emulate transposed distribution + for k in range(self.p): + Xtemp = X.copy() if self.row_id == k else ncp.empty_like(X) + Xtemp = self._bcast(self.col_comm, col_nccl, Xtemp, root=k, engine=x.engine) + + # Determine source rank for A^T block needed this iteration + # WORLD rank mapping inside batch group: + # world_rank = batch_id*P2 + (row*p + col) + # Need AT from srcA = (row=k, col=row_id) within this batch group: + srcA_in_group = k * self.p + self.row_id + srcA = self.batch_id * self.P2 + srcA_in_group + + ATtemp = AT_local if (self.rank == srcA) else None + + # Send from ranks with row_id==k (within group) to row=col_id targets, across all columns (within group), + # using WORLD communicator for explicit point-to-point + for moving_col in range(self.p): + if self.row_id == k: + # sender is (row=k, col=self.col_id) + dest_in_group = self.col_id * self.p + moving_col + destA = self.batch_id * self.P2 + dest_in_group + if destA != self.rank: + tagA = (100 + k) * 100000 + destA + self._send(self.base_comm, base_nccl, AT_local, dest=destA, tag=tagA, engine=x.engine) + + if self.col_id == moving_col and ATtemp is None: + tagA = (100 + k) * 100000 + self.rank + recv_buf = ncp.empty_like(AT_local) + ATtemp = self._recv(self.base_comm, base_nccl, recv_buf, source=srcA, tag=tagA, engine=x.engine) + + Y += ncp.matmul(ATtemp, Xtemp) + + Y = Y[:, :self.local_k, :self.local_m] # Unpad output to (B, local_k(row_id), local_m) + y[:] = Y.ravel() + return y + class MPIFredholm1(MPILinearOperator): r"""Fredholm integral of first kind. @@ -166,6 +533,5 @@ def _rmatvec(self, x: NDArray) -> NDArray: y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj() # gather results - y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1, - engine=y.engine)).ravel() + y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1, engine=y.engine)).ravel() return y diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index d0c8c73f..db2590c1 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -6,7 +6,7 @@ "mpi_recv", ] -from typing import Optional, Union +from typing import Optional from mpi4py import MPI from pylops.utils import NDArray @@ -104,41 +104,38 @@ def mpi_allreduce(base_comm: MPI.Comm, def mpi_bcast(base_comm: MPI.Comm, - rank: int, - local_array: NDArray, - index: int, - value: Union[int, NDArray], + send_buf: NDArray, + root: int = 0, engine: Optional[str] = "numpy", - ) -> None: + ) -> NDArray: """MPI_Bcast/bcast Dispatch bcast routine based on type of input and availability of - CUDA-Aware MPI + CUDA-Aware MPI. Parameters ---------- base_comm : :obj:`MPI.Comm` Base MPI Communicator. - rank : :obj:`int` - Rank. - local_array : :obj:`numpy.ndarray` - Localy array to be broadcasted. - index : :obj:`int` or :obj:`slice` - Represents the index positions where a value needs to be assigned. - value : :obj:`int` or :obj:`numpy.ndarray` - Represents the value that will be assigned to the local array at - the specified index positions. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The data buffer to be broadcasted to the other ranks from the broadcasting root rank. + root : :obj:`int`, optional + The rank of the broadcasting process. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) + Returns + ------- + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The buffer containing the broadcasted data. + """ if deps.cuda_aware_mpi_enabled or engine == "numpy": - if rank == 0: - local_array[index] = value - base_comm.Bcast(local_array[index]) - else: - # CuPy with non-CUDA-aware MPI - local_array[index] = base_comm.bcast(value) + base_comm.Bcast(send_buf, root=root) + return send_buf + # CuPy with non-CUDA-aware MPI: use object broadcast + value = send_buf if base_comm.Get_rank() == root else None + return base_comm.bcast(value, root=root) def mpi_send(base_comm: MPI.Comm, diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index 3ad5b022..a57ce8bf 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -238,28 +238,23 @@ def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) -> return recv_buf -def nccl_bcast(nccl_comm, local_array, index, value) -> None: - """ NCCL equivalent of MPI_Bcast. Broadcasts a single value at the given index - from the root GPU (rank 0) to all other GPUs. - +def nccl_bcast(nccl_comm, send_buf, root: int = 0) -> None: + """NCCL equivalent of MPI_Bcast for an array buffer. Parameters ---------- nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` The NCCL communicator used for collective communication. - local_array : :obj:`cupy.ndarray` - The local array on each GPU. The value at `index` will be broadcasted. - index : :obj:`int` - The index in the array to be broadcasted. - value : :obj:`scalar` - The value to broadcast (only used by the root GPU, rank 0). + send_buf: :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The data buffer to be broadcasted to the other ranks from the broadcasting root rank. + root: :obj:`int` + The rank of the broadcasting process. """ - if nccl_comm.rank_id() == 0: - local_array[index] = value + send_buf = send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf) nccl_comm.bcast( - local_array[index].data.ptr, - _nccl_buf_size(local_array[index]), - cupy_to_nccl_dtype[str(local_array[index].dtype)], - 0, + send_buf.data.ptr, + _nccl_buf_size(send_buf), + cupy_to_nccl_dtype[str(send_buf.dtype)], + root, cp.cuda.Stream.null.ptr, ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..6bd7bb3e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +import sys +import pytest +from mpi4py import MPI + +def pytest_itemcollected(item): + """Append MPI rank to the test ID as it is collected.""" + rank = MPI.COMM_WORLD.Get_rank() + item._nodeid += f"[Rank {rank}]" diff --git a/tests/test_fredholm.py b/tests/test_fredholm.py index 883b2940..694812d5 100644 --- a/tests/test_fredholm.py +++ b/tests/test_fredholm.py @@ -15,6 +15,7 @@ backend = "numpy" import numpy as npp +import math from mpi4py import MPI import pytest @@ -24,6 +25,7 @@ from pylops_mpi import DistributedArray from pylops_mpi.DistributedArray import local_split, Partition from pylops_mpi.signalprocessing import MPIFredholm1 +from pylops_mpi.signalprocessing.Fredholm1 import MPIFredholm1SUMMA from pylops_mpi.utils.dottest import dottest np.random.seed(42) @@ -94,6 +96,58 @@ "dtype": "float32", } # real, unsaved Gt, nz=1 +parsumma1 = { + "nsl": 3, + "ny": 5, + "nx": 4, + "nz": 3, + "saveGt": True, + "imag": 0, + "dtype": "float32", +} +parsumma2 = { + "nsl": 2, + "ny": 4, + "nx": 5, + "nz": 3, + "saveGt": False, + "imag": 1j, + "dtype": "complex64", +} + + +def _active_summa_comm(base_comm): + size = base_comm.Get_size() + p = math.isqrt(size) + active_size = p * p + if base_comm.Get_rank() >= active_size: + return None, False + if active_size == size: + return base_comm, True + active_ranks = list(range(active_size)) + group = base_comm.Get_group().Incl(active_ranks) + comm = base_comm.Create_group(group) + return comm, True + + +def _assemble_summa_tiles(op, local_tiles, nsl, nrows, ncols, + row_block, col_block, engine): + comm = op.base_comm + comm_nccl = op.base_comm_nccl if engine == "cupy" else None + tiles = op._allgather(comm, comm_nccl, local_tiles, engine=engine) + out = np.zeros((nsl, nrows, ncols), dtype=local_tiles.dtype) + for rank_in_group in range(comm.Get_size()): + row_id, col_id = divmod(rank_in_group, op.p) + tile = tiles[rank_in_group] + if tile.size == 0: + continue + rs = row_id * row_block + cs = col_id * col_block + rn = tile.shape[1] + cn = tile.shape[2] + out[:, rs:rs + rn, cs:cs + cn] = tile + return out + """Seems to stop next tests from running @pytest.mark.mpi(min_size=2) @@ -165,3 +219,114 @@ def test_Fredholm1(par): y_adj_np = Fop.H @ y_np assert_allclose(y, y_np, rtol=1e-14) assert_allclose(y_adj, y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=1) +@pytest.mark.parametrize("par", [(parsumma1), (parsumma2)]) +def test_Fredholm1SUMMA(par): + """MPIFredholm1SUMMA operator""" + np.random.seed(42) + + comm, is_active = _active_summa_comm(MPI.COMM_WORLD) + if not is_active: + return + + comm_rank = comm.Get_rank() + p = math.isqrt(comm.Get_size()) + + dtype = np.dtype(par["dtype"]) + if dtype == np.complex64 or dtype == np.float32: + base_float_dtype = np.float32 + else: + base_float_dtype = np.float64 + + nsl = par["nsl"] + nx = par["nx"] + ny = par["ny"] + nz = par["nz"] + + _G = np.arange(nsl * nx * ny, + dtype=base_float_dtype).reshape(nsl, nx, ny) + G = (_G - par["imag"] * _G).astype(dtype) + + _M = np.arange(nsl * ny * nz, + dtype=base_float_dtype).reshape(nsl, ny, nz) + M = (_M + par["imag"] * _M).astype(dtype) + + bn = (nx + p - 1) // p + bk = (ny + p - 1) // p + bm = (nz + p - 1) // p + + row_id, col_id = divmod(comm_rank, p) + rs = row_id * bn + re = min(rs + bn, nx) + cs = col_id * bk + ce = min(cs + bk, ny) + ms = row_id * bk + me = min(ms + bk, ny) + zs = col_id * bm + ze = min(zs + bm, nz) + + G_local = G[:, rs:re, cs:ce] + M_local = M[:, ms:me, zs:ze] + + Fop_MPI = MPIFredholm1SUMMA( + G_local, + nz=nz, + nsl_global=nsl, + saveGt=par["saveGt"], + pb=1, + base_comm=comm, + dtype=par["dtype"], + ) + + local_k = max(0, me - ms) + local_n = max(0, re - rs) + local_m = max(0, ze - zs) + local_x_size = nsl * local_k * local_m + local_shapes = comm.allgather(local_x_size) + + x_dist = DistributedArray( + global_shape=nsl * ny * nz, + local_shapes=local_shapes, + partition=Partition.SCATTER, + base_comm=comm, + dtype=par["dtype"], + engine=backend, + ) + x_dist.local_array[:] = M_local.ravel() + + # Forward and adjoint + y_dist = Fop_MPI @ x_dist + xadj_dist = Fop_MPI.H @ y_dist + + # Dot test + dottest(Fop_MPI, x_dist, y_dist, + nsl * nx * nz, nsl * ny * nz) + + y_local = y_dist.local_array.reshape(nsl, local_n, local_m) + y = _assemble_summa_tiles( + Fop_MPI, y_local, nsl, nx, nz, bn, bm, backend + ) + + xadj_local = xadj_dist.local_array.reshape(nsl, local_k, local_m) + xadj = _assemble_summa_tiles( + Fop_MPI, xadj_local, nsl, ny, nz, bk, bm, backend + ) + + if comm_rank == 0: + y_np = np.matmul(G, M) + xadj_np = np.matmul(G.conj().transpose(0, 2, 1), y_np) + rtol = np.finfo(base_float_dtype).resolution + assert_allclose( + y.squeeze(), + y_np.squeeze(), + rtol=rtol, + err_msg=f"Rank {comm_rank}: Forward verification failed." + ) + assert_allclose( + xadj.squeeze(), + xadj_np.squeeze(), + rtol=rtol, + err_msg=f"Rank {comm_rank}: Adjoint verification failed." + ) diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py index e41618af..38bf6e9b 100644 --- a/tests/test_matrixmult.py +++ b/tests/test_matrixmult.py @@ -37,6 +37,7 @@ # M, K, N are matrix dimensions A(N,K), B(K,M) # P_prime will be ceil(sqrt(size)). test_params = [ + pytest.param(64, 64, 64, "float64", id="f32_64_64_64"), pytest.param(37, 37, 37, "float64", id="f32_37_37_37"), pytest.param(50, 30, 40, "float64", id="f64_50_30_40"), # temporarely removed as sometimes crashed CI... to be investigated diff --git a/tests_nccl/test_matrixmult_nccl.py b/tests_nccl/test_matrixmult_nccl.py new file mode 100644 index 00000000..2757ee96 --- /dev/null +++ b/tests_nccl/test_matrixmult_nccl.py @@ -0,0 +1,286 @@ +"""Test the MPIMatrixMult class with NCCL + Designed to run with n GPUs (with 1 MPI process per GPU) + $ mpiexec -n 10 pytest test_matrixmult_nccl.py --with-mpi +""" +import math + +import numpy as np +import cupy as cp +from numpy.testing import assert_allclose +from mpi4py import MPI +import pytest + +from pylops.basicoperators import Conj, FirstDerivative +from pylops_mpi import DistributedArray, Partition +from pylops_mpi.basicoperators import MPIBlockDiag, MPIMatrixMult, \ + local_block_split, block_gather +from pylops_mpi.utils._nccl import initialize_nccl_comm + +np.random.seed(42) + +nccl_comm = initialize_nccl_comm() +base_comm = MPI.COMM_WORLD +size = base_comm.Get_size() +rank = base_comm.Get_rank() + +# Define test cases: (N, K, M, dtype_str) +# M, K, N are matrix dimensions A(N,K), B(K,M) +# P_prime will be ceil(sqrt(size)). +test_params = [ + pytest.param(64, 64, 64, "float64", id="f32_64_64_64"), + pytest.param(37, 37, 37, "float64", id="f32_37_37_37"), + pytest.param(50, 30, 40, "float64", id="f64_50_30_40"), + # temporarely removed as sometimes crashed CI... to be investigated + # pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"), + pytest.param(3, 4, 5, "float32", id="f32_3_4_5"), + pytest.param(1, 2, 1, "float64", id="f64_1_2_1",), + pytest.param(2, 1, 3, "float32", id="f32_2_1_3",), +] + + +def _ensure_square_grid(): + p_prime = math.isqrt(size) + if p_prime * p_prime != size: + pytest.skip("MPIMatrixMult NCCL tests require a square number of ranks") + return p_prime + + +def _reorganize_local_matrix(x_dist, nrows, ncols, blk_cols, p_prime): + """Re-organize distributed array in local matrix""" + x = x_dist.asarray(masked=True) + col_counts = [min(blk_cols, ncols - j * blk_cols) for j in range(p_prime)] + x_blocks = [] + offset = 0 + for cnt in col_counts: + block_size = nrows * cnt + x_block = x[offset: offset + block_size] + if len(x_block) != 0: + x_blocks.append(x_block.reshape(nrows, cnt)) + offset += block_size + return cp.hstack(x_blocks) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("N, K, M, dtype_str", test_params) +def test_MPIMatrixMult_block_nccl(N, K, M, dtype_str): + """MPIMatrixMult operator with kind=`block` and NCCL""" + p_prime = _ensure_square_grid() + if min(N, M) < p_prime: + pytest.skip("MPIMatrixMult block test requires N and M >= sqrt(size)") + + dtype = np.dtype(dtype_str) + cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 + base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 + + row_id, col_id = divmod(rank, p_prime) + cols_id = base_comm.allgather(col_id) + + # Calculate local matrix dimensions + blk_rows_A = int(math.ceil(N / p_prime)) + row_start_A = col_id * blk_rows_A + row_end_A = min(N, row_start_A + blk_rows_A) + + blk_cols_X = int(math.ceil(M / p_prime)) + col_start_X = row_id * blk_cols_X + col_end_X = min(M, col_start_X + blk_cols_X) + local_col_X_len = max(0, col_end_X - col_start_X) + + # Fill local matrices + A_glob_real = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) + A_glob_imag = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5 + A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype) + + X_glob_real = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) + X_glob_imag = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7 + X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype) + + A_p = A_glob[row_start_A:row_end_A, :] + X_p = X_glob[:, col_start_X:col_end_X] + + # Create MPIMatrixMult operator + Aop = MPIMatrixMult(A_p, M, base_comm=base_comm, + dtype=dtype_str, kind="block", + base_comm_nccl=nccl_comm) + + # Create DistributedArray for input x (representing B flattened) + all_local_col_len = base_comm.allgather(local_col_X_len) + total_cols = np.sum(all_local_col_len) + + x_dist = DistributedArray( + global_shape=(K * total_cols), + local_shapes=[(K * cl_b) for cl_b in all_local_col_len], + partition=Partition.SCATTER, + base_comm_nccl=nccl_comm, + mask=[i % p_prime for i in range(size)], + dtype=dtype, + engine="cupy" + ) + + x_dist.local_array[:] = X_p.ravel() + + # Forward operation: y = A @ x (distributed) + y_dist = Aop @ x_dist + + # Adjoint operation: xadj = A.H @ y (distributed) + xadj_dist = Aop.H @ y_dist + + # Re-organize in local matrix + y = _reorganize_local_matrix(y_dist, N, M, blk_cols_X, p_prime) + xadj = _reorganize_local_matrix(xadj_dist, K, M, blk_cols_X, p_prime) + + if rank == 0: + A_glob_np = A_glob.get() + X_glob_np = X_glob.get() + y_loc = A_glob_np @ X_glob_np + assert_allclose( + y.get().squeeze(), + y_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj_loc = A_glob_np.conj().T @ y_loc + assert_allclose( + xadj.get().squeeze(), + xadj_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Adjoint verification failed." + ) + + # Chain with another operator + Dop = FirstDerivative(dims=(N, col_end_X - col_start_X), + axis=0, dtype=dtype) + DBop = MPIBlockDiag(ops=[Dop, ], base_comm=base_comm, mask=cols_id) + Op = DBop @ Aop + + y1_dist = Op @ x_dist + xadj1_dist = Op.H @ y1_dist + + # Re-organize in local matrix + y1 = _reorganize_local_matrix(y1_dist, N, M, blk_cols_X, p_prime) + xadj1 = _reorganize_local_matrix(xadj1_dist, K, M, blk_cols_X, p_prime) + + if rank == 0: + A_glob_np = A_glob.get() + X_glob_np = X_glob.get() + Dop_glob = FirstDerivative(dims=(N, M), axis=0, dtype=dtype) + y1_loc = (Dop_glob @ (A_glob_np @ X_glob_np).ravel()).reshape(N, M) + assert_allclose( + y1.get().squeeze(), + y1_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj1_loc = A_glob_np.conj().T @ (Dop_glob.H @ y1_loc.ravel()).reshape(N, M) + assert_allclose( + xadj1.get().squeeze(), + xadj1_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Adjoint verification failed." + ) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("N, K, M, dtype_str", test_params) +def test_MPIMatrixMult_summa_nccl(N, K, M, dtype_str): + """MPIMatrixMult operator with kind=`summa` and NCCL""" + p_prime = _ensure_square_grid() + if min(N, K, M) < p_prime: + pytest.skip("MPIMatrixMult summa test requires N, K, M >= sqrt(size)") + + dtype = np.dtype(dtype_str) + cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 + base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 + + # Fill local matrices + A_glob_real = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) + A_glob_imag = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5 + A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype) + + X_glob_real = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) + X_glob_imag = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7 + X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype) + + A_slice = local_block_split((N, K), rank, base_comm) + X_slice = local_block_split((K, M), rank, base_comm) + + A_p = A_glob[A_slice] + X_p = X_glob[X_slice] + + # Create MPIMatrixMult operator + Aop = MPIMatrixMult(A_p, M, base_comm=base_comm, + dtype=dtype_str, kind="summa", + base_comm_nccl=nccl_comm) + + x_dist = DistributedArray( + global_shape=(K * M), + local_shapes=base_comm.allgather(X_p.shape[0] * X_p.shape[1]), + partition=Partition.SCATTER, + base_comm_nccl=nccl_comm, + dtype=dtype, + engine="cupy", + ) + + x_dist.local_array[:] = X_p.ravel() + + # Forward operation: y = A @ x (distributed) + y_dist = Aop @ x_dist + + # Adjoint operation: xadj = A.H @ y (distributed) + xadj_dist = Aop.H @ y_dist + + # Re-organize in local matrix + y = block_gather(y_dist, (N, M), base_comm) + xadj = block_gather(xadj_dist, (K, M), base_comm) + + if rank == 0: + A_glob_np = A_glob.get() + X_glob_np = X_glob.get() + y_loc = A_glob_np @ X_glob_np + assert_allclose( + y.get().squeeze(), + y_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj_loc = A_glob_np.conj().T @ y_loc + assert_allclose( + xadj.get().squeeze(), + xadj_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Adjoint verification failed." + ) + + # Chain with another operator + Dop = Conj(dims=(A_p.shape[0], X_p.shape[1])) + DBop = MPIBlockDiag(ops=[Dop, ], base_comm=base_comm) + Op = DBop @ Aop + + y1_dist = Op @ x_dist + xadj1_dist = Op.H @ y1_dist + + # Re-organize in local matrix + y1 = block_gather(y1_dist, (N, M), base_comm) + xadj1 = block_gather(xadj1_dist, (K, M), base_comm) + + if rank == 0: + A_glob_np = A_glob.get() + X_glob_np = X_glob.get() + y1_loc = ((A_glob_np @ X_glob_np).conj().ravel()).reshape(N, M) + + assert_allclose( + y1.get().squeeze(), + y1_loc.squeeze(), + rtol=np.finfo(y1_loc.dtype).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj1_loc = ((A_glob_np.conj().T @ y1_loc.conj()).ravel()).reshape(K, M) + assert_allclose( + xadj1.get().squeeze().ravel(), + xadj1_loc.squeeze().ravel(), + rtol=np.finfo(xadj1_loc.dtype).resolution, + err_msg=f"Rank {rank}: Adjoint verification failed." + )