Skip to content

Commit 2168009

Browse files
steppilucascolley
authored andcommitted
MAINT: fix monkeypatching of staticmethods and classmethods
1 parent 9107fe2 commit 2168009

1 file changed

Lines changed: 38 additions & 9 deletions

File tree

src/array_api_extra/testing.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import warnings
1212
from collections.abc import Callable, Generator, Iterator, Sequence
1313
from functools import update_wrapper, wraps
14+
from inspect import getattr_static
1415
from types import FunctionType, ModuleType
1516
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
1617

@@ -229,10 +230,12 @@ def test_myfunc(xp):
229230
# the method was inherited from a parent class.
230231
# Note: can't just accept an unbound method `cls.method_name` because in
231232
# case of inheritance it would be impossible to attribute it to the child class.
233+
# This also makes it so tagged methods will appear in their class's ``__dict__``
234+
# and thus findable by ``iter_tagged_modules`` below.
232235
cls, method_name = func
233236
# The method might be a staticmethod or classmethod so we need to do a dance
234237
# to ensure that this is preserved.
235-
raw_attr = cls.__dict__.get(method_name)
238+
raw_attr = getattr_static(cls, method_name)
236239
method = getattr(cls, method_name)
237240
cloned_method = _clone_function(method)
238241

@@ -322,8 +325,7 @@ def xp(request):
322325
# set here to cut down on potential redundancy.
323326
classes: set[type] = set()
324327
for target in search_targets:
325-
for obj_name in dir(target):
326-
obj = getattr(target, obj_name)
328+
for obj in target.__dict__.values():
327329
if isinstance(obj, type):
328330
classes.add(obj)
329331
search_targets.extend(classes)
@@ -336,7 +338,10 @@ def temp_setattr(target: ModuleType | type, name: str, func: object) -> None:
336338
parameters of a test so that pytest-run-parallel can run on the remainder.
337339
"""
338340
assert hasattr(target, name)
339-
to_revert.append((target, name, getattr(target, name)))
341+
# Need getattr_static because the attr could be a staticmethod or other
342+
# descriptor and we don't want that to be stripped away.
343+
original = getattr_static(target, name)
344+
to_revert.append((target, name, original))
340345
setattr(target, name, func)
341346

342347
if monkeypatch is not None:
@@ -353,33 +358,57 @@ def temp_setattr(target: ModuleType | type, name: str, func: object) -> None:
353358
temp_setattr = monkeypatch.setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
354359

355360
def iter_tagged() -> Iterator[
356-
tuple[ModuleType | type, str, Callable[..., Any], dict[str, Any]]
361+
tuple[ModuleType | type, str, Any, Callable[..., Any], dict[str, Any]]
357362
]:
358363
for target in search_targets:
359-
for name, func in target.__dict__.items():
364+
for name, attr in target.__dict__.items():
365+
# attr might be a staticmethod or classmethod. If so we need
366+
# to peel it back and wrap the underlying function and later
367+
# make sure not to accidentally replace it with a regular
368+
# method.
369+
func: Any = (
370+
attr.__func__
371+
if isinstance(attr, (staticmethod, classmethod))
372+
else attr
373+
)
360374
tags: dict[str, Any] | None = None
361375
with contextlib.suppress(AttributeError):
362376
tags = func._lazy_xp_function # pylint: disable=protected-access
363377
if tags is None:
364378
with contextlib.suppress(KeyError, TypeError):
365379
tags = _ufuncs_tags[func]
366380
if tags is not None:
367-
yield target, name, func, tags
381+
# put attr, and func in the outputs so we can later tell
382+
# if this was a staticmethod or classmethod.
383+
yield target, name, attr, func, tags
368384

385+
wrapped: Any
369386
if is_dask_namespace(xp):
370-
for target, name, func, tags in iter_tagged():
387+
for target, name, attr, func, tags in iter_tagged():
371388
n = tags["allow_dask_compute"]
372389
if n is True:
373390
n = 1_000_000
374391
elif n is False:
375392
n = 0
376393
wrapped = _dask_wrap(func, n)
394+
# If we're dealing with a staticmethod or classmethod, make
395+
# sure things stay that way.
396+
if isinstance(attr, staticmethod):
397+
wrapped = staticmethod(wrapped)
398+
elif isinstance(attr, classmethod):
399+
wrapped = classmethod(wrapped)
377400
temp_setattr(target, name, wrapped)
378401

379402
elif is_jax_namespace(xp):
380-
for target, name, func, tags in iter_tagged():
403+
for target, name, attr, func, tags in iter_tagged():
381404
if tags["jax_jit"]:
382405
wrapped = jax_autojit(func)
406+
# If we're dealing with a staticmethod or classmethod, make
407+
# sure things stay that way.
408+
if isinstance(attr, staticmethod):
409+
wrapped = staticmethod(wrapped)
410+
elif isinstance(attr, classmethod):
411+
wrapped = classmethod(wrapped)
383412
temp_setattr(target, name, wrapped)
384413

385414
# We can't just decorate patch_lazy_xp_functions with

0 commit comments

Comments
 (0)