11from collections import defaultdict
2+ from functools import singledispatch
23
3- from sympy import true
4+ from sympy import Expr , Mod , true
45
56from devito .ir import (
67 Backward , Forward , GuardBoundNext , PrefetchUpdate , Queue , ReleaseLock , SyncArray ,
78 WaitLock , WithLock , normalize_syncs
89)
910from devito .passes .clusters .utils import in_critical_region , is_memcpy
10- from devito .symbolics import IntDiv , uxreplace
11+ from devito .symbolics import IntDiv , retrieve_terminals , uxreplace
1112from devito .tools import OrderedSet , is_integer , timed_pass
12- from devito .types import CustomDimension , Lock
13+ from devito .types import CustomDimension , Lock , VirtualDimension
1314
1415__all__ = ['memcpy_prefetch' , 'tasking' ]
1516
1617
18+ @singledispatch
19+ def next_index (expr , dim , dir ):
20+ return expr ._subs (dim , dim + dir )
21+
22+
23+ @next_index .register (Expr )
24+ def _ (expr , dim , dir ):
25+ if not expr .args :
26+ return expr ._subs (dim , dim + dir )
27+ return expr .func (* [next_index (a , dim , dir ) for a in expr .args ])
28+
29+
30+ @next_index .register (IntDiv )
31+ def _ (expr , dim , dir ):
32+ """
33+ Handle forward and backward fetches separately to handle non-canonical index
34+ expressions of the form:
35+
36+ t//factor + cond(t)
37+
38+ where ``cond(t)`` is a piecewise correction term.
39+
40+ The forward fetch advances to the next coarse-grained slot while evaluating
41+ the correction at the next time point:
42+
43+ t//factor + cond(t)
44+ -> (t//factor + 1) + cond(t + 1)
45+
46+ The backward fetch is not, in general, the inverse transformation obtained by
47+ replacing ``+1`` with ``-1``. The correction may already be applied at the
48+ current time point, causing the forward and backward fetches to be asymmetric.
49+
50+ For example, with ``factor=2`` and ``cond(t) := (t == a)``, the index at
51+ ``t=a=3`` is:
52+
53+ 3//2 + 1 = 2
54+
55+ while the previous index is:
56+
57+ 2//2 + 0 = 1
58+
59+ A symmetric backward transformation would instead yield:
60+
61+ 3//2 - 1 + 0 = 0
62+ """
63+ if expr .lhs ._defines & dim ._defines :
64+ if dir == 1 :
65+ return expr + dir
66+ else :
67+ return expr ._subs (dim , dim + dir )
68+ else :
69+ return expr
70+
71+
1772def async_trigger (c , dims ):
1873 """
1974 Return the Dimension in `c`'s IterationSpace that triggers the
@@ -78,7 +133,9 @@ def callback(self, clusters, prefix):
78133 d = self .key0 (c0 )
79134 if d is not dim :
80135 continue
81-
136+ g = c0 .guards .get (d )
137+ if g is not None and (not g .has (Mod ) and d in retrieve_terminals (g )):
138+ continue
82139 protected = self ._schedule_waitlocks (c0 , d , clusters , locks , syncs )
83140 self ._schedule_withlocks (c0 , d , protected , locks , syncs )
84141
@@ -181,6 +238,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
181238 """
182239 _ , key = keys (key0 )
183240 actions = defaultdict (Actions )
241+ bounds = {}
184242
185243 for c in clusters :
186244 d = key (c )
@@ -191,9 +249,9 @@ def memcpy_prefetch(clusters, key0, sregistry):
191249 continue
192250
193251 if c .properties .is_prefetchable (d ._defines ):
194- _actions_from_update_memcpy (c , d , clusters , actions , sregistry )
252+ _actions_from_update_memcpy (c , d , clusters , actions , sregistry , bounds )
195253 elif d .is_Custom and is_integer (c .ispace [d ].size ):
196- _actions_from_init (c , d , clusters , actions )
254+ _actions_from_init (c , d , actions )
197255
198256 # Attach the computed Actions
199257 processed = []
@@ -214,7 +272,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
214272 return processed
215273
216274
217- def _actions_from_init (c , d , clusters , actions ):
275+ def _actions_from_init (c , d , actions ):
218276 e = c .exprs [0 ]
219277 function = e .rhs .function
220278 target = e .lhs .function
@@ -230,7 +288,7 @@ def _actions_from_init(c, d, clusters, actions):
230288 )
231289
232290
233- def _actions_from_update_memcpy (c , d , clusters , actions , sregistry ):
291+ def _actions_from_update_memcpy (c , d , clusters , actions , sregistry , bounds ):
234292 pd = d .root # E.g., `vd -> time`
235293 direction = c .ispace [pd ].direction
236294
@@ -240,7 +298,15 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry):
240298
241299 fetch = e .rhs .indices [d ]
242300 fshift = {Forward : 1 , Backward : - 1 }.get (direction , 0 )
243- findex = fetch + fshift if fetch .find (IntDiv ) else fetch ._subs (pd , pd + fshift )
301+ findex = next_index (fetch , pd , fshift )
302+
303+ # Maximum allowed access along d
304+ if function .dimensions [d ].is_Conditional :
305+ nslot = function .dimension_shape [d ]
306+ v = function .dimensions [d ].symbolic_factor
307+ fd_max = bounds .setdefault (d , v * (nslot - 1 ))
308+ else :
309+ fd_max = bounds .setdefault (d , function .dimension_shape [d ] - 1 )
244310
245311 # If fetching into e.g. `ub[t1]` we might need to prefetch into e.g. `ub[t0]`
246312 tindex0 = e .lhs .indices [d ]
@@ -271,8 +337,17 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry):
271337 ispace = c .ispace .augment ({pd : tindex }) if tindex is not tindex0 else c .ispace
272338
273339 guard0 = c .guards .get (d , true )._subs (fetch , findex )
274- guard1 = GuardBoundNext (function .indices [d ], direction )
275- guards = c .guards .impose (d , guard0 & guard1 )
340+ guard1 = GuardBoundNext (function .indices [d ], e .rhs .indices [d ], direction ,
341+ d_min = 0 , d_max = fd_max )
342+
343+ # First guard1 then if guard1 is valid we can safely evaluate guard0
344+ # that will have valid indices into f
345+ vdnext = VirtualDimension (name = f'vdnext_{ d .name } ' , parent = pd )
346+ ispace = ispace .insert (pd , vdnext )
347+ # Check valid tindex first
348+ guards = c .guards .impose (d , guard1 )
349+ # THen check valid access
350+ guards = guards .impose (vdnext , guard0 )
276351
277352 syncs = {d : [
278353 ReleaseLock (handle , target ),
0 commit comments