Skip to content

Commit 5f054ff

Browse files
committed
ci: add cse ordering test
1 parent 19a2857 commit 5f054ff

2 files changed

Lines changed: 25 additions & 2 deletions

File tree

devito/passes/clusters/cse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _toposort(exprs):
265265
dag = DAG(exprs)
266266

267267
for e0 in exprs:
268-
if not isinstance(e0.lhs, CTemp):
268+
if not search(e0, CTemp):
269269
continue
270270

271271
for e1 in exprs:
@@ -279,7 +279,7 @@ def choose_element(queue, scheduled):
279279
first = sorted(tmps, key=lambda i: i.lhs.name).pop(0)
280280
queue.remove(first)
281281
else:
282-
first = queue.popleft()
282+
first = queue.pop()
283283
return first
284284

285285
processed = dag.topological_sort(choose_element)

tests/test_cse.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,26 @@ def test_advanced_algo(exprs, expected):
246246

247247
assert len(processed) == len(expected)
248248
assert all(str(i.rhs) == j for i, j in zip(processed, expected))
249+
250+
251+
def test_advanced_algo_order():
252+
"""
253+
Test that smartsort/advanced doesn't break equation order.
254+
"""
255+
grid = Grid((3, 3, 3))
256+
u = TimeFunction(name="u", grid=grid, space_order=2)
257+
v = TimeFunction(name="v", grid=grid, space_order=2)
258+
259+
eq0 = DummyEq(indexify(diffify(Eq(u.forward, u.dx).evaluate)))
260+
eq1 = DummyEq(indexify(diffify(Eq(v, u.dx).evaluate)))
261+
eq_b = DummyEq(indexify(diffify(Eq(v.forward, v + u.forward).evaluate)))
262+
263+
counter = generator()
264+
make = lambda _: CTemp(name='r%d' % counter(), dtype=np.float32).indexify()
265+
processed = _cse([eq0, eq1, eq_b], make, mode='advanced')
266+
267+
# Three input equation and 2 CTemps
268+
assert len(processed) == 5
269+
assert processed[0].lhs.name == 'r1'
270+
# eq_b has to be last
271+
assert processed[-1] == eq_b

0 commit comments

Comments
 (0)