Skip to content

Commit 7f3320f

Browse files
committed
ENH: make patch_lazy_xp_functions check classes within modules
1 parent 82016e3 commit 7f3320f

1 file changed

Lines changed: 37 additions & 22 deletions

File tree

src/array_api_extra/testing.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,11 @@ def patch_lazy_xp_functions(
249249
"""
250250
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
251251
252-
If ``xp==jax.numpy``, search for all functions which have been tagged with
253-
:func:`lazy_xp_function` in the globals of the module that defines the current test,
254-
as well as in the ``lazy_xp_modules`` list in the globals of the same module,
255-
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
252+
If ``xp==jax.numpy``, search for all functions and classes which have been tagged
253+
with :func:`lazy_xp_function` in the globals of the module that defines the current
254+
test, as well as in the ``lazy_xp_modules`` list in the globals of the same module,
255+
and wrap them with :func:`jax.jit`.
256+
Unwrap them at the end of the test.
256257
257258
If ``xp==dask.array``, wrap the functions with a decorator that disables
258259
``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
@@ -296,18 +297,32 @@ def xp(request):
296297
the example above.
297298
"""
298299
mod = cast(ModuleType, request.module)
299-
mods = [mod, *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", []))]
300-
301-
to_revert: list[tuple[ModuleType, str, object]] = []
302-
303-
def temp_setattr(mod: ModuleType, name: str, func: object) -> None:
300+
search_targets: list[ModuleType | type] = [
301+
mod,
302+
*cast(list[ModuleType], getattr(mod, "lazy_xp_modules", [])),
303+
]
304+
# Also search for classes within the above modules which have had lazy_xp_function
305+
# applied to methods through ``lazy_xp_function((cls, method_name))`` syntax.
306+
# We might end up adding classes incidentally imported into modules, so using a
307+
# set here to cut down on potential redundancy.
308+
classes: set[type] = set()
309+
for target in search_targets:
310+
for obj_name in dir(target):
311+
obj = getattr(target, obj_name)
312+
if isinstance(obj, type) and isinstance(obj, Exception):
313+
classes.add(obj)
314+
search_targets.extend(classes)
315+
316+
to_revert: list[tuple[ModuleType | type, str, object]] = []
317+
318+
def temp_setattr(target: ModuleType | type, name: str, func: object) -> None:
304319
"""
305320
Variant of monkeypatch.setattr, which allows monkey-patching only selected
306321
parameters of a test so that pytest-run-parallel can run on the remainder.
307322
"""
308-
assert hasattr(mod, name)
309-
to_revert.append((mod, name, getattr(mod, name)))
310-
setattr(mod, name, func)
323+
assert hasattr(target, name)
324+
to_revert.append((target, name, getattr(target, name)))
325+
setattr(target, name, func)
311326

312327
if monkeypatch is not None:
313328
warnings.warn(
@@ -323,34 +338,34 @@ def temp_setattr(mod: ModuleType, name: str, func: object) -> None:
323338
temp_setattr = monkeypatch.setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
324339

325340
def iter_tagged() -> Iterator[
326-
tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]
341+
tuple[ModuleType | type, str, Callable[..., Any], dict[str, Any]]
327342
]:
328-
for mod in mods:
329-
for name, func in mod.__dict__.items():
343+
for target in search_targets:
344+
for name, func in target.__dict__.items():
330345
tags: dict[str, Any] | None = None
331346
with contextlib.suppress(AttributeError):
332347
tags = func._lazy_xp_function # pylint: disable=protected-access
333348
if tags is None:
334349
with contextlib.suppress(KeyError, TypeError):
335350
tags = _ufuncs_tags[func]
336351
if tags is not None:
337-
yield mod, name, func, tags
352+
yield target, name, func, tags
338353

339354
if is_dask_namespace(xp):
340-
for mod, name, func, tags in iter_tagged():
355+
for target, name, func, tags in iter_tagged():
341356
n = tags["allow_dask_compute"]
342357
if n is True:
343358
n = 1_000_000
344359
elif n is False:
345360
n = 0
346361
wrapped = _dask_wrap(func, n)
347-
temp_setattr(mod, name, wrapped)
362+
temp_setattr(target, name, wrapped)
348363

349364
elif is_jax_namespace(xp):
350-
for mod, name, func, tags in iter_tagged():
365+
for target, name, func, tags in iter_tagged():
351366
if tags["jax_jit"]:
352367
wrapped = jax_autojit(func)
353-
temp_setattr(mod, name, wrapped)
368+
temp_setattr(target, name, wrapped)
354369

355370
# We can't just decorate patch_lazy_xp_functions with
356371
# @contextlib.contextmanager because it would not work with the
@@ -360,8 +375,8 @@ def revert_on_exit() -> Generator[None]:
360375
try:
361376
yield
362377
finally:
363-
for mod, name, orig_func in to_revert:
364-
setattr(mod, name, orig_func)
378+
for target, name, orig_func in to_revert:
379+
setattr(target, name, orig_func)
365380

366381
return revert_on_exit()
367382

0 commit comments

Comments
 (0)