Skip to content

Commit bb9d272

Browse files
committed
compiler: support mutli-buffering
1 parent 02b311b commit bb9d272

14 files changed

Lines changed: 629 additions & 206 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,17 @@ def guard(clusters):
259259
# Separate out the indirect ConditionalDimensions, which only serve
260260
# the purpose of protecting from OOB accesses
261261
cds = [d for d in cds if not d.indirect]
262+
modes = [cd.relation for cd in cds]
263+
if modes.count('strict') > 1:
264+
raise CompilationError("Only one `strict` condition"
265+
"can be used in an equation")
266+
elif 'strict' in modes:
267+
mode = 'strict'
268+
else:
269+
mode = sympy.And if sympy.And in modes else sympy.Or
262270

263271
# Chain together all `cds` conditions from all expressions in `c`
264272
guards = {}
265-
mode = sympy.Or
266273
for cd in cds:
267274
# `BOTTOM` parent implies a guard that lives outside of
268275
# any iteration space, which corresponds to the placeholder None
@@ -279,7 +286,6 @@ def guard(clusters):
279286

280287
# Pull `cd` from any expr
281288
condition = guards.setdefault(k, [])
282-
mode = mode and cd.relation
283289
for e in exprs:
284290
try:
285291
condition.append(e.conditionals[cd])
@@ -296,7 +302,10 @@ def guard(clusters):
296302

297303
# Combination `mode` is And by default.
298304
# If all conditions are Or then Or combination `mode` is used.
299-
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
305+
if mode == 'strict':
306+
guards = {d: v[0] for d, v in guards.items()}
307+
else:
308+
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
300309

301310
# Construct a guarded Cluster
302311
processed.append(c.rebuild(exprs=exprs, guards=guards))

devito/ir/equations/equation.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,6 @@ def __new__(cls, *args, **kwargs):
237237
cond = d.relation(cond, GuardFactor(d))
238238
conditionals[d] = cond
239239

240-
# Replace the ConditionalDimensions in `expr`
241-
for d, cond in conditionals.items():
242-
# Replace dimension with index
243-
index = d.index
244-
index = index - relational_min(cond, d.parent)
245-
shift = relational_shift(cond, d.parent)
246-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
247-
248240
# Merge conditionals when possible. E.g if we have an implicit_dim
249241
# and there is a dimension with the same parent, we ca merged
250242
# its condition
@@ -254,14 +246,22 @@ def __new__(cls, *args, **kwargs):
254246
for cd in dict(conditionals):
255247
if cd.parent == d.parent and cd != d:
256248
cond = conditionals.pop(d)
257-
mode = cd.relation and d.relation
258-
if issubclass(mode, sympy.Or):
259-
conditionals[d] = cond
260-
conditionals.pop(cd)
249+
if d.relation == 'strict':
250+
conditionals[cd] = conditionals[d] = cond
261251
else:
252+
mode = cd.relation and d.relation
262253
conditionals[cd] = mode(cond, conditionals[cd])
263254
break
264255

256+
# Replace the ConditionalDimensions in `expr`
257+
for d, cond in conditionals.items():
258+
# Replace dimension with index
259+
index = d.index
260+
if d.condition is not None and d in expr.free_symbols:
261+
index = index - relational_min(cond, d.parent)
262+
shift = relational_shift(cond, d.parent)
263+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
264+
265265
# Lower all Differentiable operations into SymPy operations
266266
rhs = diff2sympy(expr.rhs)
267267

devito/ir/support/guards.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from sympy.logic.boolalg import BooleanFunction
1414

1515
from devito.ir.support.space import Forward, IterationDirection
16-
from devito.symbolics import CondEq, CondNe, search
16+
from devito.symbolics import CondEq, CondNe, IntDiv, search
17+
from devito.symbolics.manipulation import _uxreplace_handle, _uxreplace_registry
1718
from devito.tools import Pickable, as_tuple, frozendict, split
1819
from devito.types import Dimension, LocalObject
1920

@@ -31,6 +32,34 @@
3132
]
3233

3334

