Skip to content

Commit 075ab3b

Browse files
committed
misc: Add flags to generic LazyVisitor, rewrite FindWithin
1 parent 4f5a6ec commit 075ab3b

1 file changed

Lines changed: 47 additions & 55 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 47 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from collections import OrderedDict
8-
from collections.abc import Callable, Iterable, Iterator, Sequence
8+
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
99
from itertools import chain, groupby
1010
from typing import Any, Generic, TypeVar
1111
import ctypes
@@ -59,38 +59,51 @@ def always_rebuild(self, o, *args, **kwargs):
5959
return o._rebuild(*new_ops, **okwargs)
6060

6161

62+
# Type variables for LazyVisitor
6263
YieldType = TypeVar('YieldType', covariant=True)
64+
FlagType = TypeVar('FlagType', covariant=True)
6365
ResultType = TypeVar('ResultType', covariant=True)
6466

67+
# Describes the return type of a LazyVisitor visit method which yields objects of
68+
# type YieldType and returns a FlagType (or NoneType)
69+
LazyVisit = Generator[YieldType, None, FlagType]
6570

66-
class LazyVisitor(GenericVisitor, Generic[YieldType, ResultType]):
71+
72+
class LazyVisitor(GenericVisitor, Generic[YieldType, ResultType, FlagType]):
6773

6874
"""
6975
A generic visitor that lazily yields results instead of flattening results
70-
from children at every step.
76+
from children at every step. Intermediate visit methods may return a flag
77+
of type FlagType in addition to yielding results; by default, the last flag
78+
returned by a child is the one propagated.
7179
7280
Subclass-defined visit methods should be generators.
7381
"""
7482

75-
def lookup_method(self, instance) -> Callable[..., Iterator[YieldType]]:
83+
def lookup_method(self, instance) \
84+
-> Callable[..., LazyVisit[YieldType, FlagType]]:
7685
return super().lookup_method(instance)
7786

78-
def _visit(self, o, *args, **kwargs) -> Iterator[YieldType]:
87+
def _visit(self, o, *args, **kwargs) -> LazyVisit[YieldType, FlagType]:
7988
meth = self.lookup_method(o)
80-
yield from meth(o, *args, **kwargs)
89+
flag = yield from meth(o, *args, **kwargs)
90+
return flag
8191

82-
def _post_visit(self, ret: Iterator[YieldType]) -> ResultType:
92+
def _post_visit(self, ret: LazyVisit[YieldType, FlagType]) -> ResultType:
8393
return list(ret)
8494

85-
def visit_object(self, o: object, **kwargs) -> Iterator[YieldType]:
95+
def visit_object(self, o: object, **kwargs) -> LazyVisit[YieldType, FlagType]:
8696
yield from ()
8797

88-
def visit_Node(self, o: Node, **kwargs) -> Iterator[YieldType]:
89-
yield from self._visit(o.children, **kwargs)
98+
def visit_Node(self, o: Node, **kwargs) -> LazyVisit[YieldType, FlagType]:
99+
flag = yield from self._visit(o.children, **kwargs)
100+
return flag
90101

91-
def visit_tuple(self, o: Sequence[Any], **kwargs) -> Iterator[YieldType]:
102+
def visit_tuple(self, o: Sequence[Any], **kwargs) -> LazyVisit[YieldType, FlagType]:
103+
flag: FlagType = None
92104
for i in o:
93-
yield from self._visit(i, **kwargs)
105+
flag = yield from self._visit(i, **kwargs)
106+
return flag
94107

95108
visit_list = visit_tuple
96109

@@ -1015,7 +1028,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
10151028
return ret
10161029

10171030

1018-
class FindSymbols(LazyVisitor[Any, list[Any]]):
1031+
class FindSymbols(LazyVisitor[Any, list[Any], None]):
10191032

