Skip to content

Commit 49bfbdc

Browse files
committed
TST: make test laxy xp classes actually test function is wrapped
1 parent ffeb1f2 commit 49bfbdc

1 file changed

Lines changed: 15 additions & 6 deletions

File tree

tests/test_testing.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,9 @@ def f(self, y: Array) -> Array:
334334
def g(self, y: Array, z: Array) -> Array:
335335
return self.f(y) + self.f(z)
336336

337+
def w(self, y: Array) -> bool:
338+
return bool(self._xp.any(y))
339+
337340

338341
class B(A):
339342
@override
@@ -348,21 +351,27 @@ def f(self, y: Array) -> Array:
348351

349352

350353
lazy_xp_function((B, "g"))
354+
lazy_xp_function((B, "w"))
351355

352356

353357
class TestLazyXpFunctionClasses:
354358
def test_parent_method_not_tagged(self):
355359
assert hasattr(B.g, "_lazy_xp_function")
356360
assert not hasattr(A.g, "_lazy_xp_function")
357361

358-
def test_lazy_xp_function_classes(self, xp):
362+
def test_lazy_xp_function_classes(self, xp: ModuleType, library: Backend):
359363
x = xp.asarray([1.1, 2.2, 3.3])
360364
y = xp.asarray([1.0, 2.0, 3.0])
361-
z = xp.asarray([3.0, 4.0, 5.0])
362-
foo = B(x)
363-
observed = foo.g(y, z)
364-
expected = xp.asarray(44.0)[()]
365-
xp_assert_close(observed, expected)
365+
foo = A(x)
366+
bar = B(x)
367+
368+
if library.like(Backend.JAX):
369+
with pytest.raises(
370+
TypeError, match="Attempted boolean conversion of traced array"
371+
):
372+
assert bar.w(y)
373+
374+
assert foo.w(y)
366375

367376

368377
def dask_raises(x: Array) -> Array:

0 commit comments

Comments
 (0)