@@ -249,10 +249,11 @@ def patch_lazy_xp_functions(
249249 """
250250 Test lazy execution of functions tagged with :func:`lazy_xp_function`.
251251
252- If ``xp==jax.numpy``, search for all functions which have been tagged with
253- :func:`lazy_xp_function` in the globals of the module that defines the current test,
254- as well as in the ``lazy_xp_modules`` list in the globals of the same module,
255- and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
252+ If ``xp==jax.numpy``, search for all functions and classes which have been tagged
253+ with :func:`lazy_xp_function` in the globals of the module that defines the current
254+ test, as well as in the ``lazy_xp_modules`` list in the globals of the same module,
255+ and wrap them with :func:`jax.jit`.
256+ Unwrap them at the end of the test.
256257
257258 If ``xp==dask.array``, wrap the functions with a decorator that disables
258259 ``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
@@ -296,18 +297,32 @@ def xp(request):
296297 the example above.
297298 """
298299 mod = cast (ModuleType , request .module )
299- mods = [mod , * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , []))]
300-
301- to_revert : list [tuple [ModuleType , str , object ]] = []
302-
303- def temp_setattr (mod : ModuleType , name : str , func : object ) -> None :
300+ search_targets : list [ModuleType | type ] = [
301+ mod ,
302+ * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , [])),
303+ ]
304+ # Also search for classes within the above modules which have had lazy_xp_function
305+ # applied to methods through ``lazy_xp_function((cls, method_name))`` syntax.
306+ # We might end up adding classes incidentally imported into modules, so using a
307+ # set here to cut down on potential redundancy.
308+ classes : set [type ] = set ()
309+ for target in search_targets :
310+ for obj_name in dir (target ):
311+ obj = getattr (target , obj_name )
312+ if isinstance (obj , type ) and isinstance (obj , Exception ):
313+ classes .add (obj )
314+ search_targets .extend (classes )
315+
316+ to_revert : list [tuple [ModuleType | type , str , object ]] = []
317+
318+ def temp_setattr (target : ModuleType | type , name : str , func : object ) -> None :
304319 """
305320 Variant of monkeypatch.setattr, which allows monkey-patching only selected
306321 parameters of a test so that pytest-run-parallel can run on the remainder.
307322 """
308- assert hasattr (mod , name )
309- to_revert .append ((mod , name , getattr (mod , name )))
310- setattr (mod , name , func )
323+ assert hasattr (target , name )
324+ to_revert .append ((target , name , getattr (target , name )))
325+ setattr (target , name , func )
311326
312327 if monkeypatch is not None :
313328 warnings .warn (
@@ -323,34 +338,34 @@ def temp_setattr(mod: ModuleType, name: str, func: object) -> None:
323338 temp_setattr = monkeypatch .setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
324339
325340 def iter_tagged () -> Iterator [
326- tuple [ModuleType , str , Callable [..., Any ], dict [str , Any ]]
341+ tuple [ModuleType | type , str , Callable [..., Any ], dict [str , Any ]]
327342 ]:
328- for mod in mods :
329- for name , func in mod .__dict__ .items ():
343+ for target in search_targets :
344+ for name , func in target .__dict__ .items ():
330345 tags : dict [str , Any ] | None = None
331346 with contextlib .suppress (AttributeError ):
332347 tags = func ._lazy_xp_function # pylint: disable=protected-access
333348 if tags is None :
334349 with contextlib .suppress (KeyError , TypeError ):
335350 tags = _ufuncs_tags [func ]
336351 if tags is not None :
337- yield mod , name , func , tags
352+ yield target , name , func , tags
338353
339354 if is_dask_namespace (xp ):
340- for mod , name , func , tags in iter_tagged ():
355+ for target , name , func , tags in iter_tagged ():
341356 n = tags ["allow_dask_compute" ]
342357 if n is True :
343358 n = 1_000_000
344359 elif n is False :
345360 n = 0
346361 wrapped = _dask_wrap (func , n )
347- temp_setattr (mod , name , wrapped )
362+ temp_setattr (target , name , wrapped )
348363
349364 elif is_jax_namespace (xp ):
350- for mod , name , func , tags in iter_tagged ():
365+ for target , name , func , tags in iter_tagged ():
351366 if tags ["jax_jit" ]:
352367 wrapped = jax_autojit (func )
353- temp_setattr (mod , name , wrapped )
368+ temp_setattr (target , name , wrapped )
354369
355370 # We can't just decorate patch_lazy_xp_functions with
356371 # @contextlib.contextmanager because it would not work with the
@@ -360,8 +375,8 @@ def revert_on_exit() -> Generator[None]:
360375 try :
361376 yield
362377 finally :
363- for mod , name , orig_func in to_revert :
364- setattr (mod , name , orig_func )
378+ for target , name , orig_func in to_revert :
379+ setattr (target , name , orig_func )
365380
366381 return revert_on_exit ()
367382
0 commit comments