Skip to content

Commit a739cfe

Browse files
committed
TST: test that classmethods get wrapped
1 parent 0c12809 commit a739cfe

1 file changed

Lines changed: 24 additions & 4 deletions

File tree

tests/test_testing.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def f(self, y: Array) -> Array:
334334
def g(self, y: Array, z: Array) -> Array:
335335
return self.f(y) + self.f(z)
336336

337-
def w(self, y: Array) -> bool:
337+
def h(self, y: Array) -> bool:
338338
return bool(self._xp.any(y))
339339

340340

@@ -361,6 +361,14 @@ def j(y: Array) -> "B":
361361
return B(y)
362362
return B(y + 1.0)
363363

364+
@classmethod
365+
def w(cls, y: Array) -> "B":
366+
xp = array_namespace(y)
367+
y = xp.asarray(y)
368+
if bool(xp.any(y)):
369+
return B(y)
370+
return B(y + 1.0)
371+
364372

365373
@final
366374
class eager:
@@ -369,9 +377,10 @@ class eager:
369377

370378

371379
lazy_xp_function((B, "g"))
372-
lazy_xp_function((B, "w"))
380+
lazy_xp_function((B, "h"))
373381
lazy_xp_function((B, "k"))
374382
lazy_xp_function((B, "j"))
383+
lazy_xp_function((B, "w"))
375384

376385

377386
class TestLazyXpFunctionClasses:
@@ -393,9 +402,9 @@ def test_lazy_xp_function_classes(self, xp: ModuleType, library: Backend):
393402
with pytest.raises(
394403
TypeError, match="Attempted boolean conversion of traced array"
395404
):
396-
assert bar.w(y)
405+
assert bar.h(y)
397406

398-
assert foo.w(y)
407+
assert foo.h(y)
399408

400409
def test_static_methods_preserved(self, xp: ModuleType):
401410
# Tests that static methods stay static methods when
@@ -418,6 +427,17 @@ def test_static_methods_wrapped(self, xp: ModuleType, library: Backend):
418427
else:
419428
assert isinstance(foo.j(x), B)
420429

430+
@pytest.mark.skip_xp_backend(Backend.DASK, reason="calls dask.compute()")
431+
def test_class_methods_wrapped(self, xp: ModuleType, library: Backend):
432+
x = xp.asarray([1.1, 2.2, 3.3])
433+
if library.like(Backend.JAX):
434+
with pytest.raises(
435+
TypeError, match="Attempted boolean conversion of traced array"
436+
):
437+
assert isinstance(B.w(x), B)
438+
else:
439+
assert isinstance(B.w(x), B)
440+
421441
def test_circumvention(self, xp: ModuleType):
422442
x = xp.asarray([1.0, 2.0])
423443
y = eager.non_materializable5(x)

0 commit comments

Comments
 (0)