Skip to content

Commit dc81277

Browse files
Apply remarks
1 parent e9973e6 commit dc81277

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

dpnp/tests/qr_helper.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy
22

3-
from .helper import has_support_aspect64
3+
from .helper import factor_to_tol, has_support_aspect64
44

55

66
def gram(x, xp):
@@ -38,16 +38,17 @@ def check_qr(a_np, a_xp, mode, xp):
3838
m, n = a_np.shape[-2], a_np.shape[-1]
3939
Rraw_xp = get_R_from_raw(h_xp, m, n, xp)
4040

41+
rtol = atol = factor_to_tol(Rraw_xp.dtype, 100)
42+
4143
# Use reduced QR as a reference:
4244
# reduced is validated via Q @ R == A
43-
exp_res = xp.linalg.qr(a_xp, mode="reduced")
44-
exp_r = exp_res.R
45-
assert xp.allclose(Rraw_xp, exp_r, atol=1e-4, rtol=1e-4)
45+
exp_r = xp.linalg.qr(a_xp, mode="reduced").R
46+
assert xp.allclose(Rraw_xp, exp_r, atol=atol, rtol=rtol)
4647

4748
exp_xp = gram(a_xp, xp)
4849

4950
# Compare R^H @ R == A^H @ A
50-
assert xp.allclose(gram(Rraw_xp, xp), exp_xp, atol=1e-4, rtol=1e-4)
51+
assert xp.allclose(gram(Rraw_xp, xp), exp_xp, atol=atol, rtol=rtol)
5152

5253
assert tau_xp.shape == tau_np.shape
5354
if not has_support_aspect64(tau_xp.sycl_device):
@@ -60,11 +61,12 @@ def check_qr(a_np, a_xp, mode, xp):
6061

6162
# Use reduced QR as a reference:
6263
# reduced is validated via Q @ R == A
63-
exp_res = xp.linalg.qr(a_xp, mode="reduced")
64-
exp_r = exp_res.R
65-
assert xp.allclose(r_xp, exp_r, atol=1e-4, rtol=1e-4)
64+
exp_r = xp.linalg.qr(a_xp, mode="reduced").R
65+
rtol = atol = factor_to_tol(exp_r.dtype, 100)
66+
67+
assert xp.allclose(r_xp, exp_r, atol=atol, rtol=rtol)
6668

6769
exp_xp = gram(a_xp, xp)
6870

6971
# Compare R^H @ R == A^H @ A
70-
assert xp.allclose(gram(r_xp, xp), exp_xp, atol=1e-4, rtol=1e-4)
72+
assert xp.allclose(gram(r_xp, xp), exp_xp, atol=atol, rtol=rtol)

dpnp/tests/test_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3613,7 +3613,7 @@ class TestQr:
36133613
)
36143614
@pytest.mark.parametrize("mode", ["complete", "reduced", "r", "raw"])
36153615
def test_qr(self, dtype, shape, mode):
3616-
a = generate_random_numpy_array(shape, dtype, seed_value=None)
3616+
a = generate_random_numpy_array(shape, dtype, seed_value=81)
36173617
ia = dpnp.array(a, dtype=dtype)
36183618

36193619
check_qr(a, ia, mode, dpnp)

0 commit comments

Comments
 (0)