Skip to content

Commit 74675c7

Browse files
committed
more pyright fixes
1 parent 636639d commit 74675c7

5 files changed

Lines changed: 29 additions & 129 deletions

File tree

.basedpyright/baseline.json

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -11653,14 +11653,6 @@
1165311653
"lineCount": 1
1165411654
}
1165511655
},
11656-
{
11657-
"code": "reportPrivateUsage",
11658-
"range": {
11659-
"startColumn": 53,
11660-
"endColumn": 83,
11661-
"lineCount": 1
11662-
}
11663-
},
1166411656
{
1166511657
"code": "reportPrivateUsage",
1166611658
"range": {
@@ -11765,30 +11757,6 @@
1176511757
"lineCount": 1
1176611758
}
1176711759
},
11768-
{
11769-
"code": "reportArgumentType",
11770-
"range": {
11771-
"startColumn": 21,
11772-
"endColumn": 60,
11773-
"lineCount": 1
11774-
}
11775-
},
11776-
{
11777-
"code": "reportArgumentType",
11778-
"range": {
11779-
"startColumn": 49,
11780-
"endColumn": 67,
11781-
"lineCount": 2
11782-
}
11783-
},
11784-
{
11785-
"code": "reportArgumentType",
11786-
"range": {
11787-
"startColumn": 53,
11788-
"endColumn": 68,
11789-
"lineCount": 1
11790-
}
11791-
},
1179211760
{
1179311761
"code": "reportUnknownMemberType",
1179411762
"range": {
@@ -11829,14 +11797,6 @@
1182911797
"lineCount": 1
1183011798
}
1183111799
},
11832-
{
11833-
"code": "reportArgumentType",
11834-
"range": {
11835-
"startColumn": 16,
11836-
"endColumn": 58,
11837-
"lineCount": 1
11838-
}
11839-
},
1184011800
{
1184111801
"code": "reportPrivateUsage",
1184211802
"range": {
@@ -11885,22 +11845,6 @@
1188511845
"lineCount": 1
1188611846
}
1188711847
},
11888-
{
11889-
"code": "reportArgumentType",
11890-
"range": {
11891-
"startColumn": 12,
11892-
"endColumn": 51,
11893-
"lineCount": 1
11894-
}
11895-
},
11896-
{
11897-
"code": "reportArgumentType",
11898-
"range": {
11899-
"startColumn": 45,
11900-
"endColumn": 63,
11901-
"lineCount": 2
11902-
}
11903-
},
1190411848
{
1190511849
"code": "reportUnknownMemberType",
1190611850
"range": {
@@ -13045,14 +12989,6 @@
1304512989
"lineCount": 1
1304612990
}
1304712991
},
13048-
{
13049-
"code": "reportPrivateUsage",
13050-
"range": {
13051-
"startColumn": 53,
13052-
"endColumn": 83,
13053-
"lineCount": 1
13054-
}
13055-
},
1305612992
{
1305712993
"code": "reportArgumentType",
1305812994
"range": {
@@ -17283,14 +17219,6 @@
1728317219
"lineCount": 1
1728417220
}
1728517221
},
17286-
{
17287-
"code": "reportImplicitOverride",
17288-
"range": {
17289-
"startColumn": 8,
17290-
"endColumn": 24,
17291-
"lineCount": 1
17292-
}
17293-
},
1729417222
{
1729517223
"code": "reportPrivateUsage",
1729617224
"range": {
@@ -17299,22 +17227,6 @@
1729917227
"lineCount": 1
1730017228
}
1730117229
},
17302-
{
17303-
"code": "reportImplicitOverride",
17304-
"range": {
17305-
"startColumn": 8,
17306-
"endColumn": 22,
17307-
"lineCount": 1
17308-
}
17309-
},
17310-
{
17311-
"code": "reportImplicitOverride",
17312-
"range": {
17313-
"startColumn": 8,
17314-
"endColumn": 23,
17315-
"lineCount": 1
17316-
}
17317-
},
1731817230
{
1731917231
"code": "reportUnannotatedClassAttribute",
1732017232
"range": {
@@ -17331,14 +17243,6 @@
1733117243
"lineCount": 1
1733217244
}
1733317245
},
17334-
{
17335-
"code": "reportImplicitOverride",
17336-
"range": {
17337-
"startColumn": 8,
17338-
"endColumn": 24,
17339-
"lineCount": 1
17340-
}
17341-
},
1734217246
{
1734317247
"code": "reportUnannotatedClassAttribute",
1734417248
"range": {
@@ -17347,14 +17251,6 @@
1734717251
"lineCount": 1
1734817252
}
1734917253
},
17350-
{
17351-
"code": "reportImplicitOverride",
17352-
"range": {
17353-
"startColumn": 8,
17354-
"endColumn": 24,
17355-
"lineCount": 1
17356-
}
17357-
},
1735817254
{
1735917255
"code": "reportPrivateUsage",
1736017256
"range": {

arraycontext/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
587587
def outline(self,
588588
f: Callable[..., Any],
589589
*,
590-
id: Hashable | None = None) -> Callable[..., Any]:
590+
id: Hashable | None = None) -> Callable[..., Any]: # pyright: ignore[reportUnusedParameter]
591591
"""
592592
Returns a drop-in-replacement for *f*. The behavior of the returned
593593
callable is specific to the derived class.

arraycontext/impl/pytato/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def outline(self,
238238
f: Callable[..., Any],
239239
*,
240240
id: Hashable | None = None,
241-
tags: frozenset[Tag] = frozenset()
241+
tags: frozenset[Tag] = frozenset() # pyright: ignore[reportCallInDefaultInitializer]
242242
) -> Callable[..., Any]:
243243
from pytato.tags import FunctionIdentifier
244244

@@ -976,6 +976,7 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
976976
from .compile import LazilyJAXCompilingFunctionCaller
977977
return LazilyJAXCompilingFunctionCaller(self, f)
978978

979+
@override
979980
def transform_dag(self, dag: pytato.DictOfNamedArrays
980981
) -> pytato.DictOfNamedArrays:
981982
import pytato as pt

arraycontext/impl/pytato/outline.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,14 @@
4545
from arraycontext.context import (
4646
Array,
4747
ArrayOrContainer,
48-
ArrayOrContainerTc,
4948
ArrayT,
5049
)
5150
from arraycontext.impl.pytato import _BasePytatoArrayContext
5251

5352

5453
def _get_arg_id_to_arg(args: tuple[object, ...],
5554
kwargs: Mapping[str, object]
56-
) -> immutabledict[tuple[object, ...], object]:
55+
) -> immutabledict[tuple[object, ...], pt.Array]:
5756
"""
5857
Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id
5958
to argument values. See
@@ -104,7 +103,7 @@ def _get_output_arg_id_str(arg_id: tuple[object, ...]) -> str:
104103

105104

106105
def _get_arg_id_to_placeholder(
107-
arg_id_to_arg: Mapping[tuple[object, ...], object],
106+
arg_id_to_arg: Mapping[tuple[object, ...], pt.Array],
108107
prefix: str | None = None) -> immutabledict[tuple[object, ...], pt.Placeholder]:
109108
"""
110109
Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder`
@@ -122,25 +121,25 @@ def _get_arg_id_to_placeholder(
122121

123122
def _call_with_placeholders(
124123
f: Callable[..., object],
125-
args: tuple[object],
124+
args: tuple[object, ...],
126125
kwargs: Mapping[str, object],
127126
arg_id_to_placeholder: Mapping[tuple[object, ...], pt.Placeholder]) -> object:
128127
"""
129128
Construct placeholders analogous to *args* and *kwargs* and call *f*.
130129
"""
131130
def get_placeholder_replacement(
132-
arg: ArrayOrContainerTc | Scalar | None, key: tuple[object, ...]
133-
) -> ArrayOrContainerTc | Scalar | None:
131+
arg: ArrayOrContainer | Scalar | None, key: tuple[object, ...]
132+
) -> ArrayOrContainer | Scalar | None:
134133
if arg is None:
135134
return None
136135
elif np.isscalar(arg):
137136
return cast(Scalar, arg)
138137
elif isinstance(arg, pt.Array):
139-
return cast(ArrayOrContainerTc, arg_id_to_placeholder[key])
138+
return arg_id_to_placeholder[key]
140139
elif is_array_container_type(arg.__class__):
141-
def _rec_to_placeholder(keys: tuple[object, ...], ary: ArrayT) -> ArrayT:
142-
result = get_placeholder_replacement(ary, key + keys)
143-
return cast(ArrayT, result)
140+
def _rec_to_placeholder(
141+
keys: tuple[object, ...], ary: Array) -> Array:
142+
return cast("Array", get_placeholder_replacement(ary, key + keys))
144143

145144
return rec_keyed_map_array_container(_rec_to_placeholder, arg)
146145
else:
@@ -176,7 +175,7 @@ def _unpack_container(key: tuple[object, ...], ary: ArrayT) -> ArrayT:
176175

177176
def _pack_output(
178177
output_template: ArrayOrContainer,
179-
unpacked_output: Array | immutabledict[str, Array]
178+
unpacked_output: pt.Array | immutabledict[str, pt.Array]
180179
) -> ArrayOrContainer:
181180
"""
182181
Pack *unpacked_output* into array containers according to *output_template*.
@@ -187,7 +186,7 @@ def _pack_output(
187186
elif is_array_container_type(output_template.__class__):
188187
assert isinstance(unpacked_output, immutabledict)
189188

190-
def _pack_into_container(key: tuple[object, ...], ary: Array) -> Array:
189+
def _pack_into_container(key: tuple[object, ...], ary: Array) -> Array: # pyright: ignore[reportUnusedParameter]
191190
key_str = _get_output_arg_id_str(key)
192191
return unpacked_output[key_str]
193192

@@ -262,9 +261,6 @@ def __call__(self, *args: object, **kwargs: object) -> ArrayOrContainer:
262261
call_site_output = func_def(**call_bindings)
263262

264263
assert isinstance(call_site_output, pt.Array | immutabledict)
265-
# FIXME: pt.Array is not an actx Array
266-
return _pack_output(cast("Array | immutabledict[str, Array]", output),
267-
cast("Array | immutabledict[str, Array]", call_site_output))
268-
264+
return _pack_output(output, call_site_output)
269265

270266
# vim: foldmethod=marker

test/test_arraycontext.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
"""
2525

2626
import logging
27+
from collections.abc import Callable
2728
from dataclasses import dataclass
2829
from functools import partial
30+
from typing import TypeAlias
2931

3032
import numpy as np
3133
import pytest
@@ -34,6 +36,7 @@
3436
from pytools.tag import Tag
3537

3638
from arraycontext import (
39+
ArrayContext,
3740
BcastUntilActxArray,
3841
EagerJAXArrayContext,
3942
NumpyArrayContext,
@@ -58,6 +61,9 @@
5861
logger = logging.getLogger(__name__)
5962

6063

64+
ArrayContextFactory: TypeAlias = Callable[[], ArrayContext]
65+
66+
6167
# {{{ array context fixture
6268

6369
class _PyOpenCLArrayContextForTests(PyOpenCLArrayContext):
@@ -1166,15 +1172,16 @@ def my_rhs(scale, vel):
11661172
np.testing.assert_allclose(result.v, 3.14*v_x)
11671173

11681174

1169-
def test_actx_compile_with_outlined_function(actx_factory):
1175+
def test_actx_compile_with_outlined_function(actx_factory: ArrayContextFactory):
11701176
actx = actx_factory()
11711177
rng = np.random.default_rng()
11721178

11731179
@actx.outline
1174-
def outlined_scale_and_orthogonalize(alpha, vel):
1180+
def outlined_scale_and_orthogonalize(alpha: float, vel: Velocity2D) -> Velocity2D:
11751181
return scale_and_orthogonalize(alpha, vel)
11761182

1177-
def multi_scale_and_orthogonalize(alpha, vel1, vel2):
1183+
def multi_scale_and_orthogonalize(
1184+
alpha: float, vel1: Velocity2D, vel2: Velocity2D) -> np.ndarray:
11781185
return make_obj_array([
11791186
outlined_scale_and_orthogonalize(alpha, vel1),
11801187
outlined_scale_and_orthogonalize(alpha, vel2)])
@@ -1193,10 +1200,10 @@ def multi_scale_and_orthogonalize(alpha, vel1, vel2):
11931200

11941201
result1 = actx.to_numpy(scaled_speed1)
11951202
result2 = actx.to_numpy(scaled_speed2)
1196-
np.testing.assert_allclose(result1.u, -3.14*v1_y)
1197-
np.testing.assert_allclose(result1.v, 3.14*v1_x)
1198-
np.testing.assert_allclose(result2.u, -3.14*v2_y)
1199-
np.testing.assert_allclose(result2.v, 3.14*v2_x)
1203+
np.testing.assert_allclose(result1.u, -3.14*v1_y) # pyright: ignore[reportAttributeAccessIssue]
1204+
np.testing.assert_allclose(result1.v, 3.14*v1_x) # pyright: ignore[reportAttributeAccessIssue]
1205+
np.testing.assert_allclose(result2.u, -3.14*v2_y) # pyright: ignore[reportAttributeAccessIssue]
1206+
np.testing.assert_allclose(result2.v, 3.14*v2_x) # pyright: ignore[reportAttributeAccessIssue]
12001207

12011208
# }}}
12021209

0 commit comments

Comments
 (0)