10201033
"""
10211034
Find symbols in an Iteration/Expression tree.
@@ -1089,7 +1102,7 @@ def visit_Operator(self, o) -> Iterator[Any]:
10891102
yield from self._visit(i)
10901103

10911104

1092-
class FindNodes(LazyVisitor[Node, list[Node]]):
1105+
class FindNodes(LazyVisitor[Node, list[Node], None]):
10931106

10941107
"""
10951108
Find all instances of given type.
@@ -1123,78 +1136,57 @@ def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]:
11231136
yield from self._visit(i, **kwargs)
11241137

11251138

1126-
class FindWithin(FindNodes):
1139+
class FindWithin(FindNodes, LazyVisitor[Node, list[Node], bool]):
11271140

11281141
"""
11291142
Like FindNodes, but given an additional parameter `within=(start, stop)`,
11301143
it starts collecting matching nodes only after `start` is found, and stops
11311144
collecting matching nodes after `stop` is found.
11321145
"""
11331146

1134-
# Sentinel values to signal the start/end of a matching window
1135-
SET_FLAG = object()
1136-
UNSET_FLAG = object()
1137-
11381147
def __init__(self, match: type, start: Node, stop: Node | None = None) -> None:
11391148
super().__init__(match)
11401149
self.start = start
11411150
self.stop = stop
11421151

1143-
def _post_visit(self, ret: Iterator[Node | object]) -> list[Node]:
1144-
return super()._post_visit(i for i in ret
1145-
if i not in (self.SET_FLAG, self.UNSET_FLAG))
1146-
1147-
def visit_object(self, o: object, flag: bool = False) -> Iterator[Node | object]:
1148-
yield self.SET_FLAG if flag else self.UNSET_FLAG
1152+
def visit_object(self, o: object, flag: bool = False) -> LazyVisit[Node, bool]:
1153+
yield from ()
1154+
return flag
11491155

1150-
def visit_tuple(self, o: Sequence[Any],
1151-
flag: bool = False) -> Iterator[Node | object]:
1156+
def visit_tuple(self, o: Sequence[Any], flag: bool = False) -> LazyVisit[Node, bool]:
11521157
for el in o:
1153-
for i in self._visit(el, flag=flag):
1154-
# New flag state is yielded at the end of child results
1155-
if i is self.SET_FLAG:
1156-
flag = True
1157-
continue
1158-
if i is self.UNSET_FLAG:
1159-
flag = False
1160-
continue
1161-
1162-
# Regular object
1163-
yield i
1158+
# Yield results from visiting this element, and update the flag
1159+
flag = yield from self._visit(el, flag=flag)
11641160

1165-
yield self.SET_FLAG if flag else self.UNSET_FLAG
1161+
return flag
11661162

11671163
visit_list = visit_tuple
11681164

1169-
def visit_Node(self, o: Node, flag: bool = False) -> Iterator[Node | object]:
1165+
def visit_Node(self, o: Node, flag: bool = False) -> LazyVisit[Node, bool]:
11701166
flag = flag or (o is self.start)
11711167

11721168
if flag and self.rule(self.match, o):
11731169
yield o
11741170

11751171
for child in o.children:
1176-
for i in self._visit(child, flag=flag):
1177-
# New flag state is yielded at the end of child results
1178-
if i is self.SET_FLAG:
1179-
flag = True
1180-
continue
1181-
if i is self.UNSET_FLAG:
1182-
if flag:
1183-
yield self.UNSET_FLAG
1184-
return
1185-
continue
1186-
1187-
# Regular object
1188-
yield i
1172+
# Yield results from this child and retrieve its flag
1173+
nflag = yield from self._visit(child, flag=flag)
1174+
1175+
# If we started collecting outside of here and the child found a stop,
1176+
# don't visit the rest of the children
1177+
if flag and not nflag:
1178+
return False
1179+
flag = nflag
11891180

1181+
# Update the flag if we found a stop
11901182
flag &= (o is not self.stop)
1191-
yield self.SET_FLAG if flag else self.UNSET_FLAG
1183+
return flag
11921184

11931185

11941186
ApplicationType = TypeVar('ApplicationType')
11951187

11961188

1197-
class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType]]):
1189+
class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType], None]):
11981190

11991191
"""
12001192
Find all SymPy applied functions (aka, `Application`s). The user may refine

0 commit comments

Comments
 (0)