Skip to content

Commit 9de8eb7

Browse files
committed
mpi: Polish _mpi_advanced_... utilities
1 parent 0b87d5d commit 9de8eb7

3 files changed

Lines changed: 19 additions & 11 deletions

File tree

devito/data/data.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,16 @@ def T(self):
261261

262262
def _mpi_advanced_1d_target(self, glb_idx, axis):
263263
"""
264-
Return the raw local ndarray addressed by ``glb_idx`` without
265-
advanced indexing along ``axis``.
264+
Return a raw local view for MPI advanced-indexing communication.
266265
267-
The MPI advanced-indexing helper code in ``devito.data.utils`` owns
268-
the communication pattern; this hook is kept on ``Data`` because only
269-
the subclass can bypass its own ``__getitem__`` and obtain a plain
270-
ndarray view.
266+
The returned view is indexed by ``glb_idx`` on all dimensions except
267+
``axis``, which is replaced by ``slice(None)`` so the helper code in
268+
``devito.data.utils`` can pack or unpack the requested global integer
269+
entries itself. ``target_axis`` is the position of ``axis`` in the
270+
returned view after scalar-indexed dimensions have been dropped.
271+
272+
This hook is kept on ``Data`` because only the subclass can bypass its
273+
own ``__getitem__`` and obtain a plain ndarray view.
271274
"""
272275
target_idx = list(glb_idx)
273276
target_idx[axis] = slice(None)

devito/data/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,15 @@ def mpi_advanced_1d_set(data, glb_idx, val, axis, indices, decomposition,
218218

219219

220220
def _mpi_advanced_1d_error(data, error):
221-
"""Raise the first error reported by any rank, on every rank."""
221+
"""Raise rank-local advanced-indexing errors on every rank."""
222222
if data._distributor.nprocs > 1:
223223
errors = data._distributor.comm.allgather(error)
224-
error = next((i for i in errors if i is not None), None)
224+
errors = [
225+
f"rank {rank}: {msg}"
226+
for rank, msg in enumerate(errors)
227+
if msg is not None
228+
]
229+
error = "; ".join(errors) if errors else None
225230

226231
if error is not None:
227232
raise ValueError(error)

tests/test_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,16 +1499,16 @@ def test_advanced_indexing_errors(self, mode):
14991499
np.empty(0, dtype=np.int64)
15001500
duplicate_data = np.ones(duplicate_index.size, dtype=f.dtype)
15011501

1502-
with pytest.raises(ValueError, match="Duplicate global indices"):
1502+
with pytest.raises(ValueError, match="rank 0:.*Duplicate global indices"):
15031503
f.data[duplicate_index] = duplicate_data
15041504

15051505
oob_index = np.array([8]) if rank == 0 else np.empty(0, dtype=np.int64)
15061506
oob_data = np.ones(oob_index.size, dtype=f.dtype)
15071507

1508-
with pytest.raises(ValueError, match="out-of-bounds"):
1508+
with pytest.raises(ValueError, match="rank 0:.*out-of-bounds"):
15091509
f.data[oob_index] = oob_data
15101510

1511-
with pytest.raises(ValueError, match="out-of-bounds"):
1511+
with pytest.raises(ValueError, match="rank 0:.*out-of-bounds"):
15121512
f.data[oob_index]
15131513

15141514
@pytest.mark.parallel(mode=4)

0 commit comments

Comments
 (0)