Skip to content

Commit ab5242a

Browse files
committed
compiler: support mutli-buffering
1 parent 45c8a8e commit ab5242a

14 files changed

Lines changed: 537 additions & 162 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,17 @@ def guard(clusters):
259259
# Separate out the indirect ConditionalDimensions, which only serve
260260
# the purpose of protecting from OOB accesses
261261
cds = [d for d in cds if not d.indirect]
262+
modes = [cd.relation for cd in cds]
263+
if modes.count('strict') > 1:
264+
raise CompilationError("Only one `strict` condition"
265+
"can be used in an equation")
266+
elif 'strict' in modes:
267+
mode = 'strict'
268+
else:
269+
mode = sympy.And if sympy.And in modes else sympy.Or
262270

263271
# Chain together all `cds` conditions from all expressions in `c`
264272
guards = {}
265-
mode = sympy.Or
266273
for cd in cds:
267274
# `BOTTOM` parent implies a guard that lives outside of
268275
# any iteration space, which corresponds to the placeholder None
@@ -279,7 +286,6 @@ def guard(clusters):
279286

280287
# Pull `cd` from any expr
281288
condition = guards.setdefault(k, [])
282-
mode = mode and cd.relation
283289
for e in exprs:
284290
try:
285291
condition.append(e.conditionals[cd])
@@ -296,7 +302,10 @@ def guard(clusters):
296302

297303
# Combination `mode` is And by default.
298304
# If all conditions are Or then Or combination `mode` is used.
299-
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
305+
if mode == 'strict':
306+
guards = {d: v[0] for d, v in guards.items()}
307+
else:
308+
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
300309

301310
# Construct a guarded Cluster
302311
processed.append(c.rebuild(exprs=exprs, guards=guards))

devito/ir/equations/equation.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,6 @@ def __new__(cls, *args, **kwargs):
237237
cond = d.relation(cond, GuardFactor(d))
238238
conditionals[d] = cond
239239

240-
# Replace the ConditionalDimensions in `expr`
241-
for d, cond in conditionals.items():
242-
# Replace dimension with index
243-
index = d.index
244-
index = index - relational_min(cond, d.parent)
245-
shift = relational_shift(cond, d.parent)
246-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
247-
248240
# Merge conditionals when possible. E.g if we have an implicit_dim
249241
# and there is a dimension with the same parent, we ca merged
250242
# its condition
@@ -254,14 +246,22 @@ def __new__(cls, *args, **kwargs):
254246
for cd in dict(conditionals):
255247
if cd.parent == d.parent and cd != d:
256248
cond = conditionals.pop(d)
257-
mode = cd.relation and d.relation
258-
if issubclass(mode, sympy.Or):
259-
conditionals[d] = cond
260-
conditionals.pop(cd)
249+
if d.relation == 'strict':
250+
conditionals[cd] = conditionals[d] = cond
261251
else:
252+
mode = cd.relation and d.relation
262253
conditionals[cd] = mode(cond, conditionals[cd])
263254
break
264255

256+
# Replace the ConditionalDimensions in `expr`
257+
for d, cond in conditionals.items():
258+
# Replace dimension with index
259+
index = d.index
260+
if d.condition is not None and d in expr.free_symbols:
261+
index = index - relational_min(cond, d.parent)
262+
shift = relational_shift(cond, d.parent)
263+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
264+
265265
# Lower all Differentiable operations into SymPy operations
266266
rhs = diff2sympy(expr.rhs)
267267

devito/ir/support/guards.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sympy.logic.boolalg import BooleanFunction
1414

1515
from devito.ir.support.space import Forward, IterationDirection
16-
from devito.symbolics import CondEq, CondNe, search
16+
from devito.symbolics import CondEq, CondNe, IntDiv, search
1717
from devito.tools import Pickable, as_tuple, frozendict, split
1818
from devito.types import Dimension, LocalObject
1919

@@ -31,6 +31,34 @@
3131
]
3232

3333

