Skip to content

Commit fef0335

Browse files
mlouboutFabioLuporini
authored andcommitted
compiler: catch corner case read after write
1 parent 6ec492d commit fef0335

3 files changed

Lines changed: 33 additions & 5 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,15 @@ def _break_for_parallelism(self, scope, dim, timestamp):
231231
if any(dep.is_carried(i) for i in candidates):
232232
test0 = dep.is_flow and dep.is_lex_negative
233233
test1 = dep.is_anti and dep.is_lex_positive
234+
if test0:
235+
# If the same access pair is not a flow under logical distance,
236+
# the dep is a buffer/modulo-aliasing artifact and fission is OK
237+
ldist = dep.source.distance(dep.sink, logical=True)
238+
real_flow = (ldist > 0) or \
239+
(ldist == 0 and dep.sink.lex_ge(dep.source))
240+
if not real_flow:
241+
test0 = real_flow
234242
if test0 or test1:
235-
# Would break a data dependence
236243
return False
237244

238245
test = test or (bool(dep.cause & candidates) and not dep.is_lex_equal)

devito/ir/support/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,7 @@ def d_flow_gen(self):
11051105
continue
11061106

11071107
distance = dependence.distance
1108+
11081109
try:
11091110
is_flow = distance > 0 or (r.lex_ge(w) and distance == 0)
11101111
except TypeError:

tests/test_dse.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
get_params, skipif
1010
)
1111
from devito import ( # noqa
12-
NODE, Abs, ConditionalDimension, Constant, DefaultDimension, Derivative, Dimension,
13-
Eq, Function, Ge, Grid, Inc, Lt, Operator, SparseTimeFunction, SubDimension,
14-
TimeFunction, configuration, cos, dimensions, div, exp, first_derivative, floor, grad,
15-
norm, sin, solve, sqrt, switchconfig, transpose
12+
NODE, Abs, Buffer, ConditionalDimension, Constant, DefaultDimension, Derivative,
13+
Dimension, Eq, Function, Ge, Grid, Inc, Lt, Operator, SparseTimeFunction,
14+
SubDimension, TimeFunction, configuration, cos, dimensions, div, exp,
15+
first_derivative, floor, grad, norm, sin, solve, sqrt, switchconfig, transpose
1616
)
1717
from devito.exceptions import InvalidArgument, InvalidOperator
1818
from devito.ir import (
@@ -58,6 +58,26 @@ def test_scheduling_after_rewrite():
5858
assert all(trees[1].root.dim is tree.root.dim for tree in trees[1:])
5959

6060

61+
def test_scheduling_no_deriv():
62+
grid = Grid((11, 11, 11))
63+
x, y, z = grid.dimensions
64+
65+
image_vs = Function(name='image_vs', grid=grid, space_order=1, staggered=NODE)
66+
p_back_xy = TimeFunction(name='p_back_xy', grid=grid, staggered=(x, y),
67+
space_order=4, time_order=1, save=Buffer(1))
68+
69+
eqns = [Eq(image_vs, p_back_xy + image_vs),
70+
Eq(p_back_xy.backward, p_back_xy)]
71+
72+
op = Operator(eqns)
73+
74+
assert_structure(
75+
op,
76+
['t,x0_blk0,y0_blk0,x,y,z', 't,x1_blk0,y1_blk0,x,y,z'],
77+
'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz'
78+
)
79+
80+
6181
@pytest.mark.parametrize('expr,expected', [
6282
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
6383
('fa[x]**2', 'fa[x]*fa[x]'),

0 commit comments

Comments
 (0)