11from collections .abc import Callable , Iterable , Iterator
2- from typing import Literal
2+ from itertools import chain
3+ from typing import Any , Literal
34
45import sympy
56
1112 'retrieve_terminals' , 'retrieve_symbols' , 'retrieve_dimensions' ,
1213 'retrieve_derivatives' , 'search' ]
1314
14- class Set (set [sympy .Basic ]):
15+
16+ class Set (set ):
1517
1618 @staticmethod
17- def wrap (obj : sympy . Basic ) -> set [ sympy . Basic ] :
19+ def wrap (obj ) -> set :
1820 return {obj }
1921
2022
21- class List (list [ sympy . Basic ] ):
23+ class List (list ):
2224
2325 @staticmethod
24- def wrap (obj : sympy . Basic ) -> list [ sympy . Basic ] :
26+ def wrap (obj ) -> list :
2527 return [obj ]
2628
27- def update (self , obj : sympy . Basic ) -> None :
29+ def update (self , obj : Iterable [ Any ] ) -> None :
2830 self .extend (obj )
2931
3032
@@ -35,48 +37,42 @@ def update(self, obj: sympy.Basic) -> None:
3537
3638
3739class Search :
38-
39- def __init__ (self , query : Callable [[sympy .Basic ], bool ],
40- order : Literal ['postorder' , 'preorder' ], deep : bool = False ) -> None :
40+ def __init__ (self , query : Callable [[Any ], bool ], deep : bool = False ) -> None :
4141 """
42- Search objects in an expression. This is much quicker than the more
43- general SymPy's find.
42+ Search objects in an expression. This is much quicker than the more general
43+ SymPy's find.
4444
4545 Parameters
4646 ----------
4747 query
4848 Any query from :mod:`queries`.
49- order : str
50- Either `preorder` or `postorder`, for the search order.
5149 deep : bool, optional
5250 If True, propagate the search within an Indexed's indices. Defaults to False.
5351 """
5452 self .query = query
55- self .order = order
5653 self .deep = deep
5754
58- def _next (self , expr ) -> Iterator [sympy . Basic ]:
55+ def _next (self , expr ) -> Iterator [Any ]:
5956 if self .deep and expr .is_Indexed :
6057 yield from expr .indices
6158 elif not q_leaf (expr ):
6259 yield from expr .args
6360
64- def visit (self , expr : sympy .Basic ) -> Iterator [sympy .Basic ]:
65- """Visit the expression in the specified order."""
66- if self .order == 'preorder' :
67- if self .query (expr ):
68- yield expr
69- for child in self ._next (expr ):
70- yield from self .visit (child )
71- else :
72- for child in self ._next (expr ):
73- yield from self .visit (child )
74- if self .query (expr ):
75- yield expr
61+ def visit_preorder (self , expr ) -> Iterator [Any ]:
62+ if self .query (expr ):
63+ yield expr
64+ for i in self ._next (expr ):
65+ yield from self .visit_preorder (i )
7666
67+ def visit_postorder (self , expr ) -> Iterator [Any ]:
68+ for i in self ._next (expr ):
69+ yield from self .visit_postorder (i )
70+ if self .query (expr ):
71+ yield expr
7772
78- def search (exprs : sympy .Basic | Iterable [sympy .Basic ],
79- query : type | Callable [[sympy .Basic ], bool ],
73+
74+ def search (exprs ,
75+ query : type | Callable [[Any ], bool ],
8076 mode : Literal ['all' , 'unique' ] = 'unique' ,
8177 visit : Literal ['dfs' , 'bfs' , 'bfs_first_hit' ] = 'dfs' ,
8278 deep : bool = False ) -> List | Set :
@@ -92,21 +88,21 @@ def search(exprs: sympy.Basic | Iterable[sympy.Basic],
9288
9389 # Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
9490 # is retained in this function's parameters for backwards compatibility
95- order = 'postorder' if visit == 'dfs' else 'preorder'
96- searcher = Search ( Q , order , deep )
91+ searcher = Search ( Q , deep )
92+ _visit = searcher . visit_postorder if visit == 'dfs' else searcher . visit_preorder
9793
9894 Collection = modes [mode ]
9995 found = Collection ()
10096 for e in as_tuple (exprs ):
10197 if not isinstance (e , sympy .Basic ):
10298 continue
10399
104- for i in searcher .visit (e ):
105- found .update (Collection .wrap (i ))
106-
107- if visit == 'bfs_first_hit' :
108- # Stop at the first hit for this outer expression
100+ if visit == 'bfs_first_hit' :
101+ for i in _visit (e ):
102+ found .update (Collection .wrap (i ))
109103 break
104+ else :
105+ found .update (_visit (e ))
110106
111107 return found
112108
0 commit comments