3131import sympy as sp
3232
3333import loopy as lp
34- from pymbolic import var
34+ import pymbolic .primitives as prim
35+ from pymbolic import Expression , var
3536from pymbolic .mapper import CSECachingMapperMixin , IdentityMapper
3637from pymbolic .primitives import make_sym_vector
3738from pytools import memoize_method
@@ -1084,7 +1085,7 @@ def replace_inner_kernel(self, new_inner_kernel):
10841085 mapper_method = "map_axis_target_derivative"
10851086
10861087
1087- class _VectorIndexAdder (CSECachingMapperMixin , IdentityMapper ):
1088+ class _VectorIndexAdder (CSECachingMapperMixin [ Expression , []], IdentityMapper [[]] ):
10881089 def __init__ (self , vec_name , additional_indices ):
10891090 self .vec_name = vec_name
10901091 self .additional_indices = additional_indices
@@ -1099,7 +1100,14 @@ def map_subscript(self, expr):
10991100 else :
11001101 return IdentityMapper .map_subscript (self , expr )
11011102
1102- map_common_subexpression_uncached = IdentityMapper .map_common_subexpression
1103+ def map_common_subexpression_uncached (self ,
1104+ expr : prim .CommonSubexpression ) -> Expression :
1105+ result = self .rec (expr .child )
1106+ if result is expr .child :
1107+ return expr
1108+
1109+ return type (expr )(
1110+ result , expr .prefix , expr .scope , ** expr .get_extra_properties ())
11031111
11041112
11051113class DirectionalDerivative (DerivativeBase ):
0 commit comments