35+
@singledispatch
36+
def bound_index(expr, dim, dir):
37+
if dir == Forward:
38+
return expr._subs(dim, dim + 1)
39+
else:
40+
return expr._subs(dim, dim - 1)
41+
42+
43+
@bound_index.register(Expr)
44+
def _(expr, dim, dir):
45+
if not expr.args:
46+
if dir == Forward:
47+
return expr._subs(dim, dim + 1)
48+
else:
49+
return expr._subs(dim, dim - 1)
50+
return expr.func(*[bound_index(a, dim, dir) for a in expr.args])
51+
52+
53+
@bound_index.register(IntDiv)
54+
def _(expr, dim, dir):
55+
v = dim.symbolic_factor
56+
p0 = dim.root
57+
if dir == Forward:
58+
return Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
59+
else:
60+
return (p0 - 1) - abs(p0 - 1) % v
61+
62+
3463
class AbstractGuard:
3564
pass
3665

@@ -138,37 +167,29 @@ class BaseGuardBoundNext(Guard, Pickable):
138167
given `direction`.
139168
"""
140169

141-
__rargs__ = ('d', 'direction')
170+
__rargs__ = ('d', 'index', 'direction')
171+
__rkwargs__ = ('d_min', 'd_max')
142172

143-
def __new__(cls, d, direction, **kwargs):
173+
def __new__(cls, d, index, direction,
174+
d_min=None, d_max=None, **kwargs):
144175
assert isinstance(d, Dimension)
145176
assert isinstance(direction, IterationDirection)
146177

147-
if direction == Forward:
148-
p0 = d.root
149-
p1 = d.root.symbolic_max
178+
# Always take the next index in the iteration direction
179+
next_index = bound_index(index, d, direction)
150180

151-
if d.is_Conditional:
152-
v = d.symbolic_factor
153-
# Round `p0 + 1` up to the nearest multiple of `v`
154-
p0 = Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
155-
else:
156-
p0 = p0 + 1
181+
# The direction might be forward but accessing c - d
182+
# making the access backward w.r.t
183+
# Update direction according to access direction for valid guard
184+
if index.has(-d):
185+
direction = -direction
157186

187+
if direction == Forward:
188+
p0 = next_index
189+
p1 = d_max or d.root.symbolic_max
158190
else:
159-
p0 = d.root.symbolic_min
160-
p1 = d.root
161-
162-
if d.is_Conditional:
163-
v = d.symbolic_factor
164-
# Round `p1 - 1` down to the nearest sub-multiple of `v`
165-
# NOTE: we use ABS to make sure we handle negative values properly.
166-
# Once `p1 - 1` is negative (e.g. `iteration=time - 1` and `time=0`),
167-
# as long as we get a negative number, rather than 0 and even if it's
168-
# not `-v`, we're good
169-
p1 = (p1 - 1) - abs(p1 - 1) % v
170-
else:
171-
p1 = p1 - 1
191+
p0 = d_min if d_min is not None else d.root.symbolic_min
192+
p1 = next_index
172193

173194
try:
174195
if cls.__base__._eval_relation(p0, p1) is true:
@@ -180,12 +201,15 @@ def __new__(cls, d, direction, **kwargs):
180201

181202
obj.d = d
182203
obj.direction = direction
204+
obj.index = index
205+
obj.d_min = d_min
206+
obj.d_max = d_max
183207

184208
return obj
185209

186210
@property
187211
def _args_rebuild(self):
188-
return (self.d, self.direction)
212+
return (self.d, self.index, self.direction)
189213

190214

191215
class GuardBoundNextLe(BaseGuardBoundNext, Le):
@@ -541,3 +565,11 @@ def pairwise_or(*guards):
541565
pass
542566

543567
return guard
568+
569+
570+
_uxreplace_registry.register(BaseGuardBoundNext)
571+
572+
573+
@_uxreplace_handle.register(BaseGuardBoundNext)
574+
def _(expr, args, kwargs):
575+
return expr.func(expr.d, expr.index, expr.direction, **kwargs)

devito/ir/support/space.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,14 @@ def __repr__(self):
601601
def __hash__(self):
602602
return hash(self._name)
603603

604+
def __neg__(self):
605+
if self._name == '++':
606+
return Backward
607+
elif self._name == '--':
608+
return Forward
609+
else:
610+
return Any
611+
604612

605613
Forward = IterationDirection('++')
606614
"""Forward iteration direction ('++')."""

0 commit comments

Comments
 (0)