Skip to content

Commit 69c51d8

Browse files
committed
Fix the test_linalg.py, add a test for scalar
1 parent 5de153d commit 69c51d8

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

dpnp/tests/test_linalg.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,7 +2621,9 @@ def _make_nonsingular_np(shape, dtype, order):
26212621
[(1, 1), (2, 2), (3, 3), (1, 5), (5, 1), (2, 5), (5, 2)],
26222622
)
26232623
@pytest.mark.parametrize("order", ["C", "F"])
2624-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True, no_bool=True))
2624+
@pytest.mark.parametrize(
2625+
"dtype", get_all_dtypes(no_none=True, no_bool=True)
2626+
)
26252627
def test_lu_default(self, shape, order, dtype):
26262628
a_np = self._make_nonsingular_np(shape, dtype, order)
26272629
a_dp = dpnp.array(a_np, order=order)
@@ -2643,7 +2645,9 @@ def test_lu_default(self, shape, order, dtype):
26432645
[(1, 1), (2, 2), (3, 3), (1, 5), (5, 1), (2, 5), (5, 2)],
26442646
)
26452647
@pytest.mark.parametrize("order", ["C", "F"])
2646-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True, no_bool=True))
2648+
@pytest.mark.parametrize(
2649+
"dtype", get_all_dtypes(no_none=True, no_bool=True)
2650+
)
26472651
def test_lu_permute_l(self, shape, order, dtype):
26482652
a_np = self._make_nonsingular_np(shape, dtype, order)
26492653
a_dp = dpnp.array(a_np, order=order)
@@ -2664,7 +2668,9 @@ def test_lu_permute_l(self, shape, order, dtype):
26642668
[(1, 1), (2, 2), (3, 3), (1, 5), (5, 1), (2, 5), (5, 2)],
26652669
)
26662670
@pytest.mark.parametrize("order", ["C", "F"])
2667-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True, no_bool=True))
2671+
@pytest.mark.parametrize(
2672+
"dtype", get_all_dtypes(no_none=True, no_bool=True)
2673+
)
26682674
def test_lu_p_indices(self, shape, order, dtype):
26692675
a_np = self._make_nonsingular_np(shape, dtype, order)
26702676
a_dp = dpnp.array(a_np, order=order)
@@ -2678,10 +2684,7 @@ def test_lu_p_indices(self, shape, order, dtype):
26782684
assert U.shape == (k, n)
26792685
assert dpnp.issubdtype(p.dtype, dpnp.integer)
26802686

2681-
p_np = dpnp.asnumpy(p)
2682-
L_np = dpnp.asnumpy(L)
2683-
U_np = dpnp.asnumpy(U)
2684-
A_rec = L_np[p_np] @ U_np
2687+
A_rec = L[p] @ U
26852688
A_cast = a_dp.astype(L.dtype, copy=False)
26862689
assert dpnp.allclose(A_rec, A_cast, rtol=1e-6, atol=1e-6)
26872690

@@ -2852,6 +2855,12 @@ def test_check_finite_raises(self, bad):
28522855
a_dp = dpnp.array([[1.0, 2.0], [3.0, bad]], order="F")
28532856
assert_raises(ValueError, dpnp.scipy.linalg.lu, a_dp, check_finite=True)
28542857

2858+
@pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan])
2859+
def test_check_finite_raises_scalar(self, bad):
2860+
# Covers the 1x1 scalar fast path in dpnp_lu
2861+
a_dp = dpnp.array([[bad]])
2862+
assert_raises(ValueError, dpnp.scipy.linalg.lu, a_dp, check_finite=True)
2863+
28552864
def test_check_finite_disabled(self):
28562865
a_dp = dpnp.array([[1.0, numpy.nan], [3.0, 4.0]])
28572866
result = dpnp.scipy.linalg.lu(a_dp, check_finite=False)

0 commit comments

Comments
 (0)