Skip to content

Commit 644eacb

Browse files
kaushikcfdinducer
authored andcommitted
Mapper: allow 'expr' to be subclasses of ArrayOrNames
1 parent a59d3d3 commit 644eacb

5 files changed

Lines changed: 10 additions & 18 deletions

File tree

pytato/analysis/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def rec(self, expr: ArrayOrNames) -> None: # type: ignore
7676
if id(expr) in self._visited_ids:
7777
return
7878

79-
# type-ignore reason: super().rec expects either 'Array' or
80-
# 'AbstractResultWithNamedArrays', passed 'ArrayOrNames'
81-
super().rec(expr) # type: ignore
79+
super().rec(expr)
8280
self._visited_ids.add(id(expr))
8381

8482
def map_index_lambda(self, expr: IndexLambda) -> None:

pytato/distributed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,7 @@ def rec(self, expr: ArrayOrNames) -> FrozenSet[Array]: # type: ignore[override]
578578
try:
579579
return self.cache[expr]
580580
except KeyError:
581-
# type-ignore reason: type not compatible with super.rec() type
582-
result: FrozenSet[Array] = super().rec(expr) # type: ignore[type-var]
581+
result: FrozenSet[Array] = super().rec(expr)
583582
self.cache[expr] = result
584583
return result
585584

pytato/partition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272

7373

7474
ArrayOrNames = Union[Array, AbstractResultWithNamedArrays]
75-
T = TypeVar("T", Array, AbstractResultWithNamedArrays)
75+
T = TypeVar("T", bound=ArrayOrNames)
7676
PartId = Hashable
7777

7878

pytato/target/python/numpy_like.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,7 @@ def generate_numpy_like(expr: Union[Array, Mapping[str, Array], DictOfNamedArray
549549
numpy_backend=target.numpy_like_module_name_shorthand,
550550
numpy="np",
551551
vng=var_name_gen)
552-
# type-ignore-reason: https://github.com/inducer/pytato/issues/236
553-
result_var = cgen_mapper(expr) # type: ignore[type-var]
552+
result_var = cgen_mapper(expr)
554553

555554
lines = cgen_mapper.lines
556555
lines.append(ast.Return(ast.Name(result_var)))

pytato/transform/__init__.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@
5050
if TYPE_CHECKING:
5151
from pytato.distributed import DistributedSendRefHolder, DistributedRecv
5252

53-
T = TypeVar("T", Array, AbstractResultWithNamedArrays)
53+
ArrayOrNames = Union[Array, AbstractResultWithNamedArrays]
54+
T = TypeVar("T", bound=ArrayOrNames)
5455
CombineT = TypeVar("CombineT") # used in CombineMapper
5556
CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper
5657
IndexOrShapeExpr = TypeVar("IndexOrShapeExpr")
57-
ArrayOrNames = Union[Array, AbstractResultWithNamedArrays]
5858
R = FrozenSet[Array]
5959

6060
__doc__ = """
@@ -378,7 +378,7 @@ def rec(self,
378378
except KeyError:
379379
result = Mapper.rec(self, expr,
380380
*args,
381-
**kwargs) # type: ignore[type-var]
381+
**kwargs)
382382
self._cache[key] = result
383383
return result
384384

@@ -554,8 +554,7 @@ def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...]
554554
def rec(self, expr: ArrayOrNames) -> CombineT: # type: ignore
555555
if expr in self.cache:
556556
return self.cache[expr]
557-
# type-ignore reason: type not compatible with super.rec() type
558-
result: CombineT = super().rec(expr) # type: ignore
557+
result: CombineT = super().rec(expr)
559558
self.cache[expr] = result
560559
return result
561560

@@ -961,9 +960,7 @@ def rec(self, expr: ArrayOrNames) -> None: # type: ignore
961960
if id(expr) in self._visited_ids:
962961
return
963962

964-
# type-ignore reason: super().rec expects either 'Array' or
965-
# 'AbstractResultWithNamedArrays', passed 'ArrayOrNames'
966-
super().rec(expr) # type: ignore
963+
super().rec(expr)
967964
self._visited_ids.add(id(expr))
968965

969966
# }}}
@@ -1061,8 +1058,7 @@ def __init__(self, nsuccessors: Mapping[Array, int]):
10611058
def rec(self, expr: ArrayOrNames) -> MPMSMaterializerAccumulator: # type: ignore
10621059
if expr in self.cache:
10631060
return self.cache[expr]
1064-
# type-ignore reason: type not compatible with super.rec() type
1065-
result: MPMSMaterializerAccumulator = super().rec(expr) # type: ignore
1061+
result: MPMSMaterializerAccumulator = super().rec(expr)
10661062
self.cache[expr] = result
10671063
return result
10681064

0 commit comments

Comments
 (0)