Skip to content

Commit 81e9f7a

Browse files
committed
compiler: catch corner case read after write
1 parent 3d5bb25 commit 81e9f7a

3 files changed

Lines changed: 32 additions & 6 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
165165
# Schedule Clusters over different IterationSpaces if this increases
166166
# parallelism
167167
for i in range(1, len(clusters)):
168-
if self._break_for_parallelism(scope, dim, i):
168+
if self._break_for_parallelism(scope, dim, i, prefix[:-1]):
169169
return self.callback(clusters[:i], prefix, clusters[i:] + backlog,
170170
candidates | known_break)
171171

@@ -204,7 +204,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
204204

205205
return processed + self.callback(backlog, prefix)
206206

207-
def _break_for_parallelism(self, scope, dim, timestamp):
207+
def _break_for_parallelism(self, scope, dim, timestamp, prev):
208208
candidates = dim._defines
209209

210210
# Do not fission for data locality reasons if there's enough potential
@@ -228,6 +228,12 @@ def _break_for_parallelism(self, scope, dim, timestamp):
228228
# Would break a dependence on storage
229229
return False
230230

231+
if any(dep.distance_mapper.get(d, 0) != 0 for d in candidates) and \
232+
prev and all(dep.distance_mapper.get(d.dim, -1) == 0 for d in prev) and \
233+
dep.read.function is not dep.write.function:
234+
# Cannot read/write with `dim` since all previous ones are dependent
235+
return True
236+
231237
if any(dep.is_carried(i) for i in candidates):
232238
test0 = dep.is_flow and dep.is_lex_negative
233239
test1 = dep.is_anti and dep.is_lex_positive

devito/ir/support/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,7 @@ def d_flow_gen(self):
11051105
continue
11061106

11071107
distance = dependence.distance
1108+
11081109
try:
11091110
is_flow = distance > 0 or (r.lex_ge(w) and distance == 0)
11101111
except TypeError:

tests/test_dse.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
get_params, skipif
1010
)
1111
from devito import ( # noqa
12-
NODE, Abs, ConditionalDimension, Constant, DefaultDimension, Derivative, Dimension,
13-
Eq, Function, Ge, Grid, Inc, Lt, Operator, SparseTimeFunction, SubDimension,
14-
TimeFunction, configuration, cos, dimensions, div, exp, first_derivative, floor, grad,
15-
norm, sin, solve, sqrt, switchconfig, transpose
12+
NODE, Abs, Buffer, ConditionalDimension, Constant, DefaultDimension, Derivative,
13+
Dimension, Eq, Function, Ge, Grid, Inc, Lt, Operator, SparseTimeFunction,
14+
SubDimension, TimeFunction, configuration, cos, dimensions, div, exp,
15+
first_derivative, floor, grad, norm, sin, solve, sqrt, switchconfig, transpose
1616
)
1717
from devito.exceptions import InvalidArgument, InvalidOperator
1818
from devito.ir import (
@@ -58,6 +58,25 @@ def test_scheduling_after_rewrite():
5858
assert all(trees[1].root.dim is tree.root.dim for tree in trees[1:])
5959

6060

61+
def test_scheduling_no_deriv():
62+
grid = Grid((11, 11, 11))
63+
x, y, z = grid.dimensions
64+
65+
image_vs = Function(name='image_vs', grid=grid, space_order=1, staggered=NODE)
66+
p_back_xy = TimeFunction(name='p_back_xy', grid=grid, staggered=(x, y),
67+
space_order=4, time_order=1, save=Buffer(1))
68+
69+
eqns = [Eq(image_vs, p_back_xy + image_vs),
70+
Eq(p_back_xy.backward, p_back_xy)]
71+
72+
op = Operator(eqns)
73+
assert_structure(
74+
op,
75+
['t,x0_blk0,y0_blk0,x,y,z', 't,x1_blk0,y1_blk0,x,y,z'],
76+
'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz'
77+
)
78+
79+
6180
@pytest.mark.parametrize('expr,expected', [
6281
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
6382
('fa[x]**2', 'fa[x]*fa[x]'),

0 commit comments

Comments
 (0)