Skip to content

Commit 449a453

Browse files
committed
compiler: support mutli-buffering
1 parent 9cc52bb commit 449a453

8 files changed

Lines changed: 365 additions & 122 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,18 @@ def guard(clusters):
259259
# Separate out the indirect ConditionalDimensions, which only serve
260260
# the purpose of protecting from OOB accesses
261261
cds = [d for d in cds if not d.indirect]
262+
modes = [cd.relation for cd in cds]
263+
if modes.count('strict') > 1:
264+
print(modes, cds, {m == 'strict' for m in modes})
265+
raise CompilationError("Only one `strict` condition"
266+
"can be used in an equation")
267+
elif 'strict' in modes:
268+
mode = 'strict'
269+
else:
270+
mode = sympy.And if sympy.And in modes else sympy.Or
262271

263272
# Chain together all `cds` conditions from all expressions in `c`
264273
guards = {}
265-
mode = sympy.Or
266274
for cd in cds:
267275
# `BOTTOM` parent implies a guard that lives outside of
268276
# any iteration space, which corresponds to the placeholder None
@@ -279,7 +287,6 @@ def guard(clusters):
279287

280288
# Pull `cd` from any expr
281289
condition = guards.setdefault(k, [])
282-
mode = mode and cd.relation
283290
for e in exprs:
284291
try:
285292
condition.append(e.conditionals[cd])
@@ -296,7 +303,10 @@ def guard(clusters):
296303

297304
# Combination `mode` is And by default.
298305
# If all conditions are Or then Or combination `mode` is used.
299-
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
306+
if mode == 'strict':
307+
guards = {d: v[0] for d, v in guards.items()}
308+
else:
309+
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
300310

301311
# Construct a guarded Cluster
302312
processed.append(c.rebuild(exprs=exprs, guards=guards))

devito/ir/equations/equation.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,6 @@ def __new__(cls, *args, **kwargs):
237237
cond = d.relation(cond, GuardFactor(d))
238238
conditionals[d] = cond
239239

240-
# Replace the ConditionalDimensions in `expr`
241-
for d, cond in conditionals.items():
242-
# Replace dimension with index
243-
index = d.index
244-
index = index - relational_min(cond, d.parent)
245-
shift = relational_shift(cond, d.parent)
246-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
247-
248240
# Merge conditionals when possible. E.g if we have an implicit_dim
249241
# and there is a dimension with the same parent, we ca merged
250242
# its condition
@@ -254,14 +246,22 @@ def __new__(cls, *args, **kwargs):
254246
for cd in dict(conditionals):
255247
if cd.parent == d.parent and cd != d:
256248
cond = conditionals.pop(d)
257-
mode = cd.relation and d.relation
258-
if issubclass(mode, sympy.Or):
259-
conditionals[d] = cond
260-
conditionals.pop(cd)
249+
if d.relation == 'strict':
250+
conditionals[cd] = conditionals[d] = cond
261251
else:
252+
mode = cd.relation and d.relation
262253
conditionals[cd] = mode(cond, conditionals[cd])
263254
break
264255

256+
# Replace the ConditionalDimensions in `expr`
257+
for d, cond in conditionals.items():
258+
# Replace dimension with index
259+
index = d.index
260+
if d.condition is not None and d in expr.free_symbols:
261+
index = index - relational_min(cond, d.parent)
262+
shift = relational_shift(cond, d.parent)
263+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
264+
265265
# Lower all Differentiable operations into SymPy operations
266266
rhs = diff2sympy(expr.rhs)
267267

devito/passes/clusters/asynchrony.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22

3-
from sympy import true
3+
from sympy import Mod, true
44

55
from devito.ir import (
66
Backward, Forward, GuardBoundNext, PrefetchUpdate, Queue, ReleaseLock, SyncArray,
@@ -78,7 +78,8 @@ def callback(self, clusters, prefix):
7878
d = self.key0(c0)
7979
if d is not dim:
8080
continue
81-
81+
if d in c0.guards and not c0.guards[d].has(Mod):
82+
continue
8283
protected = self._schedule_waitlocks(c0, d, clusters, locks, syncs)
8384
self._schedule_withlocks(c0, d, protected, locks, syncs)
8485

0 commit comments

Comments
 (0)