Skip to content

Commit 60d8c80

Browse files
committed
data: Replace NBX consensus with a portable sparse exchange
The point-to-point router used the NBX nonblocking-consensus algorithm (Issend + Iprobe(ANY_SOURCE) + Ibarrier). It works under IntelMPI and mpirun, but deadlocks under OpenMPI inside the ipyparallel example notebooks -- whereas ordinary collective/halo MPI (e.g. op.apply) runs fine there. So the NBX pattern itself is the fragile part. Replace it with a reduce-scatter-based sparse exchange: each rank learns how many peers will send to it via one Reduce_scatter_block over a length-nprocs 0/1 indicator, then Isend / Probe / Recv / Waitall. These are the same standard, portable calls that work elsewhere; payloads still move strictly point-to-point (no data all-to-all). Rename nbx_exchange -> sparse_exchange and nbx_push -> sparse_push accordingly, and revert the ineffective OpenMPI yield CI env var. Correctness verified under mpirun: TestDataDistributed, TestDataGather and the sparse advanced-indexing tests pass (modes 4 and 6).
1 parent dbf3379 commit 60d8c80

5 files changed

Lines changed: 57 additions & 59 deletions

File tree

.github/workflows/examples-mpi.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ jobs:
4444
DEVITO_ARCH: "gcc"
4545
CC: "gcc"
4646
CXX: "g++"
47-
# Make OpenMPI yield the CPU while waiting instead of busy-spinning. With
48-
# 4 ipyparallel engines contending for cores, the point-to-point routing
49-
# in the data notebooks otherwise livelocks under OpenMPI (IntelMPI yields
50-
# by default, so it is unaffected and ignores this OMPI_* setting).
51-
OMPI_MCA_mpi_yield_when_idle: "1"
5247

5348
steps:
5449
- name: Checkout devito

devito/data/distributed/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
from devito.data.distributed.plan import ExchangePlan # noqa
2020
from devito.data.distributed.redistribution import redistribute_set # noqa
2121
from devito.data.distributed.selection import Selection # noqa
22-
from devito.data.distributed.transport import nbx_exchange # noqa
22+
from devito.data.distributed.transport import sparse_exchange # noqa

devito/data/distributed/plan.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
import numpy as np
2424

2525
from devito.data.distributed.selection import Affine, IndexScalar
26-
from devito.data.distributed.transport import nbx_exchange
26+
from devito.data.distributed.transport import sparse_exchange
2727
from devito.mpi import MPI
2828
from devito.tools import prod
2929

30-
__all__ = ['ExchangePlan', 'nbx_push']
30+
__all__ = ['ExchangePlan', 'sparse_push']
3131

3232

3333
class ExchangePlan:
@@ -206,7 +206,7 @@ def get(self, local):
206206
# Send each owner the offsets of the elements we want from it...
207207
headers = {r: _encode(ps, self._block_offsets, dist_lin)
208208
for r, (_, dist_lin) in self._peers.items()}
209-
requests = nbx_exchange(comm, headers, np.int64, tag=41)
209+
requests = sparse_exchange(comm, headers, np.int64, tag=41)
210210

211211
# ...and reply to whoever asked us with the requested values
212212
moved = self._moved(local)
@@ -215,7 +215,7 @@ def get(self, local):
215215
block_offsets, dist_lin = _decode(buf)
216216
midx = self._owner_apply(moved, dist_lin, block_offsets)
217217
replies[src] = np.ascontiguousarray(moved[midx]).reshape(-1)
218-
payloads = nbx_exchange(comm, replies, dtype, tag=42)
218+
payloads = sparse_exchange(comm, replies, dtype, tag=42)
219219

