Skip to content

Commit e692ac3

Browse files
committed
fix some pyright errors
1 parent 186c25f commit e692ac3

4 files changed

Lines changed: 23 additions & 16 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from typing import TYPE_CHECKING, Any
5959

6060
import numpy as np
61+
from typing_extensions import override
6162

6263
from pytools import memoize_method
6364
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags
@@ -232,6 +233,7 @@ def get_target(self):
232233

233234
# }}}
234235

236+
@override
235237
def outline(self,
236238
f: Callable[..., Any],
237239
*,
@@ -594,8 +596,7 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
594596
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
595597
key_to_pt_arrays)
596598

597-
pt_dict_of_named_arrays = pt.transform.Deduplicator()(
598-
pt_dict_of_named_arrays)
599+
pt_dict_of_named_arrays = pt.deduplicate(pt_dict_of_named_arrays)
599600

600601
# FIXME: Remove this if/when _normalize_pt_expr gets support for functions
601602
pt_dict_of_named_arrays = pt.tag_all_calls_to_be_inlined(

arraycontext/impl/pytato/compile.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ def _to_input_for_compiled(
189189
"""
190190
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
191191
if isinstance(ary, pt.Array):
192-
dag = pt.transform.Deduplicator()(
193-
pt.make_dict_of_named_arrays({"_actx_out": ary}))
192+
dag = pt.deduplicate(pt.make_dict_of_named_arrays({"_actx_out": ary}))
194193
# Transform the DAG to give metadata inference a chance to do its job
195194
return actx.transform_dag(dag)["_actx_out"].expr
196195
elif isinstance(ary, TaggableCLArray):

arraycontext/impl/pytato/outline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def get_placeholder_replacement(
133133
return arg_id_to_placeholder[key]
134134
elif is_array_container_type(arg.__class__):
135135
def _rec_to_placeholder(keys: tuple[Any, ...], ary: pt.Array) -> pt.Array:
136-
return get_placeholder_replacement(ary, key + keys)
136+
result = get_placeholder_replacement(ary, key + keys)
137+
assert isinstance(result, pt.Array)
138+
return result
137139

138140
return rec_keyed_map_array_container(_rec_to_placeholder, arg)
139141
else:
@@ -206,9 +208,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
206208

207209
prefixed_output = _call_with_placeholders(
208210
self.f, args, kwargs, arg_id_to_prefixed_placeholder)
209-
unpacked_prefixed_output = pt.transform.Deduplicator()(
210-
pt.make_dict_of_named_arrays(
211-
_unpack_output(prefixed_output)))
211+
unpacked_prefixed_output = pt.deduplicate(
212+
pt.make_dict_of_named_arrays(_unpack_output(prefixed_output)))
212213

213214
prefixed_placeholders = frozenset(
214215
arg_id_to_prefixed_placeholder.values())
@@ -225,9 +226,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
225226
arg_id_to_placeholder = _get_arg_id_to_placeholder(arg_id_to_arg)
226227

227228
output = _call_with_placeholders(self.f, args, kwargs, arg_id_to_placeholder)
228-
unpacked_output = pt.transform.Deduplicator()(
229-
pt.make_dict_of_named_arrays(
230-
_unpack_output(output)))
229+
unpacked_output = pt.deduplicate(
230+
pt.make_dict_of_named_arrays(_unpack_output(output)))
231231
if len(unpacked_output) == 1 and "_" in unpacked_output:
232232
ret_type = pt.function.ReturnType.ARRAY
233233
else:

arraycontext/impl/pytato/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@
4141
from collections.abc import Mapping
4242
from typing import TYPE_CHECKING, Any, cast
4343

44+
from typing_extensions import override
45+
4446
import pytools
4547
from pytato.analysis import get_num_call_sites
4648
from pytato.array import (
47-
AbstractResultWithNamedArrays,
4849
Array,
4950
Axis as PtAxis,
5051
DataWrapper,
@@ -58,8 +59,8 @@
5859
from pytato.transform import (
5960
ArrayOrNames,
6061
CopyMapper,
61-
Deduplicator,
6262
TransformMapperCache,
63+
deduplicate,
6364
)
6465
from pytools import UniqueNameGenerator, memoize_method
6566

@@ -95,6 +96,7 @@ def __init__(
9596
self.vng = UniqueNameGenerator()
9697
self.seen_inputs: set[str] = set()
9798

99+
@override
98100
def map_data_wrapper(self, expr: DataWrapper) -> Array:
99101
if expr.name is not None:
100102
if expr.name in self.seen_inputs:
@@ -114,13 +116,16 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
114116
axes=expr.axes,
115117
tags=expr.tags)
116118

119+
@override
117120
def map_size_param(self, expr: SizeParam) -> Array:
118121
raise NotImplementedError
119122

123+
@override
120124
def map_placeholder(self, expr: Placeholder) -> Array:
121125
raise ValueError("Placeholders cannot appear in"
122126
" DatawrapperToBoundPlaceholderMapper.")
123127

128+
@override
124129
def map_function_definition(
125130
self, expr: FunctionDefinition) -> FunctionDefinition:
126131
raise ValueError("Function definitions cannot appear in"
@@ -131,7 +136,7 @@ def map_function_definition(
131136
# definitions can't contain non-argument placeholders
132137
def _normalize_pt_expr(
133138
expr: DictOfNamedArrays
134-
) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]:
139+
) -> tuple[DictOfNamedArrays, Mapping[str, Any]]:
135140
"""
136141
Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
137142
normalized form of *expr*, with all instances of
@@ -141,7 +146,7 @@ def _normalize_pt_expr(
141146
Deterministic naming of placeholders permits more effective caching of
142147
equivalent graphs.
143148
"""
144-
expr = Deduplicator()(expr)
149+
expr = deduplicate(expr)
145150

146151
if get_num_call_sites(expr):
147152
raise NotImplementedError(
@@ -150,7 +155,7 @@ def _normalize_pt_expr(
150155

151156
normalize_mapper = _DatawrapperToBoundPlaceholderMapper()
152157
normalized_expr = normalize_mapper(expr)
153-
assert isinstance(normalized_expr, AbstractResultWithNamedArrays)
158+
assert isinstance(normalized_expr, DictOfNamedArrays)
154159
return normalized_expr, normalize_mapper.bound_arguments
155160

156161

@@ -188,6 +193,7 @@ def __init__(self, actx: ArrayContext) -> None:
188193
super().__init__()
189194
self.actx = actx
190195

196+
@override
191197
def map_data_wrapper(self, expr: DataWrapper) -> Array:
192198
import numpy as np
193199

@@ -220,6 +226,7 @@ def __init__(self, actx: ArrayContext) -> None:
220226
super().__init__()
221227
self.actx = actx
222228

229+
@override
223230
def map_data_wrapper(self, expr: DataWrapper) -> Array:
224231
import numpy as np
225232

0 commit comments

Comments
 (0)