Skip to content

Commit 95ab6a0

Browse files
Update QR tests to avoid element-wise comparisons (#2785)
This PR proposes updating QR tests to avoid direct element-wise comparisons which became unstable with oneMKL 2026.0 due to sign and phase differences in otherwise valid QR results Since QR factorization is not unique, different MKL and NumPy versions may return results that differ by sign or complex phase while still representing a correct decomposition To make the tests more stable this PR proposes using invariant-based validation for `mode="raw"` and `mode="r"` based on the unitarity of the Q factor (Q^H Q = I) and the resulting QR identity R^H @ R = A^H @ A
1 parent db486b9 commit 95ab6a0

File tree

4 files changed

+107
-110
lines changed

4 files changed

+107
-110
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
5353
* Changed `dpnp.partition` implementation to reuse `dpnp.sort` where it brings the performance benefit [#2766](https://github.com/IntelPython/dpnp/pull/2766)
5454
* `dpnp` uses pybind11 3.0.2 [#27734](https://github.com/IntelPython/dpnp/pull/2773)
5555
* Modified CMake files for the extension to explicitly mark DPC++ compiler and dpctl headers as system ones and so to suppress the build warning generated inside them [#2770](https://github.com/IntelPython/dpnp/pull/2770)
56+
* Updated QR tests to avoid element-wise comparisons for `raw` and `r` modes [#2785](https://github.com/IntelPython/dpnp/pull/2785)
5657

5758
### Deprecated
5859

dpnp/tests/qr_helper.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import numpy
2+
3+
from .helper import factor_to_tol, has_support_aspect64
4+
5+
6+
def gram(x, xp):
7+
# Return Gram matrix: X^H @ X
8+
return xp.conjugate(x).swapaxes(-1, -2) @ x
9+
10+
11+
def get_R_from_raw(h, m, n, xp):
12+
# Get reduced R from NumPy-style raw QR:
13+
# R = triu((tril(h))^T), shape (..., k, n)
14+
k = min(m, n)
15+
rt = xp.tril(h)
16+
r = xp.swapaxes(rt, -1, -2)
17+
r = xp.triu(r[..., :m, :n])
18+
return r[..., :k, :]
19+
20+
21+
def check_qr(a_np, a_xp, mode, xp):
22+
# QR is not unique:
23+
# element-wise comparison with NumPy may differ by sign/phase.
24+
# To verify correctness use mode-dependent functional checks:
25+
# complete/reduced: check decomposition Q @ R = A
26+
# raw/r: check invariant R^H @ R = A^H @ A
27+
if mode in ("complete", "reduced"):
28+
res = xp.linalg.qr(a_xp, mode)
29+
assert xp.allclose(res.Q @ res.R, a_xp, atol=1e-5)
30+
31+
# Since QR satisfies A = Q @ R with orthonormal Q (Q^H @ Q = I),
32+
# validate correctness via the invariant R^H @ R == A^H @ A
33+
# for raw/r modes
34+
elif mode == "raw":
35+
_, tau_np = numpy.linalg.qr(a_np, mode=mode)
36+
h_xp, tau_xp = xp.linalg.qr(a_xp, mode=mode)
37+
38+
m, n = a_np.shape[-2], a_np.shape[-1]
39+
Rraw_xp = get_R_from_raw(h_xp, m, n, xp)
40+
41+
rtol = atol = factor_to_tol(Rraw_xp.dtype, 100)
42+
43+
# Use reduced QR as a reference:
44+
# reduced is validated via Q @ R == A
45+
exp_r = xp.linalg.qr(a_xp, mode="reduced").R
46+
assert xp.allclose(Rraw_xp, exp_r, atol=atol, rtol=rtol)
47+
48+
exp_xp = gram(a_xp, xp)
49+
50+
# Compare R^H @ R == A^H @ A
51+
assert xp.allclose(gram(Rraw_xp, xp), exp_xp, atol=atol, rtol=rtol)
52+
53+
assert tau_xp.shape == tau_np.shape
54+
if not has_support_aspect64(tau_xp.sycl_device):
55+
assert tau_xp.dtype.kind == tau_np.dtype.kind
56+
else:
57+
assert tau_xp.dtype == tau_np.dtype
58+
59+
else: # mode == "r"
60+
r_xp = xp.linalg.qr(a_xp, mode="r")
61+
62+
# Use reduced QR as a reference:
63+
# reduced is validated via Q @ R == A
64+
exp_r = xp.linalg.qr(a_xp, mode="reduced").R
65+
rtol = atol = factor_to_tol(exp_r.dtype, 100)
66+
67+
assert xp.allclose(r_xp, exp_r, atol=atol, rtol=rtol)
68+
69+
exp_xp = gram(a_xp, xp)
70+
71+
# Compare R^H @ R == A^H @ A
72+
assert xp.allclose(gram(r_xp, xp), exp_xp, atol=atol, rtol=rtol)

dpnp/tests/test_linalg.py

Lines changed: 14 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
has_support_aspect64,
2525
numpy_version,
2626
)
27+
from .qr_helper import check_qr
2728
from .third_party.cupy import testing
2829

2930

@@ -3584,7 +3585,7 @@ def test_error(self):
35843585

35853586

35863587
class TestQr:
3587-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
3588+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
35883589
@pytest.mark.parametrize(
35893590
"shape",
35903591
[
@@ -3610,60 +3611,27 @@ class TestQr:
36103611
"(2, 2, 4)",
36113612
],
36123613
)
3613-
@pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"])
3614+
@pytest.mark.parametrize("mode", ["complete", "reduced", "r", "raw"])
36143615
def test_qr(self, dtype, shape, mode):
36153616
a = generate_random_numpy_array(shape, dtype, seed_value=81)
3616-
ia = dpnp.array(a)
3617+
ia = dpnp.array(a, dtype=dtype)
36173618

3618-
if mode == "r":
3619-
np_r = numpy.linalg.qr(a, mode)
3620-
dpnp_r = dpnp.linalg.qr(ia, mode)
3621-
else:
3622-
np_q, np_r = numpy.linalg.qr(a, mode)
3623-
3624-
# check decomposition
3625-
if mode in ("complete", "reduced"):
3626-
result = dpnp.linalg.qr(ia, mode)
3627-
dpnp_q, dpnp_r = result.Q, result.R
3628-
assert dpnp.allclose(
3629-
dpnp.matmul(dpnp_q, dpnp_r), ia, atol=1e-05
3630-
)
3631-
else: # mode=="raw"
3632-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
3633-
assert_dtype_allclose(dpnp_q, np_q, factor=24)
3634-
3635-
if mode in ("raw", "r"):
3636-
assert_dtype_allclose(dpnp_r, np_r, factor=24)
3619+
check_qr(a, ia, mode, dpnp)
36373620

3638-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
3621+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
36393622
@pytest.mark.parametrize(
36403623
"shape",
36413624
[(32, 32), (8, 16, 16)],
36423625
ids=["(32, 32)", "(8, 16, 16)"],
36433626
)
3644-
@pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"])
3627+
@pytest.mark.parametrize("mode", ["complete", "reduced", "r", "raw"])
36453628
def test_qr_large(self, dtype, shape, mode):
36463629
a = generate_random_numpy_array(shape, dtype, seed_value=81)
36473630
ia = dpnp.array(a)
36483631

3649-
if mode == "r":
3650-
np_r = numpy.linalg.qr(a, mode)
3651-
dpnp_r = dpnp.linalg.qr(ia, mode)
3652-
else:
3653-
np_q, np_r = numpy.linalg.qr(a, mode)
3654-
3655-
# check decomposition
3656-
if mode in ("complete", "reduced"):
3657-
result = dpnp.linalg.qr(ia, mode)
3658-
dpnp_q, dpnp_r = result.Q, result.R
3659-
assert dpnp.allclose(dpnp.matmul(dpnp_q, dpnp_r), ia, atol=1e-5)
3660-
else: # mode=="raw"
3661-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
3662-
assert_allclose(dpnp_q, np_q, atol=1e-4)
3663-
if mode in ("raw", "r"):
3664-
assert_allclose(dpnp_r, np_r, atol=1e-4)
3632+
check_qr(a, ia, mode, dpnp)
36653633

3666-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
3634+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
36673635
@pytest.mark.parametrize(
36683636
"shape",
36693637
[(0, 0), (0, 2), (2, 0), (2, 0, 3), (2, 3, 0), (0, 2, 3)],
@@ -3676,65 +3644,22 @@ def test_qr_large(self, dtype, shape, mode):
36763644
"(0, 2, 3)",
36773645
],
36783646
)
3679-
@pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"])
3647+
@pytest.mark.parametrize("mode", ["complete", "reduced", "r", "raw"])
36803648
def test_qr_empty(self, dtype, shape, mode):
36813649
a = numpy.empty(shape, dtype=dtype)
36823650
ia = dpnp.array(a)
36833651

3684-
if mode == "r":
3685-
np_r = numpy.linalg.qr(a, mode)
3686-
dpnp_r = dpnp.linalg.qr(ia, mode)
3687-
else:
3688-
np_q, np_r = numpy.linalg.qr(a, mode)
3689-
3690-
if mode in ("complete", "reduced"):
3691-
result = dpnp.linalg.qr(ia, mode)
3692-
dpnp_q, dpnp_r = result.Q, result.R
3693-
else:
3694-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
3652+
check_qr(a, ia, mode, dpnp)
36953653

3696-
assert_dtype_allclose(dpnp_q, np_q)
3697-
3698-
assert_dtype_allclose(dpnp_r, np_r)
3699-
3700-
@pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"])
3654+
@pytest.mark.parametrize("mode", ["complete", "reduced", "r", "raw"])
37013655
def test_qr_strides(self, mode):
37023656
a = generate_random_numpy_array((5, 5))
37033657
ia = dpnp.array(a)
37043658

37053659
# positive strides
3706-
if mode == "r":
3707-
np_r = numpy.linalg.qr(a[::2, ::2], mode)
3708-
dpnp_r = dpnp.linalg.qr(ia[::2, ::2], mode)
3709-
else:
3710-
np_q, np_r = numpy.linalg.qr(a[::2, ::2], mode)
3711-
3712-
if mode in ("complete", "reduced"):
3713-
result = dpnp.linalg.qr(ia[::2, ::2], mode)
3714-
dpnp_q, dpnp_r = result.Q, result.R
3715-
else:
3716-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia[::2, ::2], mode)
3717-
3718-
assert_dtype_allclose(dpnp_q, np_q)
3719-
3720-
assert_dtype_allclose(dpnp_r, np_r)
3721-
3660+
check_qr(a[::2, ::2], ia[::2, ::2], mode, dpnp)
37223661
# negative strides
3723-
if mode == "r":
3724-
np_r = numpy.linalg.qr(a[::-2, ::-2], mode)
3725-
dpnp_r = dpnp.linalg.qr(ia[::-2, ::-2], mode)
3726-
else:
3727-
np_q, np_r = numpy.linalg.qr(a[::-2, ::-2], mode)
3728-
3729-
if mode in ("complete", "reduced"):
3730-
result = dpnp.linalg.qr(ia[::-2, ::-2], mode)
3731-
dpnp_q, dpnp_r = result.Q, result.R
3732-
else:
3733-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia[::-2, ::-2], mode)
3734-
3735-
assert_dtype_allclose(dpnp_q, np_q)
3736-
3737-
assert_dtype_allclose(dpnp_r, np_r)
3662+
check_qr(a[::-2, ::-2], ia[::-2, ::-2], mode, dpnp)
37383663

37393664
def test_qr_errors(self):
37403665
a_dp = dpnp.array([[1, 2], [3, 5]], dtype="float32")

dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
# from cupy.cuda import runtime
1313
# from cupy.linalg import _util
1414
from dpnp.tests.helper import (
15-
LTS_VERSION,
1615
has_support_aspect64,
17-
is_lts_driver,
1816
)
17+
from dpnp.tests.qr_helper import check_qr
1918
from dpnp.tests.third_party.cupy import testing
2019
from dpnp.tests.third_party.cupy.testing import _condition
2120

@@ -169,7 +168,6 @@ def test_decomposition(self, dtype):
169168
)
170169
)
171170
class TestQRDecomposition(unittest.TestCase):
172-
173171
@testing.for_dtypes("fdFD")
174172
def check_mode(self, array, mode, dtype):
175173
# if runtime.is_hip and driver.get_build_version() < 307:
@@ -178,22 +176,29 @@ def check_mode(self, array, mode, dtype):
178176