220220
# Scatter the received values back into result-row order
221221
rows_flat = np.zeros((self._nrows(), ps), dtype=dtype)
@@ -239,9 +239,9 @@ def put(self, local, value):
239239
"""
240240
self._raise_on_error(check_dup=True)
241241
rows_flat = self._value_to_rows(value, local.dtype)
242-
nbx_push(self.comm, self.layout.distributed_axes, self._repl_total,
243-
self._peers, self._block_offsets, self.payload_size, rows_flat,
244-
local)
242+
sparse_push(self.comm, self.layout.distributed_axes, self._repl_total,
243+
self._peers, self._block_offsets, self.payload_size,
244+
rows_flat, local)
245245

246246
# ------------------------------------------------------- result <-> rows
247247

@@ -422,8 +422,8 @@ def _group_peers(layout, owners, dist_local, sub, gcoords):
422422
return peers, oob_error, dup_error
423423

424424

425-
def nbx_push(comm, distributed_axes, repl_total, peers, block_offsets,
426-
payload_size, rows_flat, local):
425+
def sparse_push(comm, distributed_axes, repl_total, peers, block_offsets,
426+
payload_size, rows_flat, local):
427427
"""
428428
Route `rows_flat` to the owner ranks (NBX) and scatter each received
429429
payload into `local` at its owner-local position.
@@ -456,8 +456,8 @@ def nbx_push(comm, distributed_axes, repl_total, peers, block_offsets,
456456
for r, (_, dist_lin) in peers.items()}
457457
payloads = {r: rows_flat[rows].reshape(-1)
458458
for r, (rows, _) in peers.items() if rows.size}
459-
requests = nbx_exchange(comm, headers, np.int64, tag=43)
460-
values = nbx_exchange(comm, payloads, rows_flat.dtype, tag=44)
459+
requests = sparse_exchange(comm, headers, np.int64, tag=43)
460+
values = sparse_exchange(comm, payloads, rows_flat.dtype, tag=44)
461461

462462
# ...then scatter whatever we received into our own local array
463463
moved = np.moveaxis(local, distributed_axes, range(len(distributed_axes)))

devito/data/distributed/redistribution.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020

2121
from devito.data.distributed.layout import Layout
22-
from devito.data.distributed.plan import _group_peers, _resolve_owners, nbx_push
22+
from devito.data.distributed.plan import _group_peers, _resolve_owners, sparse_push
2323
from devito.data.distributed.selection import Affine, Selection
2424
from devito.mpi import MPI
2525

@@ -121,12 +121,12 @@ def _push(layout, gcoords, values, local):
121121
Push `values` (one per global coordinate in `gcoords`) to their owners.
122122
123123
A structured assignment has exactly one value per distributed point and no
124-
replicated payload, so it is `nbx_push` with `payload_size == 1`
124+
replicated payload, so it is `sparse_push` with `payload_size == 1`
125125
(`block_offsets == [0]`, `repl_total == 1`).
126126
"""
127127
owners, dist_local, sub = _resolve_owners(None, layout, gcoords)
128128
peers, _, _ = _group_peers(layout, owners, dist_local, sub, gcoords)
129129

130130
block_offsets = np.zeros(1, dtype=np.int64) # no replicated payload
131-
nbx_push(layout.distributor.comm, layout.distributed_axes, 1, peers,
132-
block_offsets, 1, values.reshape(-1, 1), local)
131+
sparse_push(layout.distributor.comm, layout.distributed_axes, 1, peers,
132+
block_offsets, 1, values.reshape(-1, 1), local)

devito/data/distributed/transport.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,35 @@
22
Transport layer for distributed data redistribution.
33
44
This module knows nothing about indexing or `Data`; it only moves contiguous
5-
buffers between MPI ranks. The single primitive, `nbx_exchange`, performs
6-
a sparse "all-to-some" exchange in which only the ranks that actually share data
7-
ever communicate. It can be swapped for neighbor collectives or a persistent
8-
graph communicator without affecting the layers above.
5+
buffers between MPI ranks. The single primitive, `sparse_exchange`, performs a
6+
sparse "all-to-some" exchange in which only the ranks that actually share data
7+
exchange payloads.
8+
9+
Each rank first learns *how many* peers will send to it via a single small
10+
`Reduce_scatter_block` over one integer per rank, then posts the point-to-point
11+
messages and receives exactly that many. This relies only on standard, widely
12+
portable MPI calls (no synchronous-send / nonblocking-barrier consensus), so it
13+
behaves uniformly across MPI implementations; the payloads themselves still move
14+
strictly point-to-point, so no data all-to-all takes place.
915
"""
1016

1117
import numpy as np
1218

1319
from devito.mpi import MPI
1420
from devito.tools import dtype_to_mpidtype
1521

16-
__all__ = ['nbx_exchange']
22+
__all__ = ['sparse_exchange']
1723

1824

