Skip to content

Commit 6d7df77

Browse files
committed
mpi: fix distribution dtype
1 parent f17919a commit 6d7df77

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

devito/data/distributed/transport.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818

1919
from devito.mpi import MPI
20-
from devito.tools import dtype_to_mpidtype
20+
from devito.tools import dtype_to_mpidtype, mpi4py_mapper
2121

2222
__all__ = ['sparse_exchange']
2323

@@ -53,7 +53,13 @@ def sparse_exchange(comm, sendbufs, dtype, tag=0):
5353
"""
5454
rank = comm.Get_rank()
5555
nprocs = comm.Get_size()
56-
mpitype = dtype_to_mpidtype(dtype)
56+
57+
# Some MPI builds lack a native datatype for `dtype` (e.g. `float16`); send
58+
# over a same-size byte-equivalent wire type and view back on receipt, just
59+
# as the halo exchange does via `comm_dtype`. A no-op for mapped-to-self
60+
# types.
61+
wire = np.dtype(mpi4py_mapper.get(np.dtype(dtype).type, dtype))
62+
mpitype = dtype_to_mpidtype(wire)
5763

5864
recvd = {}
5965

@@ -79,7 +85,7 @@ def sparse_exchange(comm, sendbufs, dtype, tag=0):
7985
for peer, buf in sendbufs.items():
8086
if peer == rank or buf.size == 0:
8187
continue
82-
buf = np.ascontiguousarray(buf)
88+
buf = np.ascontiguousarray(buf).view(wire)
8389
live_bufs.append(buf)
8490
sends.append(comm.Isend([buf, mpitype], dest=peer, tag=tag))
8591

@@ -89,9 +95,9 @@ def sparse_exchange(comm, sendbufs, dtype, tag=0):
8995
comm.Probe(source=MPI.ANY_SOURCE, tag=tag, status=status)
9096
src = status.Get_source()
9197
count = status.Get_count(mpitype)
92-
buf = np.empty(count, dtype=dtype)
98+
buf = np.empty(count, dtype=wire)
9399
comm.Recv([buf, mpitype], source=src, tag=tag)
94-
recvd[src] = buf
100+
recvd[src] = buf.view(dtype)
95101

96102
MPI.Request.Waitall(sends)
97103
return recvd

0 commit comments

Comments
 (0)