Skip to content

Commit bc1a944

Browse files
committed
Fix mypy issues from better pymbolic/pytools typing
1 parent 6b5eb21 commit bc1a944

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

sumpy/expansion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def get_stored_mpole_coefficients_from_full(self,
283283
# }}}
284284

285285
@memoize_method
286-
def get_full_coefficient_identifiers(self) -> list[Hashable]:
286+
def get_full_coefficient_identifiers(self) -> Sequence[Hashable]:
287287
"""
288288
Returns identifiers for every coefficient in the complete expansion.
289289
"""

sumpy/kernel.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
import sympy as sp
3232

3333
import loopy as lp
34-
from pymbolic import var
34+
import pymbolic.primitives as prim
35+
from pymbolic import Expression, var
3536
from pymbolic.mapper import CSECachingMapperMixin, IdentityMapper
3637
from pymbolic.primitives import make_sym_vector
3738
from pytools import memoize_method
@@ -1084,7 +1085,7 @@ def replace_inner_kernel(self, new_inner_kernel):
10841085
mapper_method = "map_axis_target_derivative"
10851086

10861087

1087-
class _VectorIndexAdder(CSECachingMapperMixin, IdentityMapper):
1088+
class _VectorIndexAdder(CSECachingMapperMixin[Expression, []], IdentityMapper[[]]):
10881089
def __init__(self, vec_name, additional_indices):
10891090
self.vec_name = vec_name
10901091
self.additional_indices = additional_indices
@@ -1099,7 +1100,14 @@ def map_subscript(self, expr):
10991100
else:
11001101
return IdentityMapper.map_subscript(self, expr)
11011102

1102-
map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
1103+
def map_common_subexpression_uncached(self,
1104+
expr: prim.CommonSubexpression) -> Expression:
1105+
result = self.rec(expr.child)
1106+
if result is expr.child:
1107+
return expr
1108+
1109+
return type(expr)(
1110+
result, expr.prefix, expr.scope, **expr.get_extra_properties())
11031111

11041112

11051113
class DirectionalDerivative(DerivativeBase):

0 commit comments

Comments
 (0)