1313from devito .finite_differences .differentiable import IndexDerivative
1414from devito .ir import Cluster , Scope , cluster_pass
1515from devito .symbolics import estimate_cost , q_leaf , q_terminal
16- from devito .symbolics .search import retrieve_ctemps
16+ from devito .symbolics .search import search
1717from devito .symbolics .manipulation import _uxreplace
1818from devito .tools import DAG , as_list , as_tuple , frozendict , extract_dtype
1919from devito .types import Eq , Symbol , Temp
@@ -31,6 +31,11 @@ class CTemp(Temp):
3131 ordering_of_classes .insert (ordering_of_classes .index ('Temp' ) + 1 , 'CTemp' )
3232
3333
34+ def retrieve_ctemps (exprs , mode = 'all' ):
35+ """Shorthand to retrieve the CTemps in `exprs`"""
36+ return search (exprs , lambda expr : isinstance (expr , CTemp ), mode , 'dfs' )
37+
38+
3439@cluster_pass
3540def cse (cluster , sregistry = None , options = None , ** kwargs ):
3641 """
@@ -229,12 +234,12 @@ def _compact(exprs, exclude):
229234 mapper = {e .lhs : e .rhs for e in candidates if q_leaf (e .rhs )}
230235
231236 # Find all the CTemps in expression right-hand-sides without removing duplicates
232- ctemps = retrieve_ctemps ([e .rhs for e in exprs ])
233- ctemp_count = Counter (ctemps )
237+ ctemps = retrieve_ctemps (e .rhs for e in exprs )
234238
235239 # If there are ctemps in the expressions, then add any to the mapper which only
236240 # appear once
237241 if ctemps :
242+ ctemp_count = Counter (ctemps )
238243 mapper .update ({e .lhs : e .rhs for e in candidates
239244 if ctemp_count [e .lhs ] == 1 })
240245
0 commit comments