Skip to content

Commit 54c2b2a

Browse files
committed
compiler: Memoize FindNodes
1 parent 7d3debd commit 54c2b2a

2 files changed

Lines changed: 21 additions & 2 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,8 +1172,13 @@ class FindNodes(LazyVisitor[Node, list[Node], None]):
11721172
def __init__(self, match: type, mode: str = 'type') -> None:
11731173
super().__init__()
11741174
self.match = match
1175+
self.mode = mode
11751176
self.rule = self.rules[mode]
11761177

1178+
@memoized_weak_meth(key=lambda i: (i.match, i.mode), freeze=tuple, thaw=list)
1179+
def visit(self, o, *args, **kwargs):
1180+
return super().visit(o, *args, **kwargs)
1181+
11771182
def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]:
11781183
if self.rule(self.match, o):
11791184
yield o
@@ -1194,6 +1199,11 @@ def __init__(self, match: type, start: Node, stop: Node | None = None) -> None:
11941199
self.start = start
11951200
self.stop = stop
11961201

1202+
def visit(self, o, *args, **kwargs):
1203+
# `start` and `stop` are part of this visitor's state.
1204+
return GenericVisitor.visit(self, o, *args, **kwargs)
1205+
1206+
11971207
def visit_object(self, o: object, flag: bool = False) -> LazyVisit[Node, bool]:
11981208
yield from ()
11991209
return flag # noqa: B901

tests/test_visitors.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from devito.ir.equations import DummyEq
77
from devito.ir.iet import (
88
Block, Call, Callable, Conditional, Expression, FindApplications, FindNodes,
9-
FindSections, FindSymbols, IsPerfectIteration, Iteration, MapNodes, Transformer,
10-
Uxreplace, printAST
9+
FindSections, FindSymbols, FindWithin, IsPerfectIteration, Iteration, MapNodes,
10+
Transformer, Uxreplace, printAST
1111
)
1212
from devito.types import Array, SpaceDimension, Symbol
1313

@@ -210,6 +210,15 @@ def test_find_sections(exprs, block1, block2, block3):
210210
assert len(found[2]) == 1
211211

212212

213+
def test_find_within_not_cached_like_findnodes(block3):
214+
expr0 = FindWithin(Expression, block3.nodes[0], block3.nodes[1]).visit(block3)
215+
expr1 = FindWithin(Expression, block3.nodes[1], block3.nodes[2]).visit(block3)
216+
217+
assert len(expr0) == 3
218+
assert len(expr1) == 3
219+
assert expr0 != expr1
220+
221+
213222
def test_is_perfect_iteration(block1, block2, block3, block4):
214223
checker = IsPerfectIteration()
215224

0 commit comments

Comments
 (0)