Skip to content

Commit 8d4afb6

Browse files
committed
compiler: support mutli-buffering
1 parent cd1f1e4 commit 8d4afb6

14 files changed

Lines changed: 629 additions & 208 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,17 @@ def guard(clusters):
259259
# Separate out the indirect ConditionalDimensions, which only serve
260260
# the purpose of protecting from OOB accesses
261261
cds = [d for d in cds if not d.indirect]
262+
modes = [cd.relation for cd in cds]
263+
if modes.count('strict') > 1:
264+
raise CompilationError("Only one `strict` condition"
265+
"can be used in an equation")
266+
elif 'strict' in modes:
267+
mode = 'strict'
268+
else:
269+
mode = sympy.And if sympy.And in modes else sympy.Or
262270

263271
# Chain together all `cds` conditions from all expressions in `c`
264272
guards = {}
265-
mode = sympy.Or
266273
for cd in cds:
267274
# `BOTTOM` parent implies a guard that lives outside of
268275
# any iteration space, which corresponds to the placeholder None
@@ -279,7 +286,6 @@ def guard(clusters):
279286

280287
# Pull `cd` from any expr
281288
condition = guards.setdefault(k, [])
282-
mode = mode and cd.relation
283289
for e in exprs:
284290
try:
285291
condition.append(e.conditionals[cd])
@@ -296,7 +302,10 @@ def guard(clusters):
296302

297303
# Combination `mode` is And by default.
298304
# If all conditions are Or then Or combination `mode` is used.
299-
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
305+
if mode == 'strict':
306+
guards = {d: v[0] for d, v in guards.items()}
307+
else:
308+
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
300309

301310
# Construct a guarded Cluster
302311
processed.append(c.rebuild(exprs=exprs, guards=guards))

devito/ir/equations/equation.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from devito.tools import (
1414
Pickable, Tag, as_hashable, filter_sorted, frozendict, reuse_if_unchanged
1515
)
16-
from devito.symbolics import IntDiv, limits_mapper, uxreplace
17-
from devito.tools import Pickable, Tag, frozendict
1816
from devito.types import (
1917
Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min, relational_shift
2018
)
@@ -309,14 +307,6 @@ def __new__(cls, *args, **kwargs):
309307
cond = d.relation(cond, GuardFactor(d))
310308
conditionals[d] = cond
311309

312-
# Replace the ConditionalDimensions in `expr`
313-
for d, cond in conditionals.items():
314-
# Replace dimension with index
315-
index = d.index
316-
index = index - relational_min(cond, d.parent)
317-
shift = relational_shift(cond, d.parent)
318-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
319-
320310
# Merge conditionals when possible. E.g if we have an implicit_dim
321311
# and there is a dimension with the same parent, we ca merged
322312
# its condition
@@ -326,14 +316,22 @@ def __new__(cls, *args, **kwargs):
326316
for cd in dict(conditionals):
327317
if cd.parent == d.parent and cd != d:
328318
cond = conditionals.pop(d)
329-
mode = cd.relation and d.relation
330-
if issubclass(mode, sympy.Or):
331-
conditionals[d] = cond
332-
conditionals.pop(cd)
319+
if d.relation == 'strict':
320+
conditionals[cd] = conditionals[d] = cond
333321
else:
322+
mode = cd.relation and d.relation
334323
conditionals[cd] = mode(cond, conditionals[cd])
335324
break
336325

326+
# Replace the ConditionalDimensions in `expr`
327+
for d, cond in conditionals.items():
328+
# Replace dimension with index
329+
index = d.index
330+
if d.condition is not None and d in expr.free_symbols:
331+
index = index - relational_min(cond, d.parent)
332+
shift = relational_shift(cond, d.parent)
333+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
334+
337335
# Lower all Differentiable operations into SymPy operations
338336
rhs = diff2sympy(expr.rhs)
339337

devito/ir/support/guards.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from sympy.logic.boolalg import BooleanFunction
1414

1515
from devito.ir.support.space import Forward, IterationDirection
16-
from devito.symbolics import CondEq, CondNe, search
16+
from devito.symbolics import CondEq, CondNe, IntDiv, search
17+
from devito.symbolics.manipulation import _uxreplace_handle, _uxreplace_registry
1718
from devito.tools import Pickable, as_tuple, frozendict, split
1819
from devito.types import Dimension, LocalObject
1920

@@ -31,6 +32,34 @@
3132
]
3233

3334

