Skip to content

Commit e7890ce

Browse files
Merge pull request #2926 from devitocodes/resurrect-mpi-dual-2
mpi: Drop dual mode and tweak overlap2/full modes for GPU backends
2 parents ddb2459 + 971467b commit e7890ce

11 files changed

Lines changed: 283 additions & 146 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1494,7 +1494,7 @@ def visit_ParallelTree(self, o):
14941494
def visit_HaloSpot(self, o):
14951495
hs = o.halo_scheme
14961496
fmapper = {self.mapper.get(k, k): v for k, v in hs.fmapper.items()}
1497-
halo_scheme = hs.build(fmapper, hs.honored)
1497+
halo_scheme = hs._rebuild(fmapper=fmapper)
14981498
body = self._visit(o.body)
14991499
return o._rebuild(halo_scheme=halo_scheme, body=body)
15001500

devito/mpi/distributed.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def nprocs_local(self):
261261

262262
@property
263263
def topology(self):
264-
return DimensionTuple(*self._topology, getters=self.dimensions)
264+
return self._topology
265265

266266
@property
267267
def topology_logical(self):
@@ -345,15 +345,21 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None):
345345
# each other as possible, using an appropriate divisibility
346346
# algorithm. Thus, in 3D:
347347
# * topology[0] >= topology[1] >= topology[2]
348-
# * topology[0] * topology[1] * topology[2] == self._input_comm.size
348+
# * topology[0]*topology[1]*topology[2] == self._input_comm.size
349+
349350
# However, `MPI.Compute_dims` is distro-dependent, so we have
350351
# to enforce some properties through our own wrapper (e.g.,
351352
# OpenMPI v3 does not guarantee that 9 ranks are arranged into
352353
# a 3x3 grid when shape=(9, 9))
353-
self._topology = compute_dims(self._input_comm.size, len(shape))
354+
self._topology = DimensionTuple(
355+
*compute_dims(self._input_comm.size, len(shape)),
356+
getters=dimensions
357+
)
354358
else:
355359
# A custom topology may contain integers or the wildcard '*'
356-
self._topology = CustomTopology(topology, self._input_comm)
360+
self._topology = CustomTopology(
361+
topology, self._input_comm, getters=dimensions
362+
)
357363

358364
if self._input_comm is not input_comm:
359365
# By default, Devito arranges processes into a cartesian topology.
@@ -896,7 +902,7 @@ def _arg_values(self, *args, **kwargs):
896902
return self._arg_defaults()
897903

898904

899-
class CustomTopology(tuple):
905+
class CustomTopology(DimensionTuple):
900906

