@@ -521,6 +521,7 @@ def test_complex(self, xp: ModuleType):
521521 expect = xp .asarray ([[1.0 , - 1.0j ], [1.0j , 1.0 ]], dtype = xp .complex128 )
522522 xp_assert_close (actual , expect )
523523
524+ @pytest .mark .xfail_xp_backend (Backend .JAX_GPU , reason = "jax#32296" )
524525 @pytest .mark .xfail_xp_backend (Backend .JAX , reason = "jax#32296" )
525526 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "sparse#877" )
526527 def test_empty (self , xp : ModuleType ):
@@ -989,14 +990,14 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
989990 assert get_device (res ) == device
990991
991992 def test_array_on_device_with_scalar (self , xp : ModuleType , device : Device ):
992- a = xp .asarray ([0.01 , 0.5 , 0.8 , 0.9 , 1.00001 ], device = device )
993+ a = xp .asarray ([0.01 , 0.5 , 0.8 , 0.9 , 1.00001 ], device = device , dtype = xp . float64 )
993994 b = 1
994995 res = isclose (a , b )
995996 assert get_device (res ) == device
996997 xp_assert_equal (res , xp .asarray ([False , False , False , False , True ]))
997998
998999 a = 0.1
999- b = xp .asarray ([0.01 , 0.5 , 0.8 , 0.9 , 0.100001 ], device = device )
1000+ b = xp .asarray ([0.01 , 0.5 , 0.8 , 0.9 , 0.100001 ], device = device , dtype = xp . float64 )
10001001 res = isclose (a , b )
10011002 assert get_device (res ) == device
10021003 xp_assert_equal (res , xp .asarray ([False , False , False , False , True ]))
0 commit comments