1717import numpy as np
1818
1919from devito .mpi import MPI
20- from devito .tools import dtype_to_mpidtype
20+ from devito .tools import dtype_to_mpidtype , mpi4py_mapper
2121
2222__all__ = ['sparse_exchange' ]
2323
@@ -53,7 +53,13 @@ def sparse_exchange(comm, sendbufs, dtype, tag=0):
5353 """
5454 rank = comm .Get_rank ()
5555 nprocs = comm .Get_size ()
56- mpitype = dtype_to_mpidtype (dtype )
56+
57+ # Some MPI builds lack a native datatype for `dtype` (e.g. `float16`); send
58+ # over a same-size byte-equivalent wire type and view back on receipt, just
59+ # as the halo exchange does via `comm_dtype`. A no-op for mapped-to-self
60+ # types.
61+ wire = np .dtype (mpi4py_mapper .get (np .dtype (dtype ).type , dtype ))
62+ mpitype = dtype_to_mpidtype (wire )
5763
5864 recvd = {}
5965
@@ -79,7 +85,7 @@ def sparse_exchange(comm, sendbufs, dtype, tag=0):
7985 for peer , buf in sendbufs .items ():
8086 if peer == rank or buf .size == 0 :
8187 continue
82- buf = np .ascontiguousarray (buf )
88+ buf = np .ascontiguousarray (buf ). view ( wire )
8389 live_bufs .append (buf )
8490 sends .append (comm .Isend ([buf , mpitype ], dest = peer , tag = tag ))
8591
@@ -89,9 +95,9 @@ def sparse_exchange(comm, sendbufs, dtype, tag=0):
8995 comm .Probe (source = MPI .ANY_SOURCE , tag = tag , status = status )
9096 src = status .Get_source ()
9197 count = status .Get_count (mpitype )
92- buf = np .empty (count , dtype = dtype )
98+ buf = np .empty (count , dtype = wire )
9399 comm .Recv ([buf , mpitype ], source = src , tag = tag )
94- recvd [src ] = buf
100+ recvd [src ] = buf . view ( dtype )
95101
96102 MPI .Request .Waitall (sends )
97103 return recvd
0 commit comments