901907
"""
902908
The CustomTopology class provides a mechanism to describe parametric domain
@@ -954,7 +960,7 @@ class CustomTopology(tuple):
954960
'xy': ('*', '*', 1),
955961
}
956962

957-
def __new__(cls, items, input_comm):
963+
def __new__(cls, items, input_comm, **kwargs):
958964
# Keep track of nstars and already defined decompositions
959965
nstars = items.count('*')
960966

@@ -992,11 +998,15 @@ def __new__(cls, items, input_comm):
992998
# Final check that topology matches the communicator size
993999
assert np.prod(processed) == input_comm.size
9941000

995-
obj = super().__new__(cls, processed)
1001+
obj = super().__new__(cls, *processed, **kwargs)
9961002
obj.logical = items
9971003

9981004
return obj
9991005

1006+
def __repr__(self):
1007+
return (f"CustomTopology(logical={self.logical}, "
1008+
f"physical={super().__repr__()})")
1009+
10001010

10011011
def compute_dims(nprocs, ndim):
10021012
# We don't do anything clever here. In fact, we do something very basic --

devito/mpi/halo_scheme.py

Lines changed: 85 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import sympy
88
from sympy import Max, Min
99

10-
from devito import configuration
1110
from devito.data import CENTER, CORE, LEFT, OWNED, RIGHT
1211
from devito.ir.support import Forward, Scope
12+
from devito.symbolics import IntDiv
1313
from devito.symbolics.manipulation import _uxreplace_registry
1414
from devito.tools import (
1515
EnrichedTuple, Reconstructable, Tag, as_tuple, filter_ordered, filter_sorted, flatten,
@@ -137,11 +137,9 @@ def __init__(self, exprs, ispace):
137137
# Derive the halo exchanges
138138
self._mapper = frozendict(classify(exprs, ispace))
139139

140-
# Track the IterationSpace offsets induced by SubDomains/SubDimensions.
141-
# These should be honored in the derivation of the `omapper`
140+
# Track the IterationSpace offsets induced by SubDomains/SubDimensions,
141+
# which are honored in the derivation of the `omapper`
142142
self._honored = {}
143-
# SubDimensions are not necessarily included directly in
144-
# ispace.dimensions and hence we need to first utilize the `_defines` method
145143
dims = set().union(*[d._defines for d in ispace.dimensions
146144
if d._defines & self.dimensions])
147145
subdims = [d for d in dims if d.is_Sub and not d.local]
@@ -150,6 +148,12 @@ def __init__(self, exprs, ispace):
150148
self._honored[i.root] = frozenset([(ltk, rtk)])
151149
self._honored = frozendict(self._honored)
152150

151+
# Further constraints on the `omapper` derivation. At construction time
152+
# there's none, but lowering passes may change this
153+
# * `_alignment` may be a positive integer representing the alignment
154+
# requirement, in number of *elements*, of the underlying expressions
155+
self._alignment = None
156+
153157
def __repr__(self):
154158
fnames = ",".join(i.name for i in set(self._mapper))
155159
return f"HaloScheme<{fnames}>"
@@ -165,11 +169,22 @@ def __len__(self):
165169
def __hash__(self):
166170
return hash((self._mapper.__hash__(), self.honored.__hash__()))
167171

168-
@classmethod
169-
def build(cls, fmapper, honored):
172+
def _rebuild(self, fmapper=None, honored=None, alignment=None):
173+
"""
174+
Rebuild a HaloScheme from the provided `fmapper` and `honored`. Reuse
175+
`self`'s values for the missing arguments.
176+
"""
170177
obj = object.__new__(HaloScheme)
178+
179+
if fmapper is None:
180+
fmapper = self._mapper
181+
if honored is None:
182+
honored = self._honored
183+
171184
obj._mapper = frozendict(fmapper)
172185
obj._honored = frozendict(honored)
186+
obj._alignment = alignment or self._alignment
187+
173188
return obj
174189

175190
@classmethod
@@ -223,7 +238,7 @@ def union(self, halo_schemes):
223238
for d, v in i.honored.items():
224239
honored[d] = honored.get(d, frozenset()) | v
225240

226-
return HaloScheme.build(fmapper, honored)
241+
return i._rebuild(fmapper=fmapper, honored=honored)
227242

228243
@property
229244
def honored(self):
@@ -241,10 +256,14 @@ def is_void(self):
241256
@cached_property
242257
def omapper(self):
243258
"""
244-
Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions.
259+
Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions,
260+
"cumulative" over all DiscreteFunctions in the HaloScheme.
245261
246-
This is "cumulative" over all DiscreteFunctions in the HaloScheme; it also
247-
takes into account IterationSpace offsets induced by SubDomains/SubDimensions.
262+
The computed OMapper takes into account:
263+
264+
* The offsets induced by SubDomains/SubDimensions ("thickness");
265+
* Any data alignment requirement of the underlying expressions
266+
(`_alignment` attribute).
248267
249268
Examples
250269
--------
@@ -366,28 +385,62 @@ def omapper(self):
366385

367386
if s is CENTER:
368387
where.append((d, CORE, s))
369-
mapper[d] = (d.symbolic_min + osl,
370-
d.symbolic_max - osr)
388+
389+
mapper[d] = (
390+
d.symbolic_min + osl,
391+
d.symbolic_max - osr
392+
)
393+
371394
if nl != 0:
372395
mapper[nl] = (Max(nl - osl, 0),)
373396
if nr != 0:
374397
mapper[nr] = (Max(nr - osr, 0),)
375398
else:
376399
where.append((d, OWNED, s))
400+
377401
if s is LEFT:
378-
mapper[d] = (d.symbolic_min,
379-
Min(d.symbolic_min + osl - 1, d.symbolic_max - nr))
402+
mapper[d] = (
403+
d.symbolic_min,
404+
Min(d.symbolic_min + osl - 1, d.symbolic_max - nr)
405+
)
406+
380407
if nl != 0:
381408
mapper[nl] = (nl,)
382409
mapper[nr] = (0,)
383410
else:
384-
mapper[d] = (Max(d.symbolic_max - osr + 1, d.symbolic_min + nl),
385-
d.symbolic_max)
411+
mapper[d] = (
412+
Max(d.symbolic_max - osr + 1, d.symbolic_min + nl),
413+
d.symbolic_max
414+
)
415+
386416
if nr != 0:
387417
mapper[nl] = (0,)
388418
mapper[nr] = (nr,)
419+
389420
processed.append((tuple(where), frozendict(mapper)))
390421

422+
# Apply the alignment constraints, if any
423+
# First, get the fastest varying (contiguous) Dimension, which is the
424+
# one that matters for alignment
425+
if self._alignment:
426+
fvds = {f.dimensions[-1] for f in self.fmapper}
427+
if len(fvds) != 1:
428+
raise HaloSchemeException(
429+
"Unexpected contiguous Dimensions found while computing the "
430+
f"`omapper`: {fvds}"
431+
)
432+
fvd = fvds.pop()
433+
434+
for i, (where, mapper) in enumerate(list(processed)):
435+
try:
436+
m, M = mapper[fvd]
437+
except KeyError:
438+
continue
439+
440+
aligned_m = IntDiv(m, self._alignment) * self._alignment
441+
442+
processed[i] = (where, frozendict({**mapper, fvd: (aligned_m, M)}))
443+
391444
_, core = processed.pop(0)
392445
owned = processed
393446

