Skip to content

Commit 7a6b5bd

Browse files
committed
Fix black formatting
1 parent 206f4ee commit 7a6b5bd

File tree

5 files changed

+24
-19
lines changed

5 files changed

+24
-19
lines changed

dpnp/scipy/linalg/_decomp_lu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@
5252
)
5353

5454

55-
def lu(a, permute_l=False, overwrite_a=False, check_finite=True,
56-
p_indices=False):
55+
def lu(
56+
a, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
57+
):
5758
"""
5859
Compute LU decomposition of a matrix with partial pivoting.
5960

dpnp/scipy/linalg/_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ def _get_real_dtype(res_type):
137137
"""
138138

139139
if dpnp.issubdtype(res_type, dpnp.complexfloating):
140-
return dpnp.float32 if dpnp.dtype(res_type).itemsize <= 8 else dpnp.float64
140+
return (
141+
dpnp.float32 if dpnp.dtype(res_type).itemsize <= 8 else dpnp.float64
142+
)
141143
return res_type
142144

143145

@@ -451,12 +453,12 @@ def _pivots_to_permutation(piv, m):
451453
j = piv[..., i : i + 1]
452454

453455
# Gather the two values to be swapped.
454-
val_i = perm[..., i : i + 1].copy() # slice (free)
455-
val_j = dpnp.take_along_axis(perm, j, axis=-1) # gather
456+
val_i = perm[..., i : i + 1].copy() # slice (free)
457+
val_j = dpnp.take_along_axis(perm, j, axis=-1) # gather
456458

457459
# Perform the swap.
458-
perm[..., i : i + 1] = val_j # slice assign
459-
dpnp.put_along_axis(perm, j, val_i, axis=-1) # scatter
460+
perm[..., i : i + 1] = val_j # slice assign
461+
dpnp.put_along_axis(perm, j, val_i, axis=-1) # scatter
460462

461463
return perm
462464

dpnp/tests/test_linalg.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,7 +2743,9 @@ def test_p_is_permutation(self, dtype):
27432743

27442744
assert_allclose(P_np.sum(axis=0), numpy.ones(5, dtype=P_np.dtype))
27452745
assert_allclose(P_np.sum(axis=1), numpy.ones(5, dtype=P_np.dtype))
2746-
assert_allclose(P_np.T @ P_np, numpy.eye(5, dtype=P_np.dtype), atol=1e-15)
2746+
assert_allclose(
2747+
P_np.T @ P_np, numpy.eye(5, dtype=P_np.dtype), atol=1e-15
2748+
)
27472749

27482750
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
27492751
def test_modes_consistency(self, dtype):
@@ -2825,9 +2827,7 @@ def test_empty_p_indices(self, shape):
28252827
],
28262828
)
28272829
def test_strided(self, sl):
2828-
base = self._make_nonsingular_np(
2829-
(7, 7), dpnp.default_float_type(), "F"
2830-
)
2830+
base = self._make_nonsingular_np((7, 7), dpnp.default_float_type(), "F")
28312831
a_np = base[sl]
28322832
a_dp = dpnp.array(a_np)
28332833

@@ -2859,9 +2859,7 @@ def test_1d_input_raises(self):
28592859
@pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan])
28602860
def test_check_finite_raises(self, bad):
28612861
a_dp = dpnp.array([[1.0, 2.0], [3.0, bad]], order="F")
2862-
assert_raises(
2863-
ValueError, dpnp.scipy.linalg.lu, a_dp, check_finite=True
2864-
)
2862+
assert_raises(ValueError, dpnp.scipy.linalg.lu, a_dp, check_finite=True)
28652863

28662864
def test_check_finite_disabled(self):
28672865
a_dp = dpnp.array([[1.0, numpy.nan], [3.0, 4.0]])
@@ -2976,8 +2974,12 @@ def test_modes_consistency_batched(self, dtype):
29762974
a_np = self._make_nonsingular_nd_np((3, 4, 4), dtype, "F")
29772975
a_dp = dpnp.array(a_np, order="F")
29782976
A_cast = a_dp.astype(
2979-
dpnp.complex128 if dpnp.issubdtype(dtype, dpnp.complexfloating)
2980-
else dpnp.float64, copy=False
2977+
(
2978+
dpnp.complex128
2979+
if dpnp.issubdtype(dtype, dpnp.complexfloating)
2980+
else dpnp.float64
2981+
),
2982+
copy=False,
29812983
)
29822984

29832985
P, L, U = dpnp.scipy.linalg.lu(a_dp)
@@ -3057,9 +3059,7 @@ def test_singular_matrix(self):
30573059
def test_check_finite_raises(self):
30583060
a = dpnp.ones((2, 3, 3), dtype=dpnp.default_float_type(), order="F")
30593061
a[1, 0, 0] = dpnp.nan
3060-
assert_raises(
3061-
ValueError, dpnp.scipy.linalg.lu, a, check_finite=True
3062-
)
3062+
assert_raises(ValueError, dpnp.scipy.linalg.lu, a, check_finite=True)
30633063

30643064

30653065
class TestMatrixPower:

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,7 @@ def test_lu_factor(self, data, device):
16651665
for param in result:
16661666
param_queue = param.sycl_queue
16671667
assert_sycl_queue_equal(param_queue, a.sycl_queue)
1668+
16681669
def test_lu(self, data, device):
16691670
a = dpnp.array(data, device=device)
16701671
result = dpnp.scipy.linalg.lu(a)

dpnp/tests/test_usm_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,7 @@ def test_lu(self, data, usm_type):
15381538
assert a.usm_type == usm_type
15391539
for param in result:
15401540
assert param.usm_type == a.usm_type
1541+
15411542
def test_lu_factor(self, data, usm_type):
15421543
a = dpnp.array(data, usm_type=usm_type)
15431544
result = dpnp.scipy.linalg.lu_factor(a)

0 commit comments

Comments
 (0)