@@ -334,7 +334,7 @@ def f(self, y: Array) -> Array:
334334 def g (self , y : Array , z : Array ) -> Array :
335335 return self .f (y ) + self .f (z )
336336
337- def w (self , y : Array ) -> bool :
337+ def h (self , y : Array ) -> bool :
338338 return bool (self ._xp .any (y ))
339339
340340
@@ -361,6 +361,14 @@ def j(y: Array) -> "B":
361361 return B (y )
362362 return B (y + 1.0 )
363363
364+ @classmethod
365+ def w (cls , y : Array ) -> "B" :
366+ xp = array_namespace (y )
367+ y = xp .asarray (y )
368+ if bool (xp .any (y )):
369+ return B (y )
370+ return B (y + 1.0 )
371+
364372
365373@final
366374class eager :
@@ -369,9 +377,10 @@ class eager:
369377
370378
371379lazy_xp_function ((B , "g" ))
372- lazy_xp_function ((B , "w " ))
380+ lazy_xp_function ((B , "h " ))
373381lazy_xp_function ((B , "k" ))
374382lazy_xp_function ((B , "j" ))
383+ lazy_xp_function ((B , "w" ))
375384
376385
377386class TestLazyXpFunctionClasses :
@@ -393,9 +402,9 @@ def test_lazy_xp_function_classes(self, xp: ModuleType, library: Backend):
393402 with pytest .raises (
394403 TypeError , match = "Attempted boolean conversion of traced array"
395404 ):
396- assert bar .w (y )
405+ assert bar .h (y )
397406
398- assert foo .w (y )
407+ assert foo .h (y )
399408
400409 def test_static_methods_preserved (self , xp : ModuleType ):
401410 # Tests that static methods stay static methods when
@@ -418,6 +427,17 @@ def test_static_methods_wrapped(self, xp: ModuleType, library: Backend):
418427 else :
419428 assert isinstance (foo .j (x ), B )
420429
430+ @pytest .mark .skip_xp_backend (Backend .DASK , reason = "calls dask.compute()" )
431+ def test_class_methods_wrapped (self , xp : ModuleType , library : Backend ):
432+ x = xp .asarray ([1.1 , 2.2 , 3.3 ])
433+ if library .like (Backend .JAX ):
434+ with pytest .raises (
435+ TypeError , match = "Attempted boolean conversion of traced array"
436+ ):
437+ assert isinstance (B .w (x ), B )
438+ else :
439+ assert isinstance (B .w (x ), B )
440+
421441 def test_circumvention (self , xp : ModuleType ):
422442 x = xp .asarray ([1.0 , 2.0 ])
423443 y = eager .non_materializable5 (x )
0 commit comments