Skip to content

Commit 82016e3

Browse files
committed
MAINT: Fix typing issues
1 parent 86461cc commit 82016e3

2 files changed

Lines changed: 24 additions & 20 deletions

File tree

src/array_api_extra/testing.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Deprecated(enum.Enum):
4848
DEPRECATED = Deprecated.DEPRECATED
4949

5050

51-
def _clone_function(f):
51+
def _clone_function(f: Callable[..., Any]) -> Callable[..., Any]:
5252
"""Returns a clone of an existing function."""
5353
f_new = FunctionType(
5454
f.__code__,
@@ -58,12 +58,11 @@ def _clone_function(f):
5858
closure=f.__closure__,
5959
)
6060
f_new.__kwdefaults__ = f.__kwdefaults__
61-
update_wrapper(f_new, f)
62-
return f_new
61+
return update_wrapper(f_new, f)
6362

6463

6564
def lazy_xp_function(
66-
func: Callable[..., Any] | Tuple[type, str],
65+
func: Callable[..., Any] | tuple[type, str],
6766
*,
6867
allow_dask_compute: bool | int = False,
6968
jax_jit: bool = True,
@@ -231,12 +230,14 @@ def test_myfunc(xp):
231230
cls, method_name = func
232231
method = getattr(cls, method_name)
233232
setattr(cls, method_name, _clone_function(method))
234-
func = getattr(cls, method_name)
233+
f = getattr(cls, method_name)
234+
else:
235+
f = func
235236

236237
try:
237-
func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
238+
f._lazy_xp_function = tags # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
238239
except AttributeError: # @cython.vectorize
239-
_ufuncs_tags[func] = tags
240+
_ufuncs_tags[f] = tags
240241

241242

242243
def patch_lazy_xp_functions(

tests/test_testing.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from collections.abc import Callable, Iterator
22
from types import ModuleType
3-
from typing import cast
3+
from typing import Any, cast
44

55
import numpy as np
66
import pytest
7+
from typing_extensions import override
78

89
from array_api_extra._lib._backends import Backend
910
from array_api_extra._lib._testing import (
@@ -322,32 +323,34 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
322323

323324

324325
class A:
325-
def __init__(self, x):
326+
def __init__(self, x: Array):
326327
xp = array_namespace(x)
327-
self._xp = xp
328-
self.x = np.asarray(x)
328+
self._xp: ModuleType = xp
329+
self.x: Any = np.asarray(x)
329330

330-
def f(self, y):
331-
y = np.asarray(y)
332-
return self._xp.asarray(np.matmul(self.x, y))
331+
def f(self, y: Array) -> Array:
332+
return self._xp.asarray(np.matmul(self.x, np.asarray(y)))
333333

334-
def g(self, y, z):
334+
def g(self, y: Array, z: Array) -> Array:
335335
return self.f(y) + self.f(z)
336336

337337

338338
class B(A):
339-
def __init__(self, x):
339+
@override
340+
def __init__(self, x: Array): # pyright: ignore[reportMissingSuperCall]
340341
xp = array_namespace(x)
341-
self._xp = xp
342-
self.x = xp.asarray(x)
342+
self._xp: ModuleType = xp
343+
self.x: Any = xp.asarray(x)
343344

344-
def f(self, y):
345+
@override
346+
def f(self, y: Array) -> Array:
345347
return self._xp.matmul(self.x, y)
346348

347349

348350
lazy_xp_function((B, "g"))
349351

350-
def test_lazy_xp_function_class_inheritance(xp: ModuleType):
352+
353+
def test_lazy_xp_function_class_inheritance():
351354
assert hasattr(B.g, "_lazy_xp_function")
352355
assert not hasattr(A.g, "_lazy_xp_function")
353356

0 commit comments

Comments
 (0)