19-
def nbx_exchange(comm, sendbufs, dtype, tag=0):
25+
def sparse_exchange(comm, sendbufs, dtype, tag=0):
2026
"""
21-
Sparse "all-to-some" exchange via nonblocking consensus (NBX).
27+
Sparse "all-to-some" exchange of contiguous buffers.
2228
23-
Each rank sends a buffer to each peer listed in `sendbufs` and receives
24-
from whichever ranks send to it, without any rank needing global knowledge
25-
of the communication pattern and without any dense collective. Only ranks
26-
that actually exchange data communicate; global termination is detected with
27-
a single nonblocking barrier (log-depth).
29+
Each rank sends a buffer to each peer listed in `sendbufs` and receives from
30+
whichever ranks send to it. The number of incoming messages is discovered
31+
with one `Reduce_scatter_block` over a length-`nprocs` 0/1 indicator (a few
32+
bytes per rank); only ranks that share data then exchange payloads, strictly
33+
point-to-point.
2834
2935
Parameters
3036
----------
@@ -44,16 +50,9 @@ def nbx_exchange(comm, sendbufs, dtype, tag=0):
4450
dict
4551
Maps each source rank to the 1D buffer received from it. The caller
4652
reshapes using its known payload shape.
47-
48-
Notes
49-
-----
50-
Implements the NBX algorithm (Hoefler et al., "Scalable Communication
51-
Protocols for Dynamic Sparse Data Exchange"). Synchronous sends (`Issend`)
52-
complete only once matched by a receive, so a rank can enter the nonblocking
53-
barrier only after every message it sent has been taken. Once all ranks are
54-
in the barrier no message is in flight, so probing can safely stop.
5553
"""
5654
rank = comm.Get_rank()
55+
nprocs = comm.Get_size()
5756
mpitype = dtype_to_mpidtype(dtype)
5857

5958
recvd = {}
@@ -63,32 +62,36 @@ def nbx_exchange(comm, sendbufs, dtype, tag=0):
6362
if local is not None and local.size:
6463
recvd[rank] = np.ravel(np.ascontiguousarray(local))
6564

66-
# Post synchronous sends to every other peer. The buffers must stay alive
67-
# until the matching requests complete, hence `live_bufs`.
65+
# Discover how many peers will send to this rank: the column sum of a 0/1
66+
# "rank r sends to rank c" matrix, scattered so each rank gets its own count.
67+
indicator = np.zeros(nprocs, dtype=np.int32)
68+
for peer, buf in sendbufs.items():
69+
if peer != rank and buf.size:
70+
indicator[peer] = 1
71+
incoming = np.zeros(1, dtype=np.int32)
72+
comm.Reduce_scatter_block([indicator, MPI.INT], [incoming, MPI.INT],
73+
op=MPI.SUM)
74+
75+
# Post the point-to-point sends. The buffers must stay alive until the
76+
# matching requests complete, hence `live_bufs`.
6877
sends = []
6978
live_bufs = []
7079
for peer, buf in sendbufs.items():
7180
if peer == rank or buf.size == 0:
7281
continue
7382
buf = np.ascontiguousarray(buf)
7483
live_bufs.append(buf)
75-
sends.append(comm.Issend([buf, mpitype], dest=peer, tag=tag))
84+
sends.append(comm.Isend([buf, mpitype], dest=peer, tag=tag))
7685

77-
barrier = None
86+
# Receive exactly the expected number of messages, sizing each from its probe
7887
status = MPI.Status()
79-
while True:
80-
if comm.Iprobe(source=MPI.ANY_SOURCE, tag=tag, status=status):
81-
src = status.Get_source()
82-
count = status.Get_count(mpitype)
83-
buf = np.empty(count, dtype=dtype)
84-
comm.Recv([buf, mpitype], source=src, tag=tag)
85-
recvd[src] = buf
86-
elif barrier is None:
87-
if MPI.Request.Testall(sends):
88-
# All my sends were matched -> announce I am done sending
89-
barrier = comm.Ibarrier()
90-
elif barrier.Test():
91-
# Everyone is done sending and nothing is in flight
92-
break
88+
for _ in range(int(incoming[0])):
89+
comm.Probe(source=MPI.ANY_SOURCE, tag=tag, status=status)
90+
src = status.Get_source()
91+
count = status.Get_count(mpitype)
92+
buf = np.empty(count, dtype=dtype)
93+
comm.Recv([buf, mpitype], source=src, tag=tag)
94+
recvd[src] = buf
9395

96+
MPI.Request.Waitall(sends)
9497
return recvd

0 commit comments

Comments
 (0)