4141from collections .abc import Mapping
4242from typing import TYPE_CHECKING , Any , cast
4343
44+ from typing_extensions import override
45+
4446import pytools
4547from pytato .analysis import get_num_call_sites
4648from pytato .array import (
47- AbstractResultWithNamedArrays ,
4849 Array ,
4950 Axis as PtAxis ,
5051 DataWrapper ,
5859from pytato .transform import (
5960 ArrayOrNames ,
6061 CopyMapper ,
61- Deduplicator ,
6262 TransformMapperCache ,
63+ deduplicate ,
6364)
6465from pytools import UniqueNameGenerator , memoize_method
6566
@@ -95,6 +96,7 @@ def __init__(
9596 self .vng = UniqueNameGenerator ()
9697 self .seen_inputs : set [str ] = set ()
9798
99+ @override
98100 def map_data_wrapper (self , expr : DataWrapper ) -> Array :
99101 if expr .name is not None :
100102 if expr .name in self .seen_inputs :
@@ -114,13 +116,16 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
114116 axes = expr .axes ,
115117 tags = expr .tags )
116118
119+ @override
117120 def map_size_param (self , expr : SizeParam ) -> Array :
118121 raise NotImplementedError
119122
123+ @override
120124 def map_placeholder (self , expr : Placeholder ) -> Array :
121125 raise ValueError ("Placeholders cannot appear in"
122126 " DatawrapperToBoundPlaceholderMapper." )
123127
128+ @override
124129 def map_function_definition (
125130 self , expr : FunctionDefinition ) -> FunctionDefinition :
126131 raise ValueError ("Function definitions cannot appear in"
@@ -131,7 +136,7 @@ def map_function_definition(
131136# definitions can't contain non-argument placeholders
132137def _normalize_pt_expr (
133138 expr : DictOfNamedArrays
134- ) -> tuple [Array | AbstractResultWithNamedArrays , Mapping [str , Any ]]:
139+ ) -> tuple [DictOfNamedArrays , Mapping [str , Any ]]:
135140 """
136141 Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
137142 normalized form of *expr*, with all instances of
@@ -141,7 +146,7 @@ def _normalize_pt_expr(
141146 Deterministic naming of placeholders permits more effective caching of
142147 equivalent graphs.
143148 """
144- expr = Deduplicator () (expr )
149+ expr = deduplicate (expr )
145150
146151 if get_num_call_sites (expr ):
147152 raise NotImplementedError (
@@ -150,7 +155,7 @@ def _normalize_pt_expr(
150155
151156 normalize_mapper = _DatawrapperToBoundPlaceholderMapper ()
152157 normalized_expr = normalize_mapper (expr )
153- assert isinstance (normalized_expr , AbstractResultWithNamedArrays )
158+ assert isinstance (normalized_expr , DictOfNamedArrays )
154159 return normalized_expr , normalize_mapper .bound_arguments
155160
156161
@@ -188,6 +193,7 @@ def __init__(self, actx: ArrayContext) -> None:
188193 super ().__init__ ()
189194 self .actx = actx
190195
196+ @override
191197 def map_data_wrapper (self , expr : DataWrapper ) -> Array :
192198 import numpy as np
193199
@@ -220,6 +226,7 @@ def __init__(self, actx: ArrayContext) -> None:
220226 super ().__init__ ()
221227 self .actx = actx
222228
229+ @override
223230 def map_data_wrapper (self , expr : DataWrapper ) -> Array :
224231 import numpy as np
225232
0 commit comments