99
1010from devito .data import CENTER , CORE , LEFT , OWNED , RIGHT
1111from devito .ir .support import Forward , Scope
12+ from devito .symbolics import IntDiv
1213from devito .symbolics .manipulation import _uxreplace_registry
1314from devito .tools import (
1415 EnrichedTuple , Reconstructable , Tag , as_tuple , filter_ordered , filter_sorted , flatten ,
@@ -147,6 +148,10 @@ def __init__(self, exprs, ispace):
147148 self ._honored [i .root ] = frozenset ([(ltk , rtk )])
148149 self ._honored = frozendict (self ._honored )
149150
151+ # Further constraints on the `omapper` derivation. At construction time
152+ # there's none, but lowering passes may change this
153+ self ._alignment = None
154+
150155 def __repr__ (self ):
151156 fnames = "," .join (i .name for i in set (self ._mapper ))
152157 return f"HaloScheme<{ fnames } >"
@@ -162,7 +167,7 @@ def __len__(self):
162167 def __hash__ (self ):
163168 return hash ((self ._mapper .__hash__ (), self .honored .__hash__ ()))
164169
165- def _rebuild (self , fmapper = None , honored = None ):
170+ def _rebuild (self , fmapper = None , honored = None , alignment = None ):
166171 """
167172 Rebuild a HaloScheme from the provided `fmapper` and `honored`. Reuse
168173 `self`'s values for the missing arguments.
@@ -176,6 +181,7 @@ def _rebuild(self, fmapper=None, honored=None):
176181
177182 obj ._mapper = frozendict (fmapper )
178183 obj ._honored = frozendict (honored )
184+ obj ._alignment = alignment or self ._alignment
179185
180186 return obj
181187
@@ -248,10 +254,14 @@ def is_void(self):
248254 @cached_property
249255 def omapper (self ):
250256 """
251- Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions.
257+ Logical decomposition of the DOMAIN region into OWNED and CORE sub-regions,
258+ "cumulative" over all DiscreteFunctions in the HaloScheme.
259+
260+ The computed OMapper takes into account:
252261
253- This is "cumulative" over all DiscreteFunctions in the HaloScheme; it also
254- takes into account IterationSpace offsets induced by SubDomains/SubDimensions.
262+ * The offsets induced by SubDomains/SubDimensions ("thickness");
263+ * Any data alignment requirement of the underlying expressions
264+ (`_alignment` attribute).
255265
256266 Examples
257267 --------
@@ -373,28 +383,62 @@ def omapper(self):
373383
374384 if s is CENTER :
375385 where .append ((d , CORE , s ))
376- mapper [d ] = (d .symbolic_min + osl ,
377- d .symbolic_max - osr )
386+
387+ mapper [d ] = (
388+ d .symbolic_min + osl ,
389+ d .symbolic_max - osr
390+ )
391+
378392 if nl != 0 :
379393 mapper [nl ] = (Max (nl - osl , 0 ),)
380394 if nr != 0 :
381395 mapper [nr ] = (Max (nr - osr , 0 ),)
382396 else :
383397 where .append ((d , OWNED , s ))
398+
384399 if s is LEFT :
385- mapper [d ] = (d .symbolic_min ,
386- Min (d .symbolic_min + osl - 1 , d .symbolic_max - nr ))
400+ mapper [d ] = (
401+ d .symbolic_min ,
402+ Min (d .symbolic_min + osl - 1 , d .symbolic_max - nr )
403+ )
404+
387405 if nl != 0 :
388406 mapper [nl ] = (nl ,)
389407 mapper [nr ] = (0 ,)
390408 else :
391- mapper [d ] = (Max (d .symbolic_max - osr + 1 , d .symbolic_min + nl ),
392- d .symbolic_max )
409+ mapper [d ] = (
410+ Max (d .symbolic_max - osr + 1 , d .symbolic_min + nl ),
411+ d .symbolic_max
412+ )
413+
393414 if nr != 0 :
394415 mapper [nl ] = (0 ,)
395416 mapper [nr ] = (nr ,)
417+
396418 processed .append ((tuple (where ), frozendict (mapper )))
397419
420+ # Apply the alignment constraints, if any
421+ # First, get the fastest varying (contiguous) Dimension, which is the
422+ # one that matters for alignment
423+ if self ._alignment :
424+ fvds = {f .dimensions [- 1 ] for f in self .fmapper }
425+ if len (fvds ) != 1 :
426+ raise HaloSchemeException (
427+ "Unexpected contiguous Dimensions found while computing the "
428+ f"`omapper`: { fvds } "
429+ )
430+ fvd = fvds .pop ()
431+
432+ for i , (where , mapper ) in enumerate (list (processed )):
433+ try :
434+ m , M = mapper [fvd ]
435+ except KeyError :
436+ continue
437+
438+ aligned_m = IntDiv (m , self ._alignment ) * self ._alignment
439+
440+ processed [i ] = (where , frozendict ({** mapper , fvd : (aligned_m , M )}))
441+
398442 _ , core = processed .pop (0 )
399443 owned = processed
400444
0 commit comments