Skip to content

Commit b258e19

Browse files
committed
TYP: add pyrefly
1 parent 1e1f1ab commit b258e19

9 files changed

Lines changed: 128 additions & 17 deletions

File tree

lefthook.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ pre-commit:
3939
- name: mypy
4040
glob: "*.{py,pyi}"
4141
run: pixi {run} mypy
42+
- name: pyrefly
43+
glob: "*.{py, pyi}"
44+
run: pixi {run} pyrefly
4245
- name: typos
4346
stage_fixed: true
4447
run: pixi {run} typos

pixi.lock

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ actionlint = ">=1.7.12,<2"
108108
blacken-docs = ">=1.20.0,<2"
109109
pytest = ">=9.0.2,<10"
110110
validate-pyproject = ">=0.25,<0.26"
111+
pyrefly = ">=0.61.1,<0.62"
111112
# NOTE: don't add cupy, jax, pytorch, or sparse here,
112113
# as they slow down mypy and are not portable across target OSs
113114

@@ -117,6 +118,7 @@ hooks = { cmd = "lefthook install", description = "Install pre-commit hooks" }
117118
pre-commit = { cmd = "lefthook run pre-commit", description = "Run pre-commit checks" }
118119
pylint = { cmd = "pylint array_api_extra", cwd = "src", description = "Lint with pylint" }
119120
mypy = { cmd = "mypy", description = "Type check with mypy" }
121+
pyrefly = { cmd = "pyrefly check", description = "Type check with pyrefly" }
120122
pyright = { cmd = "basedpyright", description = "Type check with basedpyright" }
121123
ruff-check = { cmd = "ruff check --fix", description = "Lint with ruff" }
122124
ruff-format = { cmd = "ruff format", description = "Format with ruff" }
@@ -257,7 +259,7 @@ run.source = ["array_api_extra"]
257259
# mypy
258260

259261
[tool.mypy]
260-
files = ["src", "tests"]
262+
files = ["src", "tests", "vendor_tests"]
261263
python_version = "3.11"
262264
warn_unused_configs = true
263265
strict = true
@@ -273,10 +275,46 @@ ignore_missing_imports = true
273275
module = ["tests/*"]
274276
disable_error_code = ["no-untyped-def"] # test(...) without -> None
275277

278+
[[tool.mypy.overrides]]
279+
module = ["vendor_tests/*"]
280+
disable_error_code = ["no-untyped-def"] # test(...) without -> None
281+
282+
[[tool.mypy.overrides]]
283+
module = ["vendor_tests/array_api_compat/*"]
284+
ignore_errors = true
285+
286+
# pyrefly
287+
288+
[tool.pyrefly.errors]
289+
# Redundant with mypy checks
290+
missing-import = false
291+
# extra checks from scipy/scipy-stubs
292+
implicit-abstract-class = "error"
293+
implicitly-defined-attribute = "error"
294+
missing-override-decorator = "error"
295+
missing-source = "ignore"
296+
not-required-key-access = "error"
297+
open-unpacking = "error"
298+
unannotated-attribute = "error"
299+
unannotated-parameter = "error"
300+
unannotated-return = "error"
301+
untyped-import = "error"
302+
unused-ignore = "error"
303+
variance-mismatch = "error"
304+
305+
[[tool.pyrefly.sub-config]]
306+
matches = "tests/*.py"
307+
errors = { unannotated-return = false }
308+
309+
[[tool.pyrefly.sub-config]]
310+
matches = "vendor_tests/*.py"
311+
errors = { unannotated-return = false }
312+
276313
# pyright
277314

278315
[tool.basedpyright]
279-
include = ["src", "tests"]
316+
include = ["src", "tests", "vendor_tests"]
317+
exclude = ["vendor_tests/array_api_compat"]
280318
pythonVersion = "3.11"
281319
pythonPlatform = "All"
282320
typeCheckingMode = "all"

src/array_api_extra/_lib/_at.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class _AtOp(Enum):
3737
MAX = "max"
3838

