Skip to content

Commit 01fe71f

Browse files
committed
Added NCCL support
1 parent 7ed2933 commit 01fe71f

2 files changed

Lines changed: 295 additions & 4 deletions

File tree

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Com
161161
if p_prime * p_prime != comm.Get_size():
162162
raise RuntimeError(f"Communicator size must be a perfect square, got {comm.Get_size()!r}")
163163

164-
all_blks = comm.allgather(x.local_array)
164+
comm_nccl = x.base_comm_nccl if comm == x.base_comm else None
165+
all_blks = x._allgather(comm, comm_nccl, x.local_array, engine=x.engine)
165166
nr, nc = orig_shape
166167
br, bc = math.ceil(nr / p_prime), math.ceil(nc / p_prime)
167168
C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype)
@@ -706,7 +707,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
706707

707708
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
708709
Y_local = ncp.zeros((self.A.shape[1], bm), dtype=output_dtype)
709-
710+
base_comm_nccl = self.base_comm_nccl if x.engine == "cupy" else None
710711
for k in range(self._P_prime):
711712
Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block)
712713
col_comm_nccl = self._col_comm_nccl if x.engine == "cupy" else None
@@ -721,11 +722,13 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
721722
destA = self._col_id * self._P_prime + moving_col
722723
if destA != self.rank:
723724
tagA = (100 + k) * 1000 + destA
724-
self._send(self.base_comm, None, A_local, dest=destA, tag=tagA, engine=x.engine)
725+
self._send(self.base_comm, base_comm_nccl, A_local,
726+
dest=destA, tag=tagA, engine=x.engine)
725727
if self._col_id == moving_col and ATtemp is None:
726728
tagA = (100 + k) * 1000 + self.rank
727729
recv_buf = ncp.empty_like(A_local)
728-
ATtemp = self._recv(self.base_comm, None, recv_buf, source=srcA, tag=tagA, engine=x.engine)
730+
ATtemp = self._recv(self.base_comm, base_comm_nccl, recv_buf,
731+
source=srcA, tag=tagA, engine=x.engine)
729732
Y_local += ncp.dot(ATtemp, Xtemp)
730733

731734
Y_local_unpadded = Y_local[:local_k, :local_m]
@@ -761,6 +764,8 @@ def MPIMatrixMult(
761764
memory). Default is ``False``.
762765
base_comm : :obj:`mpi4py.MPI.Comm`, optional
763766
MPI communicator to use. Defaults to ``MPI.COMM_WORLD``.
767+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional
768+
NCCL communicator to use when operating on ``cupy`` arrays.
764769
kind : :obj:`str`, optional
765770
Algorithm used to perform matrix multiplication: ``'block'`` for #
766771
block-row-column decomposition, and ``'summa'`` for SUMMA algorithm, or

