|
5 | 5 | """ |
6 | 6 |
|
7 | 7 | from collections import OrderedDict |
8 | | -from collections.abc import Callable, Iterable, Iterator, Sequence |
| 8 | +from collections.abc import Callable, Generator, Iterable, Iterator, Sequence |
9 | 9 | from itertools import chain, groupby |
10 | 10 | from typing import Any, Generic, TypeVar |
11 | 11 | import ctypes |
@@ -59,38 +59,51 @@ def always_rebuild(self, o, *args, **kwargs): |
59 | 59 | return o._rebuild(*new_ops, **okwargs) |
60 | 60 |
|
61 | 61 |
|
| 62 | +# Type variables for LazyVisitor |
62 | 63 | YieldType = TypeVar('YieldType', covariant=True) |
| 64 | +FlagType = TypeVar('FlagType', covariant=True) |
63 | 65 | ResultType = TypeVar('ResultType', covariant=True) |
64 | 66 |
|
| 67 | +# Describes the return type of a LazyVisitor visit method which yields objects of |
| 68 | +# type YieldType and returns a FlagType (or NoneType) |
| 69 | +LazyVisit = Generator[YieldType, None, FlagType] |
65 | 70 |
|
66 | | -class LazyVisitor(GenericVisitor, Generic[YieldType, ResultType]): |
| 71 | + |
| 72 | +class LazyVisitor(GenericVisitor, Generic[YieldType, ResultType, FlagType]): |
67 | 73 |
|
68 | 74 | """ |
69 | 75 | A generic visitor that lazily yields results instead of flattening results |
70 | | - from children at every step. |
| 76 | + from children at every step. Intermediate visit methods may return a flag |
| 77 | + of type FlagType in addition to yielding results; by default, the last flag |
| 78 | + returned by a child is the one propagated. |
71 | 79 |
|
72 | 80 | Subclass-defined visit methods should be generators. |
73 | 81 | """ |
74 | 82 |
|
75 | | - def lookup_method(self, instance) -> Callable[..., Iterator[YieldType]]: |
| 83 | + def lookup_method(self, instance) \ |
| 84 | + -> Callable[..., LazyVisit[YieldType, FlagType]]: |
76 | 85 | return super().lookup_method(instance) |
77 | 86 |
|
78 | | - def _visit(self, o, *args, **kwargs) -> Iterator[YieldType]: |
| 87 | + def _visit(self, o, *args, **kwargs) -> LazyVisit[YieldType, FlagType]: |
79 | 88 | meth = self.lookup_method(o) |
80 | | - yield from meth(o, *args, **kwargs) |
| 89 | + flag = yield from meth(o, *args, **kwargs) |
| 90 | + return flag |
81 | 91 |
|
82 | | - def _post_visit(self, ret: Iterator[YieldType]) -> ResultType: |
| 92 | + def _post_visit(self, ret: LazyVisit[YieldType, FlagType]) -> ResultType: |
83 | 93 | return list(ret) |
84 | 94 |
|
85 | | - def visit_object(self, o: object, **kwargs) -> Iterator[YieldType]: |
| 95 | + def visit_object(self, o: object, **kwargs) -> LazyVisit[YieldType, FlagType]: |
86 | 96 | yield from () |
87 | 97 |
|
88 | | - def visit_Node(self, o: Node, **kwargs) -> Iterator[YieldType]: |
89 | | - yield from self._visit(o.children, **kwargs) |
| 98 | + def visit_Node(self, o: Node, **kwargs) -> LazyVisit[YieldType, FlagType]: |
| 99 | + flag = yield from self._visit(o.children, **kwargs) |
| 100 | + return flag |
90 | 101 |
|
91 | | - def visit_tuple(self, o: Sequence[Any], **kwargs) -> Iterator[YieldType]: |
| 102 | + def visit_tuple(self, o: Sequence[Any], **kwargs) -> LazyVisit[YieldType, FlagType]: |
| 103 | + flag: FlagType = None |
92 | 104 | for i in o: |
93 | | - yield from self._visit(i, **kwargs) |
| 105 | + flag = yield from self._visit(i, **kwargs) |
| 106 | + return flag |
94 | 107 |
|
95 | 108 | visit_list = visit_tuple |
96 | 109 |
|
@@ -1015,7 +1028,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False): |
1015 | 1028 | return ret |
1016 | 1029 |
|
1017 | 1030 |
|
1018 | | -class FindSymbols(LazyVisitor[Any, list[Any]]): |
| 1031 | +class FindSymbols(LazyVisitor[Any, list[Any], None]): |
1019 | 1032 |
|
1020 | 1033 | """ |
1021 | 1034 | Find symbols in an Iteration/Expression tree. |
@@ -1089,7 +1102,7 @@ def visit_Operator(self, o) -> Iterator[Any]: |
1089 | 1102 | yield from self._visit(i) |
1090 | 1103 |
|
1091 | 1104 |
|
1092 | | -class FindNodes(LazyVisitor[Node, list[Node]]): |
| 1105 | +class FindNodes(LazyVisitor[Node, list[Node], None]): |
1093 | 1106 |
|
1094 | 1107 | """ |
1095 | 1108 | Find all instances of given type. |
@@ -1123,78 +1136,57 @@ def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]: |
1123 | 1136 | yield from self._visit(i, **kwargs) |
1124 | 1137 |
|
1125 | 1138 |
|
1126 | | -class FindWithin(FindNodes): |
| 1139 | +class FindWithin(FindNodes, LazyVisitor[Node, list[Node], bool]): |
1127 | 1140 |
|
1128 | 1141 | """ |
1129 | 1142 | Like FindNodes, but given an additional parameter `within=(start, stop)`, |
1130 | 1143 | it starts collecting matching nodes only after `start` is found, and stops |
1131 | 1144 | collecting matching nodes after `stop` is found. |
1132 | 1145 | """ |
1133 | 1146 |
|
1134 | | - # Sentinel values to signal the start/end of a matching window |
1135 | | - SET_FLAG = object() |
1136 | | - UNSET_FLAG = object() |
1137 | | - |
1138 | 1147 | def __init__(self, match: type, start: Node, stop: Node | None = None) -> None: |
1139 | 1148 | super().__init__(match) |
1140 | 1149 | self.start = start |
1141 | 1150 | self.stop = stop |
1142 | 1151 |
|
1143 | | - def _post_visit(self, ret: Iterator[Node | object]) -> list[Node]: |
1144 | | - return super()._post_visit(i for i in ret |
1145 | | - if i not in (self.SET_FLAG, self.UNSET_FLAG)) |
1146 | | - |
1147 | | - def visit_object(self, o: object, flag: bool = False) -> Iterator[Node | object]: |
1148 | | - yield self.SET_FLAG if flag else self.UNSET_FLAG |
| 1152 | + def visit_object(self, o: object, flag: bool = False) -> LazyVisit[Node, bool]: |
| 1153 | + yield from () |
| 1154 | + return flag |
1149 | 1155 |
|
1150 | | - def visit_tuple(self, o: Sequence[Any], |
1151 | | - flag: bool = False) -> Iterator[Node | object]: |
| 1156 | + def visit_tuple(self, o: Sequence[Any], flag: bool = False) -> LazyVisit[Node, bool]: |
1152 | 1157 | for el in o: |
1153 | | - for i in self._visit(el, flag=flag): |
1154 | | - # New flag state is yielded at the end of child results |
1155 | | - if i is self.SET_FLAG: |
1156 | | - flag = True |
1157 | | - continue |
1158 | | - if i is self.UNSET_FLAG: |
1159 | | - flag = False |
1160 | | - continue |
1161 | | - |
1162 | | - # Regular object |
1163 | | - yield i |
| 1158 | + # Yield results from visiting this element, and update the flag |
| 1159 | + flag = yield from self._visit(el, flag=flag) |
1164 | 1160 |
|
1165 | | - yield self.SET_FLAG if flag else self.UNSET_FLAG |
| 1161 | + return flag |
1166 | 1162 |
|
1167 | 1163 | visit_list = visit_tuple |
1168 | 1164 |
|
1169 | | - def visit_Node(self, o: Node, flag: bool = False) -> Iterator[Node | object]: |
| 1165 | + def visit_Node(self, o: Node, flag: bool = False) -> LazyVisit[Node, bool]: |
1170 | 1166 | flag = flag or (o is self.start) |
1171 | 1167 |
|
1172 | 1168 | if flag and self.rule(self.match, o): |
1173 | 1169 | yield o |
1174 | 1170 |
|
1175 | 1171 | for child in o.children: |
1176 | | - for i in self._visit(child, flag=flag): |
1177 | | - # New flag state is yielded at the end of child results |
1178 | | - if i is self.SET_FLAG: |
1179 | | - flag = True |
1180 | | - continue |
1181 | | - if i is self.UNSET_FLAG: |
1182 | | - if flag: |
1183 | | - yield self.UNSET_FLAG |
1184 | | - return |
1185 | | - continue |
1186 | | - |
1187 | | - # Regular object |
1188 | | - yield i |
| 1172 | + # Yield results from this child and retrieve its flag |
| 1173 | + nflag = yield from self._visit(child, flag=flag) |
| 1174 | + |
| 1175 | + # If we started collecting outside of here and the child found a stop, |
| 1176 | + # don't visit the rest of the children |
| 1177 | + if flag and not nflag: |
| 1178 | + return False |
| 1179 | + flag = nflag |
1189 | 1180 |
|
| 1181 | + # Update the flag if we found a stop |
1190 | 1182 | flag &= (o is not self.stop) |
1191 | | - yield self.SET_FLAG if flag else self.UNSET_FLAG |
| 1183 | + return flag |
1192 | 1184 |
|
1193 | 1185 |
|
1194 | 1186 | ApplicationType = TypeVar('ApplicationType') |
1195 | 1187 |
|
1196 | 1188 |
|
1197 | | -class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType]]): |
| 1189 | +class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType], None]): |
1198 | 1190 |
|
1199 | 1191 | """ |
1200 | 1192 | Find all SymPy applied functions (aka, `Application`s). The user may refine |
|
0 commit comments