Skip to content

Commit ef708e5

Browse files
committed
compiler: fix various corner case of multi buffering
1 parent 91266de commit ef708e5

4 files changed

Lines changed: 23 additions & 21 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ def __new__(cls, *args, **kwargs):
237237
cond = d.relation(cond, GuardFactor(d))
238238
conditionals[d] = cond
239239

240+
# Replace the ConditionalDimensions in `expr`
241+
for d, cond in conditionals.items():
242+
# Replace dimension with index
243+
index = d.index
244+
index = index - relational_min(cond, d.parent)
245+
shift = relational_shift(cond, d.parent)
246+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
247+
240248
# Merge conditionals when possible. E.g if we have an implicit_dim
241249
# and there is a dimension with the same parent, we ca merged
242250
# its condition
@@ -247,19 +255,13 @@ def __new__(cls, *args, **kwargs):
247255
if cd.parent == d.parent and cd != d:
248256
cond = conditionals.pop(d)
249257
mode = cd.relation and d.relation
250-
conditionals[cd] = mode(cond, conditionals[cd])
258+
if issubclass(mode, sympy.Or):
259+
conditionals[d] = cond
260+
conditionals.pop(cd)
261+
else:
262+
conditionals[cd] = mode(cond, conditionals[cd])
251263
break
252264

253-
conditionals = frozendict(conditionals)
254-
255-
# Replace the ConditionalDimensions in `expr`
256-
for d, cond in conditionals.items():
257-
# Replace dimension with index
258-
index = d.index
259-
index = index - relational_min(cond, d.parent)
260-
shift = relational_shift(cond, d.parent)
261-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
262-
263265
# Lower all Differentiable operations into SymPy operations
264266
rhs = diff2sympy(expr.rhs)
265267

devito/passes/clusters/buffering.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from itertools import chain
44

55
import numpy as np
6-
from sympy import S, simplify
6+
from sympy import Mod, S, simplify
77

88
from devito.exceptions import CompilationError
99
from devito.ir import (
@@ -203,7 +203,7 @@ def callback(self, clusters, prefix):
203203
guards = c.guards
204204

205205
properties = c.properties.sequentialize(d)
206-
if not isinstance(d, BufferDimension):
206+
if not isinstance(d, BufferDimension) and c.guards[d].has(Mod):
207207
properties = properties.prefetchable(d)
208208
# `c` may be a HaloTouch Cluster, so with no vision of the `bdims`
209209
properties = properties.parallelize(v.bdims).affine(v.bdims)
@@ -377,7 +377,12 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
377377
buffer, = buffers
378378
xd = buffer.indices[dim]
379379
else:
380-
size = infer_buffer_size(f, dim, clusters)
380+
if len({c.guards[dim.root] for c in clusters}) > 1:
381+
# Multiple clusters with different guards,
382+
# will lead to conflicts in asynchrony with multiple (modulo) slots
383+
size = 1
384+
else:
385+
size = infer_buffer_size(f, dim, clusters)
381386

382387
if async_degree is not None:
383388
if async_degree < size:

devito/symbolics/extended_sympy.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from devito.types import Symbol
2020
from devito.types.basic import Basic
21-
from devito.types.relational import Ge
2221

2322
__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa
2423
'LeftShift', 'RightShift', 'IntDiv', 'Terminal', 'CallFromPointer',
@@ -48,11 +47,6 @@ def canonical(self):
4847
def negated(self):
4948
return CondNe(*self.args, evaluate=False)
5049

51-
@property
52-
def _as_min(self):
53-
from devito.symbolics.extended_dtypes import INT
54-
return INT(Ge(*self.args))
55-
5650

5751
class CondNe(sympy.Ne):
5852

devito/types/relational.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,4 +320,5 @@ def _(expr, s):
320320
def _(expr, s):
321321
if isinstance(expr.lhs, sympy.Mod):
322322
return 0
323-
return expr._as_min
323+
from devito.symbolics.extended_dtypes import INT
324+
return INT(Ge(*expr.args))

0 commit comments

Comments
 (0)