1111import warnings
1212from collections .abc import Callable , Generator , Iterator , Sequence
1313from functools import update_wrapper , wraps
14+ from inspect import getattr_static
1415from types import FunctionType , ModuleType
1516from typing import TYPE_CHECKING , Any , ParamSpec , TypeVar , cast
1617
@@ -229,10 +230,12 @@ def test_myfunc(xp):
229230 # the method was inherited from a parent class.
230231 # Note: can't just accept an unbound method `cls.method_name` because in
231232 # case of inheritance it would be impossible to attribute it to the child class.
233+ # This also makes it so tagged methods will appear in their class's ``__dict__``
234+ # and thus findable by ``iter_tagged_modules`` below.
232235 cls , method_name = func
233236 # The method might be a staticmethod or classmethod so we need to do a dance
234237 # to ensure that this is preserved.
235- raw_attr = cls . __dict__ . get ( method_name )
238+ raw_attr = getattr_static ( cls , method_name )
236239 method = getattr (cls , method_name )
237240 cloned_method = _clone_function (method )
238241
@@ -322,8 +325,7 @@ def xp(request):
322325 # set here to cut down on potential redundancy.
323326 classes : set [type ] = set ()
324327 for target in search_targets :
325- for obj_name in dir (target ):
326- obj = getattr (target , obj_name )
328+ for obj in target .__dict__ .values ():
327329 if isinstance (obj , type ):
328330 classes .add (obj )
329331 search_targets .extend (classes )
@@ -336,7 +338,10 @@ def temp_setattr(target: ModuleType | type, name: str, func: object) -> None:
336338 parameters of a test so that pytest-run-parallel can run on the remainder.
337339 """
338340 assert hasattr (target , name )
339- to_revert .append ((target , name , getattr (target , name )))
341+ # Need getattr_static because the attr could be a staticmethod or other
342+ # descriptor and we don't want that to be stripped away.
343+ original = getattr_static (target , name )
344+ to_revert .append ((target , name , original ))
340345 setattr (target , name , func )
341346
342347 if monkeypatch is not None :
@@ -353,33 +358,57 @@ def temp_setattr(target: ModuleType | type, name: str, func: object) -> None:
353358 temp_setattr = monkeypatch .setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
354359
355360 def iter_tagged () -> Iterator [
356- tuple [ModuleType | type , str , Callable [..., Any ], dict [str , Any ]]
361+ tuple [ModuleType | type , str , Any , Callable [..., Any ], dict [str , Any ]]
357362 ]:
358363 for target in search_targets :
359- for name , func in target .__dict__ .items ():
364+ for name , attr in target .__dict__ .items ():
365+ # attr might be a staticmethod or classmethod. If so we need
366+ # to peel it back and wrap the underlying function and later
367+ # make sure not to accidentally replace it with a regular
368+ # method.
369+ func : Any = (
370+ attr .__func__
371+ if isinstance (attr , (staticmethod , classmethod ))
372+ else attr
373+ )
360374 tags : dict [str , Any ] | None = None
361375 with contextlib .suppress (AttributeError ):
362376 tags = func ._lazy_xp_function # pylint: disable=protected-access
363377 if tags is None :
364378 with contextlib .suppress (KeyError , TypeError ):
365379 tags = _ufuncs_tags [func ]
366380 if tags is not None :
367- yield target , name , func , tags
381+ # put attr, and func in the outputs so we can later tell
382+ # if this was a staticmethod or classmethod.
383+ yield target , name , attr , func , tags
368384
385+ wrapped : Any
369386 if is_dask_namespace (xp ):
370- for target , name , func , tags in iter_tagged ():
387+ for target , name , attr , func , tags in iter_tagged ():
371388 n = tags ["allow_dask_compute" ]
372389 if n is True :
373390 n = 1_000_000
374391 elif n is False :
375392 n = 0
376393 wrapped = _dask_wrap (func , n )
394+ # If we're dealing with a staticmethod or classmethod, make
395+ # sure things stay that way.
396+ if isinstance (attr , staticmethod ):
397+ wrapped = staticmethod (wrapped )
398+ elif isinstance (attr , classmethod ):
399+ wrapped = classmethod (wrapped )
377400 temp_setattr (target , name , wrapped )
378401
379402 elif is_jax_namespace (xp ):
380- for target , name , func , tags in iter_tagged ():
403+ for target , name , attr , func , tags in iter_tagged ():
381404 if tags ["jax_jit" ]:
382405 wrapped = jax_autojit (func )
406+ # If we're dealing with a staticmethod or classmethod, make
407+ # sure things stay that way.
408+ if isinstance (attr , staticmethod ):
409+ wrapped = staticmethod (wrapped )
410+ elif isinstance (attr , classmethod ):
411+ wrapped = classmethod (wrapped )
383412 temp_setattr (target , name , wrapped )
384413
385414 # We can't just decorate patch_lazy_xp_functions with
0 commit comments