22Transport layer for distributed data redistribution.
33
44This module knows nothing about indexing or `Data`; it only moves contiguous
5- buffers between MPI ranks. The single primitive, `nbx_exchange`, performs
6- a sparse "all-to-some" exchange in which only the ranks that actually share data
7- ever communicate. It can be swapped for neighbor collectives or a persistent
8- graph communicator without affecting the layers above.
5+ buffers between MPI ranks. The single primitive, `sparse_exchange`, performs a
6+ sparse "all-to-some" exchange in which only the ranks that actually share data
7+ exchange payloads.
8+
9+ Each rank first learns *how many* peers will send to it via a single small
10+ `Reduce_scatter_block` over one integer per rank, then posts the point-to-point
11+ messages and receives exactly that many. This relies only on standard, widely
12+ portable MPI calls (no synchronous-send / nonblocking-barrier consensus), so it
13+ behaves uniformly across MPI implementations; the payloads themselves still move
14+ strictly point-to-point, so no data all-to-all takes place.
915"""
1016
1117import numpy as np
1218
1319from devito .mpi import MPI
1420from devito .tools import dtype_to_mpidtype
1521
16- __all__ = ['nbx_exchange ' ]
22+ __all__ = ['sparse_exchange ' ]
1723
1824
19- def nbx_exchange (comm , sendbufs , dtype , tag = 0 ):
25+ def sparse_exchange (comm , sendbufs , dtype , tag = 0 ):
2026 """
21- Sparse "all-to-some" exchange via nonblocking consensus (NBX) .
27+ Sparse "all-to-some" exchange of contiguous buffers .
2228
23- Each rank sends a buffer to each peer listed in `sendbufs` and receives
24- from whichever ranks send to it, without any rank needing global knowledge
25- of the communication pattern and without any dense collective. Only ranks
26- that actually exchange data communicate; global termination is detected with
27- a single nonblocking barrier (log-depth) .
29+ Each rank sends a buffer to each peer listed in `sendbufs` and receives from
30+ whichever ranks send to it. The number of incoming messages is discovered
31+ with one `Reduce_scatter_block` over a length-`nprocs` 0/1 indicator (a few
32+ bytes per rank); only ranks that share data then exchange payloads, strictly
33+ point-to-point .
2834
2935 Parameters
3036 ----------
@@ -44,16 +50,9 @@ def nbx_exchange(comm, sendbufs, dtype, tag=0):
4450 dict
4551 Maps each source rank to the 1D buffer received from it. The caller
4652 reshapes using its known payload shape.
47-
48- Notes
49- -----
50- Implements the NBX algorithm (Hoefler et al., "Scalable Communication
51- Protocols for Dynamic Sparse Data Exchange"). Synchronous sends (`Issend`)
52- complete only once matched by a receive, so a rank can enter the nonblocking
53- barrier only after every message it sent has been taken. Once all ranks are
54- in the barrier no message is in flight, so probing can safely stop.
5553 """
5654 rank = comm .Get_rank ()
55+ nprocs = comm .Get_size ()
5756 mpitype = dtype_to_mpidtype (dtype )
5857
5958 recvd = {}
@@ -63,32 +62,36 @@ def nbx_exchange(comm, sendbufs, dtype, tag=0):
6362 if local is not None and local .size :
6463 recvd [rank ] = np .ravel (np .ascontiguousarray (local ))
6564
66- # Post synchronous sends to every other peer. The buffers must stay alive
67- # until the matching requests complete, hence `live_bufs`.
65+ # Discover how many peers will send to this rank: the column sum of a 0/1
66+ # "rank r sends to rank c" matrix, scattered so each rank gets its own count.
67+ indicator = np .zeros (nprocs , dtype = np .int32 )
68+ for peer , buf in sendbufs .items ():
69+ if peer != rank and buf .size :
70+ indicator [peer ] = 1
71+ incoming = np .zeros (1 , dtype = np .int32 )
72+ comm .Reduce_scatter_block ([indicator , MPI .INT ], [incoming , MPI .INT ],
73+ op = MPI .SUM )
74+
75+ # Post the point-to-point sends. The buffers must stay alive until the
76+ # matching requests complete, hence `live_bufs`.
6877 sends = []
6978 live_bufs = []
7079 for peer , buf in sendbufs .items ():
7180 if peer == rank or buf .size == 0 :
7281 continue
7382 buf = np .ascontiguousarray (buf )
7483 live_bufs .append (buf )
75- sends .append (comm .Issend ([buf , mpitype ], dest = peer , tag = tag ))
84+ sends .append (comm .Isend ([buf , mpitype ], dest = peer , tag = tag ))
7685
77- barrier = None
86+ # Receive exactly the expected number of messages, sizing each from its probe
7887 status = MPI .Status ()
79- while True :
80- if comm .Iprobe (source = MPI .ANY_SOURCE , tag = tag , status = status ):
81- src = status .Get_source ()
82- count = status .Get_count (mpitype )
83- buf = np .empty (count , dtype = dtype )
84- comm .Recv ([buf , mpitype ], source = src , tag = tag )
85- recvd [src ] = buf
86- elif barrier is None :
87- if MPI .Request .Testall (sends ):
88- # All my sends were matched -> announce I am done sending
89- barrier = comm .Ibarrier ()
90- elif barrier .Test ():
91- # Everyone is done sending and nothing is in flight
92- break
88+ for _ in range (int (incoming [0 ])):
89+ comm .Probe (source = MPI .ANY_SOURCE , tag = tag , status = status )
90+ src = status .Get_source ()
91+ count = status .Get_count (mpitype )
92+ buf = np .empty (count , dtype = dtype )
93+ comm .Recv ([buf , mpitype ], source = src , tag = tag )
94+ recvd [src ] = buf
9395
96+ MPI .Request .Waitall (sends )
9497 return recvd
0 commit comments