34+
@singledispatch
35+
def bound_index(expr, dim, dir):
36+
if dir == Forward:
37+
return expr._subs(dim, dim + 1)
38+
else:
39+
return expr._subs(dim, dim - 1)
40+
41+
42+
@bound_index.register(Expr)
43+
def _(expr, dim, dir):
44+
if not expr.args:
45+
if dir == Forward:
46+
return expr._subs(dim, dim + 1)
47+
else:
48+
return expr._subs(dim, dim - 1)
49+
return expr.func(*[bound_index(a, dim, dir) for a in expr.args])
50+
51+
52+
@bound_index.register(IntDiv)
53+
def _(expr, dim, dir):
54+
v = dim.symbolic_factor
55+
p0 = dim.root
56+
if dir == Forward:
57+
return Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
58+
else:
59+
return (p0 - 1) - abs(p0 - 1) % v
60+
61+
3462
class AbstractGuard:
3563
pass
3664

@@ -138,37 +166,29 @@ class BaseGuardBoundNext(Guard, Pickable):
138166
given `direction`.
139167
"""
140168

141-
__rargs__ = ('d', 'direction')
169+
__rargs__ = ('d', 'index', 'direction')
170+
__rkwargs__ = ('d_min', 'd_max')
142171

143-
def __new__(cls, d, direction, **kwargs):
172+
def __new__(cls, d, index, direction,
173+
d_min=None, d_max=None, **kwargs):
144174
assert isinstance(d, Dimension)
145175
assert isinstance(direction, IterationDirection)
146176

147-
if direction == Forward:
148-
p0 = d.root
149-
p1 = d.root.symbolic_max
177+
# Always take the next index in the iteration direction
178+
next_index = bound_index(index, d, direction)
150179

151-
if d.is_Conditional:
152-
v = d.symbolic_factor
153-
# Round `p0 + 1` up to the nearest multiple of `v`
154-
p0 = Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
155-
else:
156-
p0 = p0 + 1
180+
# The direction might be forward but accessing c - d
181+
# making the access backward w.r.t
182+
# Update direction according to access direction for valid guard
183+
if index.has(-d):
184+
direction = -direction
157185

186+
if direction == Forward:
187+
p0 = next_index
188+
p1 = d_max or d.root.symbolic_max
158189
else:
159-
p0 = d.root.symbolic_min
160-
p1 = d.root
161-
162-
if d.is_Conditional:
163-
v = d.symbolic_factor
164-
# Round `p1 - 1` down to the nearest sub-multiple of `v`
165-
# NOTE: we use ABS to make sure we handle negative values properly.
166-
# Once `p1 - 1` is negative (e.g. `iteration=time - 1` and `time=0`),
167-
# as long as we get a negative number, rather than 0 and even if it's
168-
# not `-v`, we're good
169-
p1 = (p1 - 1) - abs(p1 - 1) % v
170-
else:
171-
p1 = p1 - 1
190+
p0 = d_min if d_min is not None else d.root.symbolic_min
191+
p1 = next_index
172192

173193
try:
174194
if cls.__base__._eval_relation(p0, p1) is true:
@@ -180,6 +200,9 @@ def __new__(cls, d, direction, **kwargs):
180200

181201
obj.d = d
182202
obj.direction = direction
203+
obj.index = index
204+
obj.d_min = d_min
205+
obj.d_max = d_max
183206

184207
return obj
185208

devito/ir/support/space.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,14 @@ def __repr__(self):
601601
def __hash__(self):
602602
return hash(self._name)
603603

604+
def __neg__(self):
605+
if self._name == '++':
606+
return Backward
607+
elif self._name == '--':
608+
return Forward
609+
else:
610+
return Any
611+
604612

605613
Forward = IterationDirection('++')
606614
"""Forward iteration direction ('++')."""

devito/passes/clusters/asynchrony.py

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,74 @@
11
from collections import defaultdict
2+
from functools import singledispatch
23

3-
from sympy import true
4+
from sympy import Expr, Mod, true
45

56
from devito.ir import (
67
Backward, Forward, GuardBoundNext, PrefetchUpdate, Queue, ReleaseLock, SyncArray,
78
WaitLock, WithLock, normalize_syncs
89
)
910
from devito.passes.clusters.utils import in_critical_region, is_memcpy
10-
from devito.symbolics import IntDiv, uxreplace
11+
from devito.symbolics import IntDiv, retrieve_terminals, uxreplace
1112
from devito.tools import OrderedSet, is_integer, timed_pass
12-
from devito.types import CustomDimension, Lock
13+
from devito.types import CustomDimension, Lock, VirtualDimension
1314

1415
__all__ = ['memcpy_prefetch', 'tasking']
1516

1617

18+
@singledispatch
19+
def next_index(expr, dim, dir):
20+
return expr._subs(dim, dim + dir)
21+
22+
23+
@next_index.register(Expr)
24+
def _(expr, dim, dir):
25+
if not expr.args:
26+
return expr._subs(dim, dim + dir)
27+
return expr.func(*[next_index(a, dim, dir) for a in expr.args])
28+
29+
30+
@next_index.register(IntDiv)
31+
def _(expr, dim, dir):
32+
"""
33+
Handle forward and backward fetches separately to handle non-canonical index
34+
expressions of the form:
35+
36+
t//factor + cond(t)
37+
38+
where ``cond(t)`` is a piecewise correction term.
39+
40+
The forward fetch advances to the next coarse-grained slot while evaluating
41+
the correction at the next time point:
42+
43+
t//factor + cond(t)
44+
-> (t//factor + 1) + cond(t + 1)
45+
46+
The backward fetch is not, in general, the inverse transformation obtained by
47+
replacing ``+1`` with ``-1``. The correction may already be applied at the
48+
current time point, causing the forward and backward fetches to be asymmetric.
49+
50+
For example, with ``factor=2`` and ``cond(t) := (t == a)``, the index at
51+
``t=a=3`` is:
52+
53+
3//2 + 1 = 2
54+
55+
while the previous index is:
56+
57+
2//2 + 0 = 1
58+
59+
A symmetric backward transformation would instead yield:
60+
61+
3//2 - 1 + 0 = 0
62+
"""
63+
if expr.lhs._defines & dim._defines:
64+
if dir == 1:
65+
return expr + dir
66+
else:
67+
return expr._subs(dim, dim + dir)
68+
else:
69+
return expr
70+
71+
1772
def async_trigger(c, dims):
1873
"""
1974
Return the Dimension in `c`'s IterationSpace that triggers the
@@ -78,7 +133,9 @@ def callback(self, clusters, prefix):
78133
d = self.key0(c0)
79134
if d is not dim:
80135
continue
81-
136+
g = c0.guards.get(d)
137+
if g is not None and (not g.has(Mod) and d in retrieve_terminals(g)):
138+
continue
82139
protected = self._schedule_waitlocks(c0, d, clusters, locks, syncs)
83140
self._schedule_withlocks(c0, d, protected, locks, syncs)
84141

@@ -240,7 +297,15 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry):
240297