35+
@singledispatch
36+
def bound_index(expr, dim, dir):
37+
if dir == Forward:
38+
return expr._subs(dim, dim + 1)
39+
else:
40+
return expr._subs(dim, dim - 1)
41+
42+
43+
@bound_index.register(Expr)
44+
def _(expr, dim, dir):
45+
if not expr.args:
46+
if dir == Forward:
47+
return expr._subs(dim, dim + 1)
48+
else:
49+
return expr._subs(dim, dim - 1)
50+
return expr.func(*[bound_index(a, dim, dir) for a in expr.args])
51+
52+
53+
@bound_index.register(IntDiv)
54+
def _(expr, dim, dir):
55+
v = dim.symbolic_factor
56+
p0 = dim.root
57+
if dir == Forward:
58+
return Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
59+
else:
60+
return (p0 - 1) - abs(p0 - 1) % v
61+
62+
3463
class AbstractGuard:
3564
pass
3665

@@ -138,37 +167,29 @@ class BaseGuardBoundNext(Guard, Pickable):
138167
given `direction`.
139168
"""
140169

141-
__rargs__ = ('d', 'direction')
170+
__rargs__ = ('d', 'index', 'direction')
171+
__rkwargs__ = ('d_min', 'd_max')
142172

143-
def __new__(cls, d, direction, **kwargs):
173+
def __new__(cls, d, index, direction,
174+
d_min=None, d_max=None, **kwargs):
144175
assert isinstance(d, Dimension)
145176
assert isinstance(direction, IterationDirection)
146177

147-
if direction == Forward:
148-
p0 = d.root
149-
p1 = d.root.symbolic_max
178+
# Always take the next index in the iteration direction
179+
next_index = bound_index(index, d, direction)
150180

151-
if d.is_Conditional:
152-
v = d.symbolic_factor
153-
# Round `p0 + 1` up to the nearest multiple of `v`
154-
p0 = Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
155-
else:
156-
p0 = p0 + 1
181+
# The direction might be forward but accessing c - d
182+
# making the access backward w.r.t
183+
# Update direction according to access direction for valid guard
184+
if index.has(-d):
185+
direction = -direction
157186

187+
if direction == Forward:
188+
p0 = next_index
189+
p1 = d_max or d.root.symbolic_max
158190
else:
159-
p0 = d.root.symbolic_min
160-
p1 = d.root
161-
162-
if d.is_Conditional:
163-
v = d.symbolic_factor
164-
# Round `p1 - 1` down to the nearest sub-multiple of `v`
165-
# NOTE: we use ABS to make sure we handle negative values properly.
166-
# Once `p1 - 1` is negative (e.g. `iteration=time - 1` and `time=0`),
167-
# as long as we get a negative number, rather than 0 and even if it's
168-
# not `-v`, we're good
169-
p1 = (p1 - 1) - abs(p1 - 1) % v
170-
else:
171-
p1 = p1 - 1
191+
p0 = d_min if d_min is not None else d.root.symbolic_min
192+
p1 = next_index
172193

173194
try:
174195
if cls.__base__._eval_relation(p0, p1) is true:
@@ -180,12 +201,15 @@ def __new__(cls, d, direction, **kwargs):
180201

181202
obj.d = d
182203
obj.direction = direction
204+
obj.index = index
205+
obj.d_min = d_min
206+
obj.d_max = d_max
183207

184208
return obj
185209

186210
@property
187211
def _args_rebuild(self):
188-
return (self.d, self.direction)
212+
return (self.d, self.index, self.direction)
189213

190214

191215
class GuardBoundNextLe(BaseGuardBoundNext, Le):
@@ -544,3 +568,11 @@ def pairwise_or(*guards):
544568
pass
545569

546570
return guard
571+
572+
573+
_uxreplace_registry.register(BaseGuardBoundNext)
574+
575+
576+
@_uxreplace_handle.register(BaseGuardBoundNext)
577+
def _(expr, args, kwargs):
578+
return expr.func(expr.d, expr.index, expr.direction, **kwargs)

devito/ir/support/space.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,14 @@ def __repr__(self):
628628
def __hash__(self):
629629
return hash(self._name)
630630

631+
def __neg__(self):
632+
if self._name == '++':
633+
return Backward
634+
elif self._name == '--':
635+
return Forward
636+
else:
637+
return Any
638+
631639

632640
Forward = IterationDirection('++')
633641
"""Forward iteration direction ('++')."""

0 commit comments

Comments
 (0)