1313from sympy .logic .boolalg import BooleanFunction
1414
1515from devito .ir .support .space import Forward , IterationDirection
16- from devito .symbolics import CondEq , CondNe , search
16+ from devito .symbolics import CondEq , CondNe , IntDiv , search
17+ from devito .symbolics .manipulation import _uxreplace_handle , _uxreplace_registry
1718from devito .tools import Pickable , as_tuple , frozendict , split
1819from devito .types import Dimension , LocalObject
1920
3132]
3233
3334
35+ @singledispatch
36+ def bound_index (expr , dim , dir ):
37+ if dir == Forward :
38+ return expr ._subs (dim , dim + 1 )
39+ else :
40+ return expr ._subs (dim , dim - 1 )
41+
42+
43+ @bound_index .register (Expr )
44+ def _ (expr , dim , dir ):
45+ if not expr .args :
46+ if dir == Forward :
47+ return expr ._subs (dim , dim + 1 )
48+ else :
49+ return expr ._subs (dim , dim - 1 )
50+ return expr .func (* [bound_index (a , dim , dir ) for a in expr .args ])
51+
52+
53+ @bound_index .register (IntDiv )
54+ def _ (expr , dim , dir ):
55+ v = dim .symbolic_factor
56+ p0 = dim .root
57+ if dir == Forward :
58+ return Mul ((((p0 + 1 ) + v - 1 ) / v ), v , evaluate = False )
59+ else :
60+ return (p0 - 1 ) - abs (p0 - 1 ) % v
61+
62+
3463class AbstractGuard :
3564 pass
3665
@@ -138,37 +167,29 @@ class BaseGuardBoundNext(Guard, Pickable):
138167 given `direction`.
139168 """
140169
141- __rargs__ = ('d' , 'direction' )
170+ __rargs__ = ('d' , 'index' , 'direction' )
171+ __rkwargs__ = ('d_min' , 'd_max' )
142172
143- def __new__ (cls , d , direction , ** kwargs ):
173+ def __new__ (cls , d , index , direction ,
174+ d_min = None , d_max = None , ** kwargs ):
144175 assert isinstance (d , Dimension )
145176 assert isinstance (direction , IterationDirection )
146177
147- if direction == Forward :
148- p0 = d .root
149- p1 = d .root .symbolic_max
178+ # Always take the next index in the iteration direction
179+ next_index = bound_index (index , d , direction )
150180
151- if d .is_Conditional :
152- v = d .symbolic_factor
153- # Round `p0 + 1` up to the nearest multiple of `v`
154- p0 = Mul ((((p0 + 1 ) + v - 1 ) / v ), v , evaluate = False )
155- else :
156- p0 = p0 + 1
181+ # The direction might be forward but accessing c - d
182+ # making the access backward w.r.t
183+ # Update direction according to access direction for valid guard
184+ if index .has (- d ):
185+ direction = - direction
157186
187+ if direction == Forward :
188+ p0 = next_index
189+ p1 = d_max or d .root .symbolic_max
158190 else :
159- p0 = d .root .symbolic_min
160- p1 = d .root
161-
162- if d .is_Conditional :
163- v = d .symbolic_factor
164- # Round `p1 - 1` down to the nearest sub-multiple of `v`
165- # NOTE: we use ABS to make sure we handle negative values properly.
166- # Once `p1 - 1` is negative (e.g. `iteration=time - 1` and `time=0`),
167- # as long as we get a negative number, rather than 0 and even if it's
168- # not `-v`, we're good
169- p1 = (p1 - 1 ) - abs (p1 - 1 ) % v
170- else :
171- p1 = p1 - 1
191+ p0 = d_min if d_min is not None else d .root .symbolic_min
192+ p1 = next_index
172193
173194 try :
174195 if cls .__base__ ._eval_relation (p0 , p1 ) is true :
@@ -180,12 +201,15 @@ def __new__(cls, d, direction, **kwargs):
180201
181202 obj .d = d
182203 obj .direction = direction
204+ obj .index = index
205+ obj .d_min = d_min
206+ obj .d_max = d_max
183207
184208 return obj
185209
186210 @property
187211 def _args_rebuild (self ):
188- return (self .d , self .direction )
212+ return (self .d , self .index , self . direction )
189213
190214
191215class GuardBoundNextLe (BaseGuardBoundNext , Le ):
@@ -544,3 +568,11 @@ def pairwise_or(*guards):
544568 pass
545569
546570 return guard
571+
572+
573+ _uxreplace_registry .register (BaseGuardBoundNext )
574+
575+
576+ @_uxreplace_handle .register (BaseGuardBoundNext )
577+ def _ (expr , args , kwargs ):
578+ return expr .func (expr .d , expr .index , expr .direction , ** kwargs )
0 commit comments