Skip to content

Commit 86d0e63

Browse files
committed
compiler: Fix HaloScheme.omapper taking alignment into account
1 parent cdd3b2a commit 86d0e63

4 files changed

Lines changed: 93 additions & 16 deletions

File tree

devito/mpi/halo_scheme.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from devito.data import CENTER, CORE, LEFT, OWNED, RIGHT
1111
from devito.ir.support import Forward, Scope
12+
from devito.symbolics import IntDiv
1213
from devito.symbolics.manipulation import _uxreplace_registry
1314
from devito.tools import (
1415
EnrichedTuple, Reconstructable, Tag, as_tuple, filter_ordered, filter_sorted, flatten,
@@ -147,6 +148,10 @@ def __init__(self, exprs, ispace):
147148
self._honored[i.root] = frozenset([(ltk, rtk)])
148149
self._honored = frozendict(self._honored)
149150

151+
# Further constraints on the `omapper` derivation. At construction time
152+
# there's none, but lowering passes may change this
153+
self._alignment = None
154+
150155
def __repr__(self):
151156
fnames = ",".join(i.name for i in set(self._mapper))
152157
return f"HaloScheme<{fnames}>"
@@ -162,7 +167,7 @@ def __len__(self):
162167
def __hash__(self):
163168
return hash((self._mapper.__hash__(), self.honored.__hash__()))
164169

165-
def _rebuild(self, fmapper=None, honored=None):
170+
def _rebuild(self, fmapper=None, honored=None, alignment=None):
166171
"""
167172
Rebuild a HaloScheme from the provided `fmapper` and `honored`. Reuse
168173
`self`'s values for the missing arguments.
@@ -176,6 +181,7 @@ def _rebuild(self, fmapper=None, honored=None):
176181

177182
obj._mapper = frozendict(fmapper)
178183
obj._honored = frozendict(honored)
184+
obj._alignment = alignment or self._alignment
179185

180186
return obj
181187

@@ -248,10 +254,14 @@ def is_void(self):
248254
@cached_property
249255
def omapper(self):
250256
"""
251-
Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions.
257+
Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions,
258+
"cumulative" over all DiscreteFunctions in the HaloScheme.
259+
260+
The computed OMapper takes into account:
252261
253-
This is "cumulative" over all DiscreteFunctions in the HaloScheme; it also
254-
takes into account IterationSpace offsets induced by SubDomains/SubDimensions.
262+
* The offsets induced by SubDomains/SubDimensions ("thickness");
263+
* Any data alignment requirement of the underlying expressions
264+
(`_alignment` attribute).
255265
256266
Examples
257267
--------
@@ -373,28 +383,62 @@ def omapper(self):
373383

374384
if s is CENTER:
375385
where.append((d, CORE, s))
376-
mapper[d] = (d.symbolic_min + osl,
377-
d.symbolic_max - osr)
386+
387+
mapper[d] = (
388+
d.symbolic_min + osl,
389+
d.symbolic_max - osr
390+
)
391+
378392
if nl != 0:
379393
mapper[nl] = (Max(nl - osl, 0),)
380394
if nr != 0:
381395
mapper[nr] = (Max(nr - osr, 0),)
382396
else:
383397
where.append((d, OWNED, s))
398+
384399
if s is LEFT:
385-
mapper[d] = (d.symbolic_min,
386-
Min(d.symbolic_min + osl - 1, d.symbolic_max - nr))
400+
mapper[d] = (
401+
d.symbolic_min,
402+
Min(d.symbolic_min + osl - 1, d.symbolic_max - nr)
403+
)
404+
387405
if nl != 0:
388406
mapper[nl] = (nl,)
389407
mapper[nr] = (0,)
390408
else:
391-
mapper[d] = (Max(d.symbolic_max - osr + 1, d.symbolic_min + nl),
392-
d.symbolic_max)
409+
mapper[d] = (
410+
Max(d.symbolic_max - osr + 1, d.symbolic_min + nl),
411+
d.symbolic_max
412+
)
413+
393414
if nr != 0:
394415
mapper[nl] = (0,)
395416
mapper[nr] = (nr,)
417+
396418
processed.append((tuple(where), frozendict(mapper)))
397419

420+
# Apply the alignment constraints, if any
421+
# First, get the fastest varying (contiguous) Dimension, which is the
422+
# one that matters for alignment
423+
if self._alignment:
424+
fvds = {f.dimensions[-1] for f in self.fmapper}
425+
if len(fvds) != 1:
426+
raise HaloSchemeException(
427+
"Unexpected contiguous Dimensions found while computing the "
428+
f"`omapper`: {fvds}"
429+
)
430+
fvd = fvds.pop()
431+
432+
for i, (where, mapper) in enumerate(list(processed)):
433+
try:
434+
m, M = mapper[fvd]
435+
except KeyError:
436+
continue
437+
438+
aligned_m = IntDiv(m, self._alignment) * self._alignment
439+
440+
processed[i] = (where, frozendict({**mapper, fvd: (aligned_m, M)}))
441+
398442
_, core = processed.pop(0)
399443
owned = processed
400444

devito/passes/iet/mpi.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from devito.mpi.reduction_scheme import DistReduce
1212
from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder
1313
from devito.passes.iet.engine import iet_pass
14-
from devito.symbolics import VectorAccess
14+
from devito.symbolics import VectorAccess, search
1515
from devito.tools import generator
1616
from devito.types import TensorMove
1717

@@ -287,11 +287,11 @@ def _mark_overlappable(iet):
287287
# Check legality. Comp/comm overlaps is legal only if the OWNED regions
288288
# can grow arbitrarily, which means all of the dependencies must be
289289
# carried along a non-halo Dimension
290-
expressions = FindNodes(Expression).visit(hs)
291-
if not expressions:
290+
exprs = FindNodes(Expression).visit(hs)
291+
if not exprs:
292292
continue
293293

294-
scope = Scope(i.expr for i in expressions)
294+
scope = Scope([n.expr for n in exprs])
295295

296296
for dep in scope.d_all_gen():
297297
if dep.function in hs.functions:
@@ -305,11 +305,28 @@ def _mark_overlappable(iet):
305305
# f[x, y] = ...
306306
break
307307
else:
308-
# All good!
308+
# All good -- we can perform comp/comm overlap!
309309
found.append(hs)
310310

311+
# The underlying `exprs` might have data alignment constraints due to the
312+
# presence of objects such as VectorAccess or TensorMove, which expect the
313+
# starting address of the data to be aligned to a certain value. Comp/comm
314+
# overlap creates multiple iteration spaces (for the core and owned
315+
# regions), which might break the alignment contract if we don't play safe
316+
# -- imposing these regions start at a carefully rounded-down point, at the
317+
# cost of potentially performing a bit of redundant compute
318+
mapper = {}
319+
for hs in found:
320+
exprs = [n.expr for n in FindNodes(Expression).visit(hs)]
321+
objs = search(exprs, (VectorAccess, TensorMove))
322+
alignment = max([i._expected_alignment for i in objs], default=None)
323+
324+
hsf = hs.halo_scheme._rebuild(alignment=alignment)
325+
hs1 = hs._rebuild(halo_scheme=hsf)
326+
327+
mapper[hs] = OverlappableHaloSpot(**hs1.args)
328+
311329
# Transform the IET replacing HaloSpots with OverlappableHaloSpots
312-
mapper = {hs: OverlappableHaloSpot(**hs.args) for hs in found}
313330
iet = Transformer(mapper, nested=True).visit(iet)
314331

315332
return iet

devito/symbolics/extended_sympy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def __new__(cls, lhs, rhs, params=None):
119119
elif rhs == 1 or rhs is None:
120120
return lhs
121121

122+
if is_integer(lhs) and is_integer(rhs):
123+
# Both sides are plain integers -- perform the division right away
124+
return lhs // rhs
125+
122126
if not is_integer(rhs):
123127
# Perhaps it's a symbolic RHS -- but we wanna be sure it's of type int
124128
if not hasattr(rhs, 'dtype'):
@@ -890,6 +894,12 @@ class VectorAccess(Expr, Pickable, BasicWrapperMixin):
890894
Represent a vector access operation at high-level.
891895
"""
892896

897+
_expected_alignment = 16
898+
"""
899+
The expected alignment in bytes for the accessed vector. This must be
900+
honored by the compiler for correctness.
901+
"""
902+
893903
def __new__(cls, *args, **kwargs):
894904
return Expr.__new__(cls, *args)
895905

devito/types/parallel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,12 @@ class TensorMove(Expr, Reserved, Terminal):
421421

422422
__rargs__ = ('base', 'tid0', 'coords')
423423

424+
_expected_alignment = 16
425+
"""
426+
The expected alignment in bytes for the accessed vector. This must be
427+
honored by the compiler for correctness.
428+
"""
429+
424430
def __new__(cls, base, tid0, coords, **kwargs):
425431
return super().__new__(cls, base, tid0, coords)
426432

0 commit comments

Comments
 (0)