11import numpy
22
3- from .helper import has_support_aspect64
3+ from .helper import factor_to_tol , has_support_aspect64
44
55
66def gram (x , xp ):
@@ -38,16 +38,17 @@ def check_qr(a_np, a_xp, mode, xp):
3838 m , n = a_np .shape [- 2 ], a_np .shape [- 1 ]
3939 Rraw_xp = get_R_from_raw (h_xp , m , n , xp )
4040
41+ rtol = atol = factor_to_tol (Rraw_xp .dtype , 100 )
42+
4143 # Use reduced QR as a reference:
4244 # reduced is validated via Q @ R == A
43- exp_res = xp .linalg .qr (a_xp , mode = "reduced" )
44- exp_r = exp_res .R
45- assert xp .allclose (Rraw_xp , exp_r , atol = 1e-4 , rtol = 1e-4 )
45+ exp_r = xp .linalg .qr (a_xp , mode = "reduced" ).R
46+ assert xp .allclose (Rraw_xp , exp_r , atol = atol , rtol = rtol )
4647
4748 exp_xp = gram (a_xp , xp )
4849
4950 # Compare R^H @ R == A^H @ A
50- assert xp .allclose (gram (Rraw_xp , xp ), exp_xp , atol = 1e-4 , rtol = 1e-4 )
51+ assert xp .allclose (gram (Rraw_xp , xp ), exp_xp , atol = atol , rtol = rtol )
5152
5253 assert tau_xp .shape == tau_np .shape
5354 if not has_support_aspect64 (tau_xp .sycl_device ):
@@ -60,11 +61,12 @@ def check_qr(a_np, a_xp, mode, xp):
6061
6162 # Use reduced QR as a reference:
6263 # reduced is validated via Q @ R == A
63- exp_res = xp .linalg .qr (a_xp , mode = "reduced" )
64- exp_r = exp_res .R
65- assert xp .allclose (r_xp , exp_r , atol = 1e-4 , rtol = 1e-4 )
64+ exp_r = xp .linalg .qr (a_xp , mode = "reduced" ).R
65+ rtol = atol = factor_to_tol (exp_r .dtype , 100 )
66+
67+ assert xp .allclose (r_xp , exp_r , atol = atol , rtol = rtol )
6668
6769 exp_xp = gram (a_xp , xp )
6870
6971 # Compare R^H @ R == A^H @ A
70- assert xp .allclose (gram (r_xp , xp ), exp_xp , atol = 1e-4 , rtol = 1e-4 )
72+ assert xp .allclose (gram (r_xp , xp ), exp_xp , atol = atol , rtol = rtol )
0 commit comments