3939
# @override from Python 3.12
40-
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride]
40+
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride] # pyrefly: ignore[missing-override-decorator]
4141
"""
4242
Return string representation (useful for pytest logs).
4343

src/array_api_extra/_lib/_lazy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
P = ParamSpec("P")
3131

3232

33-
@overload
33+
@overload # pyrefly: ignore[invalid-param-spec]
3434
def lazy_apply( # type: ignore[valid-type]
3535
func: Callable[P, Array | ArrayLike],
3636
*args: Array | complex | None,
@@ -42,7 +42,7 @@ def lazy_apply( # type: ignore[valid-type]
4242
) -> Array: ... # numpydoc ignore=GL08
4343

4444

45-
@overload
45+
@overload # pyrefly: ignore[invalid-param-spec]
4646
def lazy_apply( # type: ignore[valid-type]
4747
func: Callable[P, Sequence[Array | ArrayLike]],
4848
*args: Array | complex | None,
@@ -54,7 +54,7 @@ def lazy_apply( # type: ignore[valid-type]
5454
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08
5555

5656

57-
def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
57+
def lazy_apply( # type: ignore[valid-type] # pyrefly: ignore[invalid-param-spec] # numpydoc ignore=GL07,SA04
5858
func: Callable[P, Array | ArrayLike | Sequence[Array | ArrayLike]],
5959
*args: Array | complex | None,
6060
shape: tuple[int | None, ...] | Sequence[tuple[int | None, ...]] | None = None,
@@ -240,7 +240,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
240240
if is_dask_namespace(xp):
241241
import dask
242242

243-
metas: list[Array] = [arg._meta for arg in array_args] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
243+
metas: list[Array] = [arg._meta for arg in array_args] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue] # pyrefly: ignore[missing-attribute]
244244
meta_xp = array_namespace(*metas)
245245

246246
wrapped = dask.delayed( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]

src/array_api_extra/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def test_myfunc(xp):
259259
f = func
260260

261261
try:
262-
f._lazy_xp_function = tags # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
262+
f._lazy_xp_function = tags # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess] # pyrefly: ignore[missing-attribute]
263263
except AttributeError: # @cython.vectorize
264264
_ufuncs_tags[f] = tags
265265

@@ -461,7 +461,7 @@ class CountingDaskScheduler(SchedulerGetCallable):
461461
max_count: int
462462
msg: str
463463

464-
def __init__(self, max_count: int, msg: str): # numpydoc ignore=GL08
464+
def __init__(self, max_count: int, msg: str) -> None: # numpydoc ignore=GL08
465465
self.count = 0
466466
self.max_count = max_count
467467
self.msg = msg

tests/test_lazy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,10 @@ def test_lazy_apply_none_shape_in_args(xp: ModuleType, library: Backend):
219219
mxp = np if library is Backend.DASK else xp
220220
int_type = xp.asarray(0).dtype
221221

222-
ctx: contextlib.AbstractContextManager[object]
222+
ctx: (
223+
contextlib.AbstractContextManager[object]
224+
| contextlib.AbstractContextManager[None]
225+
)
223226
if library.like(Backend.JAX):
224227
ctx = pytest.raises(ValueError, match="Output shape must be fully known")
225228
elif library is Backend.ARRAY_API_STRICTEST:
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""This file is a hook imported by `src/array_api_extra/_lib/_compat.py`."""
22
# pyright: reportUnknownParameterType=false, reportMissingParameterType=false
33

4-
from .array_api_compat import * # noqa: F403
4+
from types import ModuleType
5+
from typing import Any
6+
7+
from .array_api_compat import * # noqa: F403 # pyright: ignore[reportAssignmentType]
58
from .array_api_compat import array_namespace as array_namespace_compat
69

710

811
# Let unit tests check with `is` that we are picking up the function from this module
912
# and not from the original array_api_compat module.
10-
def array_namespace(*xs, **kwargs): # numpydoc ignore=GL08
11-
return array_namespace_compat(*xs, **kwargs)
13+
def array_namespace(*xs: Any | complex | None, **kwargs) -> ModuleType: # type: ignore[no-redef] # numpydoc ignore=GL08
14+
return array_namespace_compat(*xs, **kwargs) # pyright: ignore[reportUnknownArgumentType]

vendor_tests/test_vendor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# pyright: reportAttributeAccessIssue=false
22

3-
from typing import Any
3+
from typing import Any, cast
44

55
import array_api_strict as xp
66
from numpy.testing import assert_array_equal
77

8+
from vendor_tests.array_api_compat.common._typing import Array
9+
810

911
def test_vendor_compat():
10-
from ._array_api_compat_vendor import ( # type: ignore[attr-defined]
12+
from ._array_api_compat_vendor import (
1113
array_namespace,
1214
device,
1315
is_array_api_obj,
@@ -35,6 +37,7 @@ def test_vendor_compat():
3537
to_device(x, device(x))
3638
assert is_array_api_obj(x)
3739
assert is_array_api_strict_namespace(xp)
40+
x = cast(Array, x)
3841
assert not is_cupy_array(x)
3942
assert not is_cupy_namespace(xp)
4043
assert not is_dask_array(x)
@@ -56,8 +59,8 @@ def test_vendor_extra():
5659
from .array_api_extra import atleast_nd
5760

5861
x = xp.asarray(1)
59-
y = atleast_nd(x, ndim=0)
60-
assert_array_equal(y, x) # pyright: ignore[reportUnknownArgumentType]
62+
y = atleast_nd(x, ndim=0) # type: ignore[arg-type]
63+
assert_array_equal(y, x)
6164

6265

6366
def test_vendor_extra_testing():

0 commit comments

Comments
 (0)