179177
a_cpu = numpy.asarray(array, dtype=dtype)
180178
a_gpu = cupy.asarray(array, dtype=dtype)
181-
result_gpu = cupy.linalg.qr(a_gpu, mode=mode)
179+
# QR is not unique:
180+
# element-wise comparison with NumPy may differ by sign/phase.
181+
# To verify correctness use mode-dependent functional checks:
182+
# complete/reduced: check decomposition Q @ R = A
183+
# raw/r: check invariant R^H @ R = A^H @ A
184+
185+
# result_gpu = cupy.linalg.qr(a_gpu, mode=mode)
182186
if (
183187
mode != "raw"
184188
or numpy.lib.NumpyVersion(numpy.__version__) >= "1.22.0rc1"
185189
):
186-
result_cpu = numpy.linalg.qr(a_cpu, mode=mode)
187-
self._check_result(result_cpu, result_gpu)
188-
189-
def _check_result(self, result_cpu, result_gpu):
190-
if isinstance(result_cpu, tuple):
191-
for b_cpu, b_gpu in zip(result_cpu, result_gpu):
192-
assert b_cpu.dtype == b_gpu.dtype
193-
testing.assert_allclose(b_cpu, b_gpu, atol=1e-4)
194-
else:
195-
assert result_cpu.dtype == result_gpu.dtype
196-
testing.assert_allclose(result_cpu, result_gpu, atol=1e-4)
190+
# result_cpu = numpy.linalg.qr(a_cpu, mode=mode)
191+
# self._check_result(result_cpu, result_gpu, a_gpu, mode)
192+
check_qr(a_cpu, a_gpu, mode, cupy)
193+
194+
# def _check_result(self, result_cpu, result_gpu):
195+
# if isinstance(result_cpu, tuple):
196+
# for b_cpu, b_gpu in zip(result_cpu, result_gpu):
197+
# assert b_cpu.dtype == b_gpu.dtype
198+
# testing.assert_allclose(b_cpu, b_gpu, atol=1e-4)
199+
# else:
200+
# assert result_cpu.dtype == result_gpu.dtype
201+
# testing.assert_allclose(result_cpu, result_gpu, atol=1e-4)
197202

198203
@testing.fix_random()
199204
@_condition.repeat(3, 10)
@@ -202,19 +207,13 @@ def test_mode(self):
202207
self.check_mode(numpy.random.randn(3, 3), mode=self.mode)
203208
self.check_mode(numpy.random.randn(5, 4), mode=self.mode)
204209

205-
@pytest.mark.skipif(
206-
is_lts_driver(version=LTS_VERSION.V1_6), reason="SAT-8375"
207-
)
208210
@testing.with_requires("numpy>=1.22")
209211
@testing.fix_random()
210212
def test_mode_rank3(self):
211213
self.check_mode(numpy.random.randn(3, 2, 4), mode=self.mode)
212214
self.check_mode(numpy.random.randn(4, 3, 3), mode=self.mode)
213215
self.check_mode(numpy.random.randn(2, 5, 4), mode=self.mode)
214216

215-
@pytest.mark.skipif(
216-
is_lts_driver(version=LTS_VERSION.V1_6), reason="SAT-8375"
217-
)
218217
@testing.with_requires("numpy>=1.22")
219218
@testing.fix_random()
220219
def test_mode_rank4(self):

0 commit comments

Comments
 (0)