77import sympy
88from sympy import Max , Min
99
10- from devito import configuration
1110from devito .data import CENTER , CORE , LEFT , OWNED , RIGHT
1211from devito .ir .support import Forward , Scope
12+ from devito .symbolics import IntDiv
1313from devito .symbolics .manipulation import _uxreplace_registry
1414from devito .tools import (
1515 EnrichedTuple , Reconstructable , Tag , as_tuple , filter_ordered , filter_sorted , flatten ,
@@ -137,11 +137,9 @@ def __init__(self, exprs, ispace):
137137 # Derive the halo exchanges
138138 self ._mapper = frozendict (classify (exprs , ispace ))
139139
140- # Track the IterationSpace offsets induced by SubDomains/SubDimensions.
141- # These should be honored in the derivation of the `omapper`
140+ # Track the IterationSpace offsets induced by SubDomains/SubDimensions,
141+ # which are honored in the derivation of the `omapper`
142142 self ._honored = {}
143- # SubDimensions are not necessarily included directly in
144- # ispace.dimensions and hence we need to first utilize the `_defines` method
145143 dims = set ().union (* [d ._defines for d in ispace .dimensions
146144 if d ._defines & self .dimensions ])
147145 subdims = [d for d in dims if d .is_Sub and not d .local ]
@@ -150,6 +148,12 @@ def __init__(self, exprs, ispace):
150148 self ._honored [i .root ] = frozenset ([(ltk , rtk )])
151149 self ._honored = frozendict (self ._honored )
152150
151+ # Further constraints on the `omapper` derivation. At construction time
152+ # there's none, but lowering passes may change this
153+ # * `_alignment` may be a positive integer representing the alignment
154+ # requirement, in number of *elements*, of the underlying expressions
155+ self ._alignment = None
156+
153157 def __repr__ (self ):
154158 fnames = "," .join (i .name for i in set (self ._mapper ))
155159 return f"HaloScheme<{ fnames } >"
@@ -165,11 +169,22 @@ def __len__(self):
165169 def __hash__ (self ):
166170 return hash ((self ._mapper .__hash__ (), self .honored .__hash__ ()))
167171
168- @classmethod
169- def build (cls , fmapper , honored ):
172+ def _rebuild (self , fmapper = None , honored = None , alignment = None ):
173+ """
174+ Rebuild a HaloScheme from the provided `fmapper` and `honored`. Reuse
175+ `self`'s values for the missing arguments.
176+ """
170177 obj = object .__new__ (HaloScheme )
178+
179+ if fmapper is None :
180+ fmapper = self ._mapper
181+ if honored is None :
182+ honored = self ._honored
183+
171184 obj ._mapper = frozendict (fmapper )
172185 obj ._honored = frozendict (honored )
186+ obj ._alignment = alignment or self ._alignment
187+
173188 return obj
174189
175190 @classmethod
@@ -223,7 +238,7 @@ def union(self, halo_schemes):
223238 for d , v in i .honored .items ():
224239 honored [d ] = honored .get (d , frozenset ()) | v
225240
226- return HaloScheme . build (fmapper , honored )
241+ return i . _rebuild (fmapper = fmapper , honored = honored )
227242
228243 @property
229244 def honored (self ):
@@ -241,10 +256,14 @@ def is_void(self):
241256 @cached_property
242257 def omapper (self ):
243258 """
244- Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions.
259+ Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions,
260+ "cumulative" over all DiscreteFunctions in the HaloScheme.
245261
246- This is "cumulative" over all DiscreteFunctions in the HaloScheme; it also
247- takes into account IterationSpace offsets induced by SubDomains/SubDimensions.
262+ The computed OMapper takes into account:
263+
264+ * The offsets induced by SubDomains/SubDimensions ("thickness");
265+ * Any data alignment requirement of the underlying expressions
266+ (`_alignment` attribute).
248267
249268 Examples
250269 --------
@@ -366,28 +385,62 @@ def omapper(self):
366385
367386 if s is CENTER :
368387 where .append ((d , CORE , s ))
369- mapper [d ] = (d .symbolic_min + osl ,
370- d .symbolic_max - osr )
388+
389+ mapper [d ] = (
390+ d .symbolic_min + osl ,
391+ d .symbolic_max - osr
392+ )
393+
371394 if nl != 0 :
372395 mapper [nl ] = (Max (nl - osl , 0 ),)
373396 if nr != 0 :
374397 mapper [nr ] = (Max (nr - osr , 0 ),)
375398 else :
376399 where .append ((d , OWNED , s ))
400+
377401 if s is LEFT :
378- mapper [d ] = (d .symbolic_min ,
379- Min (d .symbolic_min + osl - 1 , d .symbolic_max - nr ))
402+ mapper [d ] = (
403+ d .symbolic_min ,
404+ Min (d .symbolic_min + osl - 1 , d .symbolic_max - nr )
405+ )
406+
380407 if nl != 0 :
381408 mapper [nl ] = (nl ,)
382409 mapper [nr ] = (0 ,)
383410 else :
384- mapper [d ] = (Max (d .symbolic_max - osr + 1 , d .symbolic_min + nl ),
385- d .symbolic_max )
411+ mapper [d ] = (
412+ Max (d .symbolic_max - osr + 1 , d .symbolic_min + nl ),
413+ d .symbolic_max
414+ )
415+
386416 if nr != 0 :
387417 mapper [nl ] = (0 ,)
388418 mapper [nr ] = (nr ,)
419+
389420 processed .append ((tuple (where ), frozendict (mapper )))
390421
422+ # Apply the alignment constraints, if any
423+ # First, get the fastest varying (contiguous) Dimension, which is the
424+ # one that matters for alignment
425+ if self ._alignment :
426+ fvds = {f .dimensions [- 1 ] for f in self .fmapper }
427+ if len (fvds ) != 1 :
428+ raise HaloSchemeException (
429+ "Unexpected contiguous Dimensions found while computing the "
430+ f"`omapper`: { fvds } "
431+ )
432+ fvd = fvds .pop ()
433+
434+ for i , (where , mapper ) in enumerate (list (processed )):
435+ try :
436+ m , M = mapper [fvd ]
437+ except KeyError :
438+ continue
439+
440+ aligned_m = IntDiv (m , self ._alignment ) * self ._alignment
441+
442+ processed [i ] = (where , frozendict ({** mapper , fvd : (aligned_m , M )}))
443+
391444 _ , core = processed .pop (0 )
392445 owned = processed
393446
@@ -483,15 +536,15 @@ def project(self, functions):
483536 to the provided `functions`.
484537 """
485538 fmapper = {f : v for f , v in self .fmapper .items () if f in as_tuple (functions )}
486- return HaloScheme . build (fmapper , self . honored )
539+ return self . _rebuild (fmapper = fmapper )
487540
488541 def drop (self , functions ):
489542 """
490543 Create a new HaloScheme that contains all entries in `self` except those
491544 corresponding to the provided `functions`.
492545 """
493546 fmapper = {f : v for f , v in self .fmapper .items () if f not in as_tuple (functions )}
494- return HaloScheme . build (fmapper , self . honored )
547+ return self . _rebuild (fmapper = fmapper )
495548
496549 def add (self , f , hse ):
497550 """
@@ -503,7 +556,7 @@ def add(self, f, hse):
503556 if f in fmapper :
504557 hse = fmapper [f ].union (hse )
505558 fmapper [f ] = hse
506- return HaloScheme . build (fmapper , self . honored )
559+ return self . _rebuild (fmapper = fmapper )
507560
508561 def merge (self , hs ):
509562 """
@@ -512,20 +565,14 @@ def merge(self, hs):
512565 fmapper = dict (self .fmapper )
513566 for f , hse in hs .fmapper .items ():
514567 fmapper [f ] = fmapper .get (f , hse ).merge (hse )
515- return HaloScheme . build (fmapper , self . honored )
568+ return self . _rebuild (fmapper = fmapper )
516569
517570
518571def classify (exprs , ispace ):
519572 """
520573 Produce the mapper `Function -> HaloSchemeEntry`, which describes the necessary
521574 halo exchanges in the given Scope.
522575 """
523-
524- # Some MPI modes require pulling the `loc_indices` from the reads, others
525- # from the writes. It essentially depends on whether the halo exchange is
526- # performed before (reads) or after (writes) the OWNED region is computed
527- loc_indices_from_reads = configuration ['mpi' ] not in ('dual' ,)
528-
529576 scope = Scope (exprs )
530577
531578 mapper = {}
@@ -565,15 +612,17 @@ def classify(exprs, ispace):
565612 else :
566613 v [(d , LEFT )] = STENCIL
567614 v [(d , RIGHT )] = STENCIL
568- elif loc_indices_from_reads :
615+ else :
569616 v [(d , i [d ])] = NONE
570617
571618 # Does `i` actually require a halo exchange?
572619 if not any (hl is STENCIL for hl in v .values ()):
573620 continue
574621
575622 # Derive diagonal halo exchanges from the previous analysis
576- combs = list (product ([LEFT , CENTER , RIGHT ], repeat = len (f ._dist_dimensions )))
623+ combs = list (
624+ product ([LEFT , CENTER , RIGHT ], repeat = len (f ._dist_dimensions ))
625+ )
577626 combs .remove ((CENTER ,)* len (f ._dist_dimensions ))
578627 for c in combs :
579628 key = (f ._dist_dimensions , c )
@@ -598,13 +647,6 @@ def classify(exprs, ispace):
598647 if not halo_labels :
599648 continue
600649
601- # Augment `halo_labels` with `loc_indices`-related information if necessary
602- if not loc_indices_from_reads :
603- for i in scope .writes .get (f , []):
604- for d in i .findices :
605- if not f .grid .is_distributed (d ):
606- halo_labels [(d , i [d ])].add (NONE )
607-
608650 # Separate halo-exchange Dimensions from `loc_indices`
609651 raw_loc_indices , halos = defaultdict (list ), []
610652 for (d , s ), hl in halo_labels .items ():
@@ -613,15 +655,18 @@ def classify(exprs, ispace):
613655 if not hl :
614656 continue
615657 elif len (hl ) > 1 :
616- raise HaloSchemeException ("Inconsistency found while building a halo "
617- f"scheme for `{ f } ` along Dimension `{ d } `" )
658+ raise HaloSchemeException (
659+ "Inconsistency found while building a halo scheme for "
660+ f"`{ f } ` along Dimension `{ d } `" )
618661 elif hl .pop () is STENCIL :
619662 halos .append (Halo (d , s ))
620663 elif d ._defines & set (ispace .itdims ):
621664 raw_loc_indices [d ].append (s )
622665
623- loc_indices , loc_dirs = process_loc_indices (raw_loc_indices ,
624- ispace .directions )
666+ loc_indices , loc_dirs = process_loc_indices (
667+ raw_loc_indices , ispace .directions
668+ )
669+
625670 mapper [f ] = HaloSchemeEntry (loc_indices , loc_dirs , halos , dims )
626671
627672 return mapper
0 commit comments