Skip to content

Commit 01ac205

Browse files
committed
ENH: Add better lazy_xp_function support for class methods
1 parent 7333098 commit 01ac205

1 file changed

Lines changed: 29 additions & 5 deletions

File tree

src/array_api_extra/testing.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import enum
1111
import warnings
1212
from collections.abc import Callable, Generator, Iterator, Sequence
13-
from functools import wraps
14-
from types import ModuleType
13+
from functools import update_wrapper, wraps
14+
from types import FunctionType, ModuleType
1515
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
1616

1717
from ._lib._utils._compat import is_dask_namespace, is_jax_namespace
@@ -48,8 +48,22 @@ class Deprecated(enum.Enum):
4848
DEPRECATED = Deprecated.DEPRECATED
4949

5050

51+
def _clone_function(f):
52+
"""Returns a clone of an existing function."""
53+
f_new = FunctionType(
54+
f.__code__,
55+
f.__globals__,
56+
name=f.__name__,
57+
argdefs=f.__defaults__,
58+
closure=f.__closure__,
59+
)
60+
f_new.__kwdefaults__ = f.__kwdefaults__
61+
update_wrapper(f_new, f)
62+
return f_new
63+
64+
5165
def lazy_xp_function(
52-
func: Callable[..., Any],
66+
func: Callable[..., Any] | Tuple[type, str],
5367
*,
5468
allow_dask_compute: bool | int = False,
5569
jax_jit: bool = True,
@@ -69,8 +83,9 @@ def lazy_xp_function(
6983
7084
Parameters
7185
----------
72-
func : callable
73-
Function to be tested.
86+
func : callable | tuple[type, str]
87+
Function to be tested, or a tuple containing an (uninstantiated) class and a
88+
method name to specify a class method to be tested.
7489
allow_dask_compute : bool | int, optional
7590
Whether `func` is allowed to internally materialize the Dask graph, or maximum
7691
number of times it is allowed to do so. This is typically triggered by
@@ -209,6 +224,15 @@ def test_myfunc(xp):
209224
"jax_jit": jax_jit,
210225
}
211226

227+
if isinstance(func, tuple):
228+
# Replace the method with a clone before adding tags
229+
# to avoid adding unwanted tags to a parent method when
230+
# the method was inherited from a parent class.
231+
cls, method_name = func
232+
method = getattr(cls, method_name)
233+
setattr(cls, method_name, _clone_function(method))
234+
func = getattr(cls, method_name)
235+
212236
try:
213237
func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
214238
except AttributeError: # @cython.vectorize

0 commit comments

Comments
 (0)