tests_nccl/test_matrixmult_nccl.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
"""Test the MPIMatrixMult class with NCCL
2+
Designed to run with n GPUs (with 1 MPI process per GPU)
3+
$ mpiexec -n 10 pytest test_matrixmult_nccl.py --with-mpi
4+
"""
5+
import math
6+
7+
import numpy as np
8+
import cupy as cp
9+
from numpy.testing import assert_allclose
10+
from mpi4py import MPI
11+
import pytest
12+
13+
from pylops.basicoperators import Conj, FirstDerivative
14+
from pylops_mpi import DistributedArray, Partition
15+
from pylops_mpi.basicoperators import MPIBlockDiag, MPIMatrixMult, \
16+
local_block_split, block_gather
17+
from pylops_mpi.utils._nccl import initialize_nccl_comm
18+
19+
np.random.seed(42)
20+
21+
nccl_comm = initialize_nccl_comm()
22+
base_comm = MPI.COMM_WORLD
23+
size = base_comm.Get_size()
24+
rank = base_comm.Get_rank()
25+
26+
# Define test cases: (N, K, M, dtype_str)
27+
# M, K, N are matrix dimensions A(N,K), B(K,M)
28+
# P_prime will be ceil(sqrt(size)).
29+
test_params = [
30+
pytest.param(64, 64, 64, "float64", id="f32_64_64_64"),
31+
pytest.param(37, 37, 37, "float64", id="f32_37_37_37"),
32+
pytest.param(50, 30, 40, "float64", id="f64_50_30_40"),
33+
# temporarely removed as sometimes crashed CI... to be investigated
34+
# pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"),
35+
pytest.param(3, 4, 5, "float32", id="f32_3_4_5"),
36+
pytest.param(1, 2, 1, "float64", id="f64_1_2_1",),
37+
pytest.param(2, 1, 3, "float32", id="f32_2_1_3",),
38+
]
39+
40+
41+
def _ensure_square_grid():
42+
p_prime = math.isqrt(size)
43+
if p_prime * p_prime != size:
44+
pytest.skip("MPIMatrixMult NCCL tests require a square number of ranks")
45+
return p_prime
46+
47+
48+
def _reorganize_local_matrix(x_dist, nrows, ncols, blk_cols, p_prime):
49+
"""Re-organize distributed array in local matrix"""
50+
x = x_dist.asarray(masked=True)
51+
col_counts = [min(blk_cols, ncols - j * blk_cols) for j in range(p_prime)]
52+
x_blocks = []
53+
offset = 0
54+
for cnt in col_counts:
55+
block_size = nrows * cnt
56+
x_block = x[offset: offset + block_size]
57+
if len(x_block) != 0:
58+
x_blocks.append(x_block.reshape(nrows, cnt))
59+
offset += block_size
60+
return cp.hstack(x_blocks)
61+
62+
63+
@pytest.mark.mpi(min_size=2)
64+
@pytest.mark.parametrize("N, K, M, dtype_str", test_params)
65+
def test_MPIMatrixMult_block_nccl(N, K, M, dtype_str):
66+
"""MPIMatrixMult operator with kind=`block` and NCCL"""
67+
p_prime = _ensure_square_grid()
68+
if min(N, M) < p_prime:
69+
pytest.skip("MPIMatrixMult block test requires N and M >= sqrt(size)")
70+
71+
dtype = np.dtype(dtype_str)
72+
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
73+
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
74+
75+
row_id, col_id = divmod(rank, p_prime)
76+
cols_id = base_comm.allgather(col_id)
77+
78+
# Calculate local matrix dimensions
79+
blk_rows_A = int(math.ceil(N / p_prime))
80+
row_start_A = col_id * blk_rows_A
81+
row_end_A = min(N, row_start_A + blk_rows_A)
82+
83+
blk_cols_X = int(math.ceil(M / p_prime))
84+
col_start_X = row_id * blk_cols_X
85+
col_end_X = min(M, col_start_X + blk_cols_X)
86+
local_col_X_len = max(0, col_end_X - col_start_X)
87+
88+
# Fill local matrices
89+
A_glob_real = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K)
90+
A_glob_imag = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5
91+
A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype)
92+
93+
X_glob_real = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M)
94+
X_glob_imag = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7
95+
X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype)
96+
97+
A_p = A_glob[row_start_A:row_end_A, :]
98+
X_p = X_glob[:, col_start_X:col_end_X]
99+
100+
# Create MPIMatrixMult operator
101+
Aop = MPIMatrixMult(A_p, M, base_comm=base_comm,
102+
dtype=dtype_str, kind="block",
103+
base_comm_nccl=nccl_comm)
104+
105+
# Create DistributedArray for input x (representing B flattened)
106+
all_local_col_len = base_comm.allgather(local_col_X_len)
107+
total_cols = np.sum(all_local_col_len)
108+
109+
x_dist = DistributedArray(
110+
global_shape=(K * total_cols),
111+
local_shapes=[(K * cl_b) for cl_b in all_local_col_len],
112+
partition=Partition.SCATTER,
113+
base_comm_nccl=nccl_comm,
114+
mask=[i % p_prime for i in range(size)],
115+
dtype=dtype,
116+
engine="cupy"
117+
)
118+
119+
x_dist.local_array[:] = X_p.ravel()
120+
121+
# Forward operation: y = A @ x (distributed)
122+
y_dist = Aop @ x_dist
123+
124+
# Adjoint operation: xadj = A.H @ y (distributed)
125+
xadj_dist = Aop.H @ y_dist
126+
127+
# Re-organize in local matrix
128+
y = _reorganize_local_matrix(y_dist, N, M, blk_cols_X, p_prime)
129+
xadj = _reorganize_local_matrix(xadj_dist, K, M, blk_cols_X, p_prime)
130+
131+
if rank == 0:
132+
A_glob_np = A_glob.get()
133+
X_glob_np = X_glob.get()
134+
y_loc = A_glob_np @ X_glob_np
135+
assert_allclose(
136+
y.get().squeeze(),
137+
y_loc.squeeze(),
138+
rtol=np.finfo(np.dtype(dtype)).resolution,
139+
err_msg=f"Rank {rank}: Forward verification failed."
140+
)
141+
142+
xadj_loc = A_glob_np.conj().T @ y_loc
143+
assert_allclose(
144+
xadj.get().squeeze(),
145+
xadj_loc.squeeze(),
146+
rtol=np.finfo(np.dtype(dtype)).resolution,
147+
err_msg=f"Rank {rank}: Adjoint verification failed."
148+
)
149+
150+
# Chain with another operator
151+
Dop = FirstDerivative(dims=(N, col_end_X - col_start_X),
152+
axis=0, dtype=dtype)
153+
DBop = MPIBlockDiag(ops=[Dop, ], base_comm=base_comm, mask=cols_id)
154+
Op = DBop @ Aop
155+
156+
y1_dist = Op @ x_dist
157+
xadj1_dist = Op.H @ y1_dist
158+
159+
# Re-organize in local matrix
160+
y1 = _reorganize_local_matrix(y1_dist, N, M, blk_cols_X, p_prime)
161+
xadj1 = _reorganize_local_matrix(xadj1_dist, K, M, blk_cols_X, p_prime)
162+
163+
if rank == 0:
164+
A_glob_np = A_glob.get()
165+
X_glob_np = X_glob.get()
166+
Dop_glob = FirstDerivative(dims=(N, M), axis=0, dtype=dtype)
167+
y1_loc = (Dop_glob @ (A_glob_np @ X_glob_np).ravel()).reshape(N, M)
168+
assert_allclose(
169+
y1.get().squeeze(),
170+
y1_loc.squeeze(),
171+
rtol=np.finfo(np.dtype(dtype)).resolution,
172+
err_msg=f"Rank {rank}: Forward verification failed."
173+
)
174+
175+
xadj1_loc = A_glob_np.conj().T @ (Dop_glob.H @ y1_loc.ravel()).reshape(N, M)
176+
assert_allclose(
177+
xadj1.get().squeeze(),
178+
xadj1_loc.squeeze(),
179+
rtol=np.finfo(np.dtype(dtype)).resolution,
180+
err_msg=f"Rank {rank}: Adjoint verification failed."
181+
)
182+
183+
184+
@pytest.mark.mpi(min_size=2)
185+
@pytest.mark.parametrize("N, K, M, dtype_str", test_params)
186+
def test_MPIMatrixMult_summa_nccl(N, K, M, dtype_str):
187+
"""MPIMatrixMult operator with kind=`summa` and NCCL"""
188+
p_prime = _ensure_square_grid()
189+
if min(N, K, M) < p_prime:
190+
pytest.skip("MPIMatrixMult summa test requires N, K, M >= sqrt(size)")
191+
192+
dtype = np.dtype(dtype_str)
193+
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
194+
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
195+
196+
# Fill local matrices
197+
A_glob_real = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K)
198+
A_glob_imag = cp.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5
199+
A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype)
200+
201+
X_glob_real = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M)
202+
X_glob_imag = cp.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7
203+
X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype)
204+
205+
A_slice = local_block_split((N, K), rank, base_comm)
206+
X_slice = local_block_split((K, M), rank, base_comm)
207+
208+
A_p = A_glob[A_slice]
209+
X_p = X_glob[X_slice]
210+
211+
# Create MPIMatrixMult operator
212+
Aop = MPIMatrixMult(A_p, M, base_comm=base_comm,
213+
dtype=dtype_str, kind="summa",
214+
base_comm_nccl=nccl_comm)
215+
216+
x_dist = DistributedArray(
217+
global_shape=(K * M),
218+
local_shapes=base_comm.allgather(X_p.shape[0] * X_p.shape[1]),
219+
partition=Partition.SCATTER,
220+
base_comm_nccl=nccl_comm,
221+
dtype=dtype,
222+
engine="cupy",
223+
)
224+
225+
x_dist.local_array[:] = X_p.ravel()
226+
227+
# Forward operation: y = A @ x (distributed)
228+
y_dist = Aop @ x_dist
229+
230+
# Adjoint operation: xadj = A.H @ y (distributed)
231+
xadj_dist = Aop.H @ y_dist
232+
233+
# Re-organize in local matrix
234+
y = block_gather(y_dist, (N, M), base_comm)
235+
xadj = block_gather(xadj_dist, (K, M), base_comm)
236+
237+
if rank == 0:
238+
A_glob_np = A_glob.get()
239+
X_glob_np = X_glob.get()
240+
y_loc = A_glob_np @ X_glob_np
241+
assert_allclose(
242+
y.get().squeeze(),
243+
y_loc.squeeze(),
244+
rtol=np.finfo(np.dtype(dtype)).resolution,
245+
err_msg=f"Rank {rank}: Forward verification failed."
246+
)
247+
248+
xadj_loc = A_glob_np.conj().T @ y_loc
249+
assert_allclose(
250+
xadj.get().squeeze(),
251+
xadj_loc.squeeze(),
252+
rtol=np.finfo(np.dtype(dtype)).resolution,
253+
err_msg=f"Rank {rank}: Adjoint verification failed."
254+
)
255+
256+
# Chain with another operator
257+
Dop = Conj(dims=(A_p.shape[0], X_p.shape[1]))
258+
DBop = MPIBlockDiag(ops=[Dop, ], base_comm=base_comm)
259+
Op = DBop @ Aop
260+
261+
y1_dist = Op @ x_dist
262+
xadj1_dist = Op.H @ y1_dist
263+
264+
# Re-organize in local matrix
265+
y1 = block_gather(y1_dist, (N, M), base_comm)
266+
xadj1 = block_gather(xadj1_dist, (K, M), base_comm)
267+
268+
if rank == 0:
269+
A_glob_np = A_glob.get()
270+
X_glob_np = X_glob.get()
271+
y1_loc = ((A_glob_np @ X_glob_np).conj().ravel()).reshape(N, M)
272+
273+
assert_allclose(
274+
y1.get().squeeze(),
275+
y1_loc.squeeze(),
276+
rtol=np.finfo(y1_loc.dtype).resolution,
277+
err_msg=f"Rank {rank}: Forward verification failed."
278+
)
279+
280+
xadj1_loc = ((A_glob_np.conj().T @ y1_loc.conj()).ravel()).reshape(K, M)
281+
assert_allclose(
282+
xadj1.get().squeeze().ravel(),
283+
xadj1_loc.squeeze().ravel(),
284+
rtol=np.finfo(xadj1_loc.dtype).resolution,
285+
err_msg=f"Rank {rank}: Adjoint verification failed."
286+
)

0 commit comments

Comments
 (0)