241298
fetch = e.rhs.indices[d]
242299
fshift = {Forward: 1, Backward: -1}.get(direction, 0)
243-
findex = fetch + fshift if fetch.find(IntDiv) else fetch._subs(pd, pd + fshift)
300+
findex = next_index(fetch, pd, fshift)
301+
302+
# Maximum allowed access along d
303+
if function.dimensions[d].is_Conditional:
304+
nslot = function.dimension_shape[d]
305+
v = function.dimensions[d].symbolic_factor
306+
fd_max = v * (nslot - 1)
307+
else:
308+
fd_max = None
244309

245310
# If fetching into e.g. `ub[t1]` we might need to prefetch into e.g. `ub[t0]`
246311
tindex0 = e.lhs.indices[d]
@@ -271,8 +336,17 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry):
271336
ispace = c.ispace.augment({pd: tindex}) if tindex is not tindex0 else c.ispace
272337

273338
guard0 = c.guards.get(d, true)._subs(fetch, findex)
274-
guard1 = GuardBoundNext(function.indices[d], direction)
275-
guards = c.guards.impose(d, guard0 & guard1)
339+
guard1 = GuardBoundNext(function.indices[d], e.rhs.indices[d], direction,
340+
d_min=0, d_max=fd_max)
341+
342+
# First guard1 then if guard1 is valid we can safely evaluate guard0
343+
# that will have valid indices into f
344+
vdnext = VirtualDimension(name=f'vdnext_{d.name}', parent=pd)
345+
ispace = ispace.insert(pd, vdnext)
346+
# Check valid tindex first
347+
guards = c.guards.impose(d, guard1)
348+
# THen check valid access
349+
guards = guards.impose(vdnext, guard0)
276350

277351
syncs = {d: [
278352
ReleaseLock(handle, target),

0 commit comments

Comments
 (0)