@@ -334,6 +334,9 @@ def f(self, y: Array) -> Array:
334334 def g (self , y : Array , z : Array ) -> Array :
335335 return self .f (y ) + self .f (z )
336336
337+ def w (self , y : Array ) -> bool :
338+ return bool (self ._xp .any (y ))
339+
337340
338341class B (A ):
339342 @override
@@ -348,21 +351,27 @@ def f(self, y: Array) -> Array:
348351
349352
350353lazy_xp_function ((B , "g" ))
354+ lazy_xp_function ((B , "w" ))
351355
352356
353357class TestLazyXpFunctionClasses :
354358 def test_parent_method_not_tagged (self ):
355359 assert hasattr (B .g , "_lazy_xp_function" )
356360 assert not hasattr (A .g , "_lazy_xp_function" )
357361
358- def test_lazy_xp_function_classes (self , xp ):
362+ def test_lazy_xp_function_classes (self , xp : ModuleType , library : Backend ):
359363 x = xp .asarray ([1.1 , 2.2 , 3.3 ])
360364 y = xp .asarray ([1.0 , 2.0 , 3.0 ])
361- z = xp .asarray ([3.0 , 4.0 , 5.0 ])
362- foo = B (x )
363- observed = foo .g (y , z )
364- expected = xp .asarray (44.0 )[()]
365- xp_assert_close (observed , expected )
365+ foo = A (x )
366+ bar = B (x )
367+
368+ if library .like (Backend .JAX ):
369+ with pytest .raises (
370+ TypeError , match = "Attempted boolean conversion of traced array"
371+ ):
372+ assert bar .w (y )
373+
374+ assert foo .w (y )
366375
367376
368377def dask_raises (x : Array ) -> Array :
0 commit comments