@@ -483,15 +536,15 @@ def project(self, functions):
483536
to the provided `functions`.
484537
"""
485538
fmapper = {f: v for f, v in self.fmapper.items() if f in as_tuple(functions)}
486-
return HaloScheme.build(fmapper, self.honored)
539+
return self._rebuild(fmapper=fmapper)
487540

488541
def drop(self, functions):
489542
"""
490543
Create a new HaloScheme that contains all entries in `self` except those
491544
corresponding to the provided `functions`.
492545
"""
493546
fmapper = {f: v for f, v in self.fmapper.items() if f not in as_tuple(functions)}
494-
return HaloScheme.build(fmapper, self.honored)
547+
return self._rebuild(fmapper=fmapper)
495548

496549
def add(self, f, hse):
497550
"""
@@ -503,7 +556,7 @@ def add(self, f, hse):
503556
if f in fmapper:
504557
hse = fmapper[f].union(hse)
505558
fmapper[f] = hse
506-
return HaloScheme.build(fmapper, self.honored)
559+
return self._rebuild(fmapper=fmapper)
507560

508561
def merge(self, hs):
509562
"""
@@ -512,20 +565,14 @@ def merge(self, hs):
512565
fmapper = dict(self.fmapper)
513566
for f, hse in hs.fmapper.items():
514567
fmapper[f] = fmapper.get(f, hse).merge(hse)
515-
return HaloScheme.build(fmapper, self.honored)
568+
return self._rebuild(fmapper=fmapper)
516569

517570

518571
def classify(exprs, ispace):
519572
"""
520573
Produce the mapper `Function -> HaloSchemeEntry`, which describes the necessary
521574
halo exchanges in the given Scope.
522575
"""
523-
524-
# Some MPI modes require pulling the `loc_indices` from the reads, others
525-
# from the writes. It essentially depends on whether the halo exchange is
526-
# performed before (reads) or after (writes) the OWNED region is computed
527-
loc_indices_from_reads = configuration['mpi'] not in ('dual',)
528-
529576
scope = Scope(exprs)
530577

531578
mapper = {}
@@ -565,15 +612,17 @@ def classify(exprs, ispace):
565612
else:
566613
v[(d, LEFT)] = STENCIL
567614
v[(d, RIGHT)] = STENCIL
568-
elif loc_indices_from_reads:
615+
else:
569616
v[(d, i[d])] = NONE
570617

571618
# Does `i` actually require a halo exchange?
572619
if not any(hl is STENCIL for hl in v.values()):
573620
continue
574621

575622
# Derive diagonal halo exchanges from the previous analysis
576-
combs = list(product([LEFT, CENTER, RIGHT], repeat=len(f._dist_dimensions)))
623+
combs = list(
624+
product([LEFT, CENTER, RIGHT], repeat=len(f._dist_dimensions))
625+
)
577626
combs.remove((CENTER,)*len(f._dist_dimensions))
578627
for c in combs:
579628
key = (f._dist_dimensions, c)
@@ -598,13 +647,6 @@ def classify(exprs, ispace):
598647
if not halo_labels:
599648
continue
600649

601-
# Augment `halo_labels` with `loc_indices`-related information if necessary
602-
if not loc_indices_from_reads:
603-
for i in scope.writes.get(f, []):
604-
for d in i.findices:
605-
if not f.grid.is_distributed(d):
606-
halo_labels[(d, i[d])].add(NONE)
607-
608650
# Separate halo-exchange Dimensions from `loc_indices`
609651
raw_loc_indices, halos = defaultdict(list), []
610652
for (d, s), hl in halo_labels.items():
@@ -613,15 +655,18 @@ def classify(exprs, ispace):
613655
if not hl:
614656
continue
615657
elif len(hl) > 1:
616-
raise HaloSchemeException("Inconsistency found while building a halo "
617-
f"scheme for `{f}` along Dimension `{d}`")
658+
raise HaloSchemeException(
659+
"Inconsistency found while building a halo scheme for "
660+
f"`{f}` along Dimension `{d}`")
618661
elif hl.pop() is STENCIL:
619662
halos.append(Halo(d, s))
620663
elif d._defines & set(ispace.itdims):
621664
raw_loc_indices[d].append(s)
622665

623-
loc_indices, loc_dirs = process_loc_indices(raw_loc_indices,
624-
ispace.directions)
666+
loc_indices, loc_dirs = process_loc_indices(
667+
raw_loc_indices, ispace.directions
668+
)
669+
625670
mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)
626671

627672
return mapper

0 commit comments

Comments
 (0)