|
1 | 1 | from collections import defaultdict |
| 2 | +from functools import singledispatch |
2 | 3 |
|
3 | | -from sympy import true |
| 4 | +from sympy import Expr, Mod, true |
4 | 5 |
|
5 | 6 | from devito.ir import ( |
6 | 7 | Backward, Forward, GuardBoundNext, PrefetchUpdate, Queue, ReleaseLock, SyncArray, |
7 | 8 | WaitLock, WithLock, normalize_syncs |
8 | 9 | ) |
9 | 10 | from devito.passes.clusters.utils import in_critical_region, is_memcpy |
10 | | -from devito.symbolics import IntDiv, uxreplace |
| 11 | +from devito.symbolics import IntDiv, retrieve_terminals, uxreplace |
11 | 12 | from devito.tools import OrderedSet, is_integer, timed_pass |
12 | | -from devito.types import CustomDimension, Lock |
| 13 | +from devito.types import CustomDimension, Lock, VirtualDimension |
13 | 14 |
|
14 | 15 | __all__ = ['memcpy_prefetch', 'tasking'] |
15 | 16 |
|
16 | 17 |
|
| 18 | +@singledispatch |
| 19 | +def next_index(expr, dim, dir): |
| 20 | + return expr._subs(dim, dim + dir) |
| 21 | + |
| 22 | + |
| 23 | +@next_index.register(Expr) |
| 24 | +def _(expr, dim, dir): |
| 25 | + if not expr.args: |
| 26 | + return expr._subs(dim, dim + dir) |
| 27 | + return expr.func(*[next_index(a, dim, dir) for a in expr.args]) |
| 28 | + |
| 29 | + |
| 30 | +@next_index.register(IntDiv) |
| 31 | +def _(expr, dim, dir): |
| 32 | + """ |
| 33 | + Handle forward and backward fetches separately to handle non-canonical index |
| 34 | + expressions of the form: |
| 35 | +
|
| 36 | + t//factor + cond(t) |
| 37 | +
|
| 38 | + where ``cond(t)`` is a piecewise correction term. |
| 39 | +
|
| 40 | + The forward fetch advances to the next coarse-grained slot while evaluating |
| 41 | + the correction at the next time point: |
| 42 | +
|
| 43 | + t//factor + cond(t) |
| 44 | + -> (t//factor + 1) + cond(t + 1) |
| 45 | +
|
| 46 | + The backward fetch is not, in general, the inverse transformation obtained by |
| 47 | + replacing ``+1`` with ``-1``. The correction may already be applied at the |
| 48 | + current time point, causing the forward and backward fetches to be asymmetric. |
| 49 | +
|
| 50 | + For example, with ``factor=2`` and ``cond(t) := (t == a)``, the index at |
| 51 | + ``t=a=3`` is: |
| 52 | +
|
| 53 | + 3//2 + 1 = 2 |
| 54 | +
|
| 55 | + while the previous index is: |
| 56 | +
|
| 57 | + 2//2 + 0 = 1 |
| 58 | +
|
| 59 | + A symmetric backward transformation would instead yield: |
| 60 | +
|
| 61 | + 3//2 - 1 + 0 = 0 |
| 62 | + """ |
| 63 | + if expr.lhs._defines & dim._defines: |
| 64 | + if dir == 1: |
| 65 | + return expr + dir |
| 66 | + else: |
| 67 | + return expr._subs(dim, dim + dir) |
| 68 | + else: |
| 69 | + return expr |
| 70 | + |
| 71 | + |
17 | 72 | def async_trigger(c, dims): |
18 | 73 | """ |
19 | 74 | Return the Dimension in `c`'s IterationSpace that triggers the |
@@ -78,7 +133,9 @@ def callback(self, clusters, prefix): |
78 | 133 | d = self.key0(c0) |
79 | 134 | if d is not dim: |
80 | 135 | continue |
81 | | - |
| 136 | + g = c0.guards.get(d) |
| 137 | + if g is not None and (not g.has(Mod) and d in retrieve_terminals(g)): |
| 138 | + continue |
82 | 139 | protected = self._schedule_waitlocks(c0, d, clusters, locks, syncs) |
83 | 140 | self._schedule_withlocks(c0, d, protected, locks, syncs) |
84 | 141 |
|
@@ -240,7 +297,15 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry): |
240 | 297 |
|
241 | 298 | fetch = e.rhs.indices[d] |
242 | 299 | fshift = {Forward: 1, Backward: -1}.get(direction, 0) |
243 | | - findex = fetch + fshift if fetch.find(IntDiv) else fetch._subs(pd, pd + fshift) |
| 300 | + findex = next_index(fetch, pd, fshift) |
| 301 | + |
| 302 | + # Maximum allowed access along d |
| 303 | + if function.dimensions[d].is_Conditional: |
| 304 | + nslot = function.dimension_shape[d] |
| 305 | + v = function.dimensions[d].symbolic_factor |
| 306 | + fd_max = v * (nslot - 1) |
| 307 | + else: |
| 308 | + fd_max = None |
244 | 309 |
|
245 | 310 | # If fetching into e.g. `ub[t1]` we might need to prefetch into e.g. `ub[t0]` |
246 | 311 | tindex0 = e.lhs.indices[d] |
@@ -271,8 +336,17 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry): |
271 | 336 | ispace = c.ispace.augment({pd: tindex}) if tindex is not tindex0 else c.ispace |
272 | 337 |
|
273 | 338 | guard0 = c.guards.get(d, true)._subs(fetch, findex) |
274 | | - guard1 = GuardBoundNext(function.indices[d], direction) |
275 | | - guards = c.guards.impose(d, guard0 & guard1) |
| 339 | + guard1 = GuardBoundNext(function.indices[d], e.rhs.indices[d], direction, |
| 340 | + d_min=0, d_max=fd_max) |
| 341 | + |
| 342 | + # First guard1 then if guard1 is valid we can safely evaluate guard0 |
| 343 | + # that will have valid indices into f |
| 344 | + vdnext = VirtualDimension(name=f'vdnext_{d.name}', parent=pd) |
| 345 | + ispace = ispace.insert(pd, vdnext) |
| 346 | + # Check valid tindex first |
| 347 | + guards = c.guards.impose(d, guard1) |
| 348 | + # THen check valid access |
| 349 | + guards = guards.impose(vdnext, guard0) |
276 | 350 |
|
277 | 351 | syncs = {d: [ |
278 | 352 | ReleaseLock(handle, target), |
|
0 commit comments