|
| 1 | +"""Test the MPIMatrixMult class with NCCL |
| 2 | + Designed to run with n GPUs (with 1 MPI process per GPU) |
| 3 | + $ mpiexec -n 10 pytest test_matrixmult_nccl.py --with-mpi |
| 4 | +""" |
| 5 | +import math |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import cupy as cp |
| 9 | +from numpy.testing import assert_allclose |
| 10 | +from mpi4py import MPI |
| 11 | +import pytest |
| 12 | + |
| 13 | +from pylops.basicoperators import Conj, FirstDerivative |
| 14 | +from pylops_mpi import DistributedArray, Partition |
| 15 | +from pylops_mpi.basicoperators import MPIBlockDiag, MPIMatrixMult, \ |
| 16 | + local_block_split, block_gather |
| 17 | +from pylops_mpi.utils._nccl import initialize_nccl_comm |
| 18 | + |
| 19 | +np.random.seed(42) |
| 20 | + |
| 21 | +nccl_comm = initialize_nccl_comm() |
| 22 | +base_comm = MPI.COMM_WORLD |
| 23 | +size = base_comm.Get_size() |
| 24 | +rank = base_comm.Get_rank() |
| 25 | + |
| 26 | +# Define test cases: (N, K, M, dtype_str) |
| 27 | +# M, K, N are matrix dimensions A(N,K), B(K,M) |
| 28 | +# P_prime will be ceil(sqrt(size)). |
| 29 | +test_params = [ |
| 30 | + pytest.param(64, 64, 64, "float64", id="f32_64_64_64"), |
| 31 | + pytest.param(37, 37, 37, "float64", id="f32_37_37_37"), |
| 32 | + pytest.param(50, 30, 40, "float64", id="f64_50_30_40"), |
| 33 | + # temporarely removed as sometimes crashed CI... to be investigated |
| 34 | + # pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"), |
| 35 | + pytest.param(3, 4, 5, "float32", id="f32_3_4_5"), |
| 36 | + pytest.param(1, 2, 1, "float64", id="f64_1_2_1",), |
| 37 | + pytest.param(2, 1, 3, "float32", id="f32_2_1_3",), |
| 38 | +] |
| 39 | + |
| 40 | + |
| 41 | +def _ensure_square_grid(): |
| 42 | + p_prime = math.isqrt(size) |
| 43 | + if p_prime * p_prime != size: |
| 44 | + pytest.skip("MPIMatrixMult NCCL tests require a square number of ranks") |
| 45 | + return p_prime |
| 46 | + |
| 47 | + |
| 48 | +def _reorganize_local_matrix(x_dist, nrows, ncols, blk_cols, p_prime): |
| 49 | + """Re-organize distributed array in local matrix""" |
| 50 | + x = x_dist.asarray(masked=True) |
| 51 | + col_counts = [min(blk_cols, ncols - j * blk_cols) for j in range(p_prime)] |
| 52 | + x_blocks = [] |
| 53 | + offset = 0 |
| 54 | + for cnt in col_counts: |
| 55 | + block_size = nrows * cnt |
| 56 | + x_block = x[offset: offset + block_size] |
| 57 | + if len(x_block) != 0: |
| 58 | + x_blocks.append(x_block.reshape(nrows, cnt)) |
| 59 | + offset += block_size |
| 60 | + return cp.hstack(x_blocks) |
| 61 | + |
| 62 | + |
| 63 | +@pytest.mark.mpi(min_size=2) |
| 64 | +@pytest.mark.parametrize("N, K, M, dtype_str", test_params) |
| 65 | +def test_MPIMatrixMult_block_nccl(N, K, M, dtype_str): |
| 66 | + """MPIMatrixMult operator with kind=`block` and NCCL""" |
| 67 | + p_prime = _ensure_square_grid() |
| 68 | + if min(N, M) < p_prime: |
| 69 | + pytest.skip("MPIMatrixMult block test requires N and M >= sqrt(size)") |
| 70 | + |
| 71 | + dtype = np.dtype(dtype_str) |
| 72 | + cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 |
| 73 | + base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 |
| 74 | + |
| 75 | + row_id, col_id = divmod(rank, p_prime) |
| 76 | + cols_id = base_comm.allgather(col_id) |
| 77 | + |
| 78 | + # Calculate local matrix dimensions |
| 79 | + blk_rows_A = int(math.ceil(N / p_prime)) |
| 80 | + row_start_A = col_id * blk_rows_A |
| 81 | + row_end_A = min(N, row_start_A + blk_rows_A) |
| 82 | + |
| 83 | + blk_cols_X = int(math.ceil(M / p_prime)) |
| 84 | + col_start_X = row_id * blk_cols_X |
| 85 | + col_end_X = min(M, col_start_X + blk_cols_X) |
| 86 | + local_col_X_len = max(0, col_end_X - col_start_X) |
| 87 | + |
| 88 | + # Fill local matrices |
| 89 | + A_glob_real = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) |
| 90 | + A_glob_imag = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5 |
| 91 | + A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype) |
| 92 | + |
| 93 | + X_glob_real = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) |
| 94 | + X_glob_imag = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7 |
| 95 | + X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype) |
| 96 | + |
| 97 | + A_p = A_glob[row_start_A:row_end_A, :] |
| 98 | + X_p = X_glob[:, col_start_X:col_end_X] |
| 99 | + |
| 100 | + # Create MPIMatrixMult operator |
| 101 | + Aop = MPIMatrixMult(A_p, M, base_comm=base_comm, |
| 102 | + dtype=dtype_str, kind="block", |
| 103 | + base_comm_nccl=nccl_comm) |
| 104 | + |
| 105 | + # Create DistributedArray for input x (representing B flattened) |
| 106 | + all_local_col_len = base_comm.allgather(local_col_X_len) |
| 107 | + total_cols = np.sum(all_local_col_len) |
| 108 | + |
| 109 | + x_dist = DistributedArray( |
| 110 | + global_shape=(K * total_cols), |
| 111 | + local_shapes=[(K * cl_b) for cl_b in all_local_col_len], |
| 112 | + partition=Partition.SCATTER, |
| 113 | + base_comm_nccl=nccl_comm, |
| 114 | + mask=[i % p_prime for i in range(size)], |
| 115 | + dtype=dtype, |
| 116 | + engine="cupy" |
| 117 | + ) |
| 118 | + |
| 119 | + x_dist.local_array[:] = X_p.ravel() |
| 120 | + |
| 121 | + # Forward operation: y = A @ x (distributed) |
| 122 | + y_dist = Aop @ x_dist |
| 123 | + |
| 124 | + # Adjoint operation: xadj = A.H @ y (distributed) |
| 125 | + xadj_dist = Aop.H @ y_dist |
| 126 | + |
| 127 | + # Re-organize in local matrix |
| 128 | + y = _reorganize_local_matrix(y_dist, N, M, blk_cols_X, p_prime) |
| 129 | + xadj = _reorganize_local_matrix(xadj_dist, K, M, blk_cols_X, p_prime) |
| 130 | + |
| 131 | + if rank == 0: |
| 132 | + A_glob_np = A_glob.get() |
| 133 | + X_glob_np = X_glob.get() |
| 134 | + y_loc = A_glob_np @ X_glob_np |
| 135 | + assert_allclose( |
| 136 | + y.get().squeeze(), |
| 137 | + y_loc.squeeze(), |
| 138 | + rtol=np.finfo(np.dtype(dtype)).resolution, |
| 139 | + err_msg=f"Rank {rank}: Forward verification failed." |
| 140 | + ) |
| 141 | + |
| 142 | + xadj_loc = A_glob_np.conj().T @ y_loc |
| 143 | + assert_allclose( |
| 144 | + xadj.get().squeeze(), |
| 145 | + xadj_loc.squeeze(), |
| 146 | + rtol=np.finfo(np.dtype(dtype)).resolution, |
| 147 | + err_msg=f"Rank {rank}: Adjoint verification failed." |
| 148 | + ) |
| 149 | + |
| 150 | + # Chain with another operator |
| 151 | + Dop = FirstDerivative(dims=(N, col_end_X - col_start_X), |
| 152 | + axis=0, dtype=dtype) |
| 153 | + DBop = MPIBlockDiag(ops=[Dop, ], base_comm=base_comm, mask=cols_id) |
| 154 | + Op = DBop @ Aop |
| 155 | + |
| 156 | + y1_dist = Op @ x_dist |
| 157 | + xadj1_dist = Op.H @ y1_dist |
| 158 | + |
| 159 | + # Re-organize in local matrix |
| 160 | + y1 = _reorganize_local_matrix(y1_dist, N, M, blk_cols_X, p_prime) |
| 161 | + xadj1 = _reorganize_local_matrix(xadj1_dist, K, M, blk_cols_X, p_prime) |
| 162 | + |
| 163 | + if rank == 0: |
| 164 | + A_glob_np = A_glob.get() |
| 165 | + X_glob_np = X_glob.get() |
| 166 | + Dop_glob = FirstDerivative(dims=(N, M), axis=0, dtype=dtype) |
| 167 | + y1_loc = (Dop_glob @ (A_glob_np @ X_glob_np).ravel()).reshape(N, M) |
| 168 | + assert_allclose( |
| 169 | + y1.get().squeeze(), |
| 170 | + y1_loc.squeeze(), |
| 171 | + rtol=np.finfo(np.dtype(dtype)).resolution, |
| 172 | + err_msg=f"Rank {rank}: Forward verification failed." |
| 173 | + ) |
| 174 | + |
| 175 | + xadj1_loc = A_glob_np.conj().T @ (Dop_glob.H @ y1_loc.ravel()).reshape(N, M) |
| 176 | + assert_allclose( |
| 177 | + xadj1.get().squeeze(), |
| 178 | + xadj1_loc.squeeze(), |
| 179 | + rtol=np.finfo(np.dtype(dtype)).resolution, |
| 180 | + err_msg=f"Rank {rank}: Adjoint verification failed." |
| 181 | + ) |
| 182 | + |
| 183 | + |
| 184 | +@pytest.mark.mpi(min_size=2) |
| 185 | +@pytest.mark.parametrize("N, K, M, dtype_str", test_params) |
| 186 | +def test_MPIMatrixMult_summa_nccl(N, K, M, dtype_str): |
| 187 | + """MPIMatrixMult operator with kind=`summa` and NCCL""" |
| 188 | + p_prime = _ensure_square_grid() |
| 189 | + if min(N, K, M) < p_prime: |
| 190 | + pytest.skip("MPIMatrixMult summa test requires N, K, M >= sqrt(size)") |
| 191 | + |
| 192 | + dtype = np.dtype(dtype_str) |
| 193 | + cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 |
| 194 | + base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 |
| 195 | + |
| 196 | + # Fill local matrices |
| 197 | + A_glob_real = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) |
| 198 | + A_glob_imag = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5 |
| 199 | + A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype) |
| 200 | + |
| 201 | + X_glob_real = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) |
| 202 | + X_glob_imag = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7 |
| 203 | + X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype) |
| 204 | + |
| 205 | + A_slice = local_block_split((N, K), rank, base_comm) |
| 206 | + X_slice = local_block_split((K, M), rank, base_comm) |
| 207 | + |
| 208 | + A_p = A_glob[A_slice] |
| 209 | + X_p = X_glob[X_slice] |
| 210 | + |
| 211 | + # Create MPIMatrixMult operator |
| 212 | + Aop = MPIMatrixMult(A_p, M, base_comm=base_comm, |
| 213 | + dtype=dtype_str, kind="summa", |
| 214 | + base_comm_nccl=nccl_comm) |
| 215 | + |
| 216 | + x_dist = DistributedArray( |
| 217 | + global_shape=(K * M), |
| 218 | + local_shapes=base_comm.allgather(X_p.shape[0] * X_p.shape[1]), |
| 219 | + partition=Partition.SCATTER, |
| 220 | + base_comm_nccl=nccl_comm, |
| 221 | + dtype=dtype, |
| 222 | + engine="cupy", |
| 223 | + ) |
| 224 | + |
| 225 | + x_dist.local_array[:] = X_p.ravel() |
| 226 | + |
| 227 | + # Forward operation: y = A @ x (distributed) |
| 228 | + y_dist = Aop @ x_dist |
| 229 | + |
| 230 | + # Adjoint operation: xadj = A.H @ y (distributed) |
| 231 | + xadj_dist = Aop.H @ y_dist |
| 232 | + |
| 233 | + # Re-organize in local matrix |
| 234 | + y = block_gather(y_dist, (N, M), base_comm) |
| 235 | + xadj = block_gather(xadj_dist, (K, M), base_comm) |
| 236 | + |
| 237 | + if rank == 0: |
| 238 | + A_glob_np = A_glob.get() |
| 239 | + X_glob_np = X_glob.get() |
| 240 | + y_loc = A_glob_np @ X_glob_np |
| 241 | + assert_allclose( |
| 242 | + y.get().squeeze(), |
| 243 | + y_loc.squeeze(), |
| 244 | + rtol=np.finfo(np.dtype(dtype)).resolution, |
| 245 | + err_msg=f"Rank {rank}: Forward verification failed." |
| 246 | + ) |
| 247 | + |
| 248 | + xadj_loc = A_glob_np.conj().T @ y_loc |
| 249 | + assert_allclose( |
| 250 | + xadj.get().squeeze(), |
| 251 | + xadj_loc.squeeze(), |
| 252 | + rtol=np.finfo(np.dtype(dtype)).resolution, |
| 253 | + err_msg=f"Rank {rank}: Adjoint verification failed." |
| 254 | + ) |
| 255 | + |
| 256 | + # Chain with another operator |
| 257 | + Dop = Conj(dims=(A_p.shape[0], X_p.shape[1])) |
| 258 | + DBop = MPIBlockDiag(ops=[Dop, ], base_comm=base_comm) |
| 259 | + Op = DBop @ Aop |
| 260 | + |
| 261 | + y1_dist = Op @ x_dist |
| 262 | + xadj1_dist = Op.H @ y1_dist |
| 263 | + |
| 264 | + # Re-organize in local matrix |
| 265 | + y1 = block_gather(y1_dist, (N, M), base_comm) |
| 266 | + xadj1 = block_gather(xadj1_dist, (K, M), base_comm) |
| 267 | + |
| 268 | + if rank == 0: |
| 269 | + A_glob_np = A_glob.get() |
| 270 | + X_glob_np = X_glob.get() |
| 271 | + y1_loc = ((A_glob_np @ X_glob_np).conj().ravel()).reshape(N, M) |
| 272 | + |
| 273 | + assert_allclose( |
| 274 | + y1.get().squeeze(), |
| 275 | + y1_loc.squeeze(), |
| 276 | + rtol=np.finfo(y1_loc.dtype).resolution, |
| 277 | + err_msg=f"Rank {rank}: Forward verification failed." |
| 278 | + ) |
| 279 | + |
| 280 | + xadj1_loc = ((A_glob_np.conj().T @ y1_loc.conj()).ravel()).reshape(K, M) |
| 281 | + assert_allclose( |
| 282 | + xadj1.get().squeeze().ravel(), |
| 283 | + xadj1_loc.squeeze().ravel(), |
| 284 | + rtol=np.finfo(xadj1_loc.dtype).resolution, |
| 285 | + err_msg=f"Rank {rank}: Adjoint verification failed." |
| 286 | + ) |
0 commit comments