1010import enum
1111import warnings
1212from collections .abc import Callable , Generator , Iterator , Sequence
13- from functools import wraps
14- from types import ModuleType
13+ from functools import update_wrapper , wraps
14+ from types import FunctionType , ModuleType
1515from typing import TYPE_CHECKING , Any , ParamSpec , TypeVar , cast
1616
1717from ._lib ._utils ._compat import is_dask_namespace , is_jax_namespace
@@ -48,8 +48,22 @@ class Deprecated(enum.Enum):
4848DEPRECATED = Deprecated .DEPRECATED
4949
5050
51+ def _clone_function (f ):
52+ """Returns a clone of an existing function."""
53+ f_new = FunctionType (
54+ f .__code__ ,
55+ f .__globals__ ,
56+ name = f .__name__ ,
57+ argdefs = f .__defaults__ ,
58+ closure = f .__closure__ ,
59+ )
60+ f_new .__kwdefaults__ = f .__kwdefaults__
61+ update_wrapper (f_new , f )
62+ return f_new
63+
64+
5165def lazy_xp_function (
52- func : Callable [..., Any ],
66+ func : Callable [..., Any ] | Tuple [ type , str ] ,
5367 * ,
5468 allow_dask_compute : bool | int = False ,
5569 jax_jit : bool = True ,
@@ -69,8 +83,9 @@ def lazy_xp_function(
6983
7084 Parameters
7185 ----------
72- func : callable
73- Function to be tested.
86+ func : callable | tuple[type, str]
87+ Function to be tested, or a tuple containing an (uninstantiated) class and a
88+ method name to specify a class method to be tested.
7489 allow_dask_compute : bool | int, optional
7590 Whether `func` is allowed to internally materialize the Dask graph, or maximum
7691 number of times it is allowed to do so. This is typically triggered by
@@ -209,6 +224,15 @@ def test_myfunc(xp):
209224 "jax_jit" : jax_jit ,
210225 }
211226
227+ if isinstance (func , tuple ):
228+ # Replace the method with a clone before adding tags
229+ # to avoid adding unwanted tags to a parent method when
230+ # the method was inherited from a parent class.
231+ cls , method_name = func
232+ method = getattr (cls , method_name )
233+ setattr (cls , method_name , _clone_function (method ))
234+ func = getattr (cls , method_name )
235+
212236 try :
213237 func ._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
214238 except AttributeError : # @cython.vectorize
0 commit comments