Skip to content

Commit cfcf7d6

Browse files
committed
TST: fix two CUDA test failures
1 parent cf4d7d9 commit cfcf7d6

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

tests/test_funcs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ def test_complex(self, xp: ModuleType):
521521
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
522522
xp_assert_close(actual, expect)
523523

524+
@pytest.mark.xfail_xp_backend(Backend.JAX_GPU, reason="jax#32296")
524525
@pytest.mark.xfail_xp_backend(Backend.JAX, reason="jax#32296")
525526
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#877")
526527
def test_empty(self, xp: ModuleType):
@@ -989,14 +990,14 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
989990
assert get_device(res) == device
990991

991992
def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device):
992-
a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device)
993+
a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device, dtype=xp.float64)
993994
b = 1
994995
res = isclose(a, b)
995996
assert get_device(res) == device
996997
xp_assert_equal(res, xp.asarray([False, False, False, False, True]))
997998

998999
a = 0.1
999-
b = xp.asarray([0.01, 0.5, 0.8, 0.9, 0.100001], device=device)
1000+
b = xp.asarray([0.01, 0.5, 0.8, 0.9, 0.100001], device=device, dtype=xp.float64)
10001001
res = isclose(a, b)
10011002
assert get_device(res) == device
10021003
xp_assert_equal(res, xp.asarray([False, False, False, False, True]))

0 commit comments

Comments
 (0)