Skip to content

Commit e9ececb

Browse files
committed
support for dpnp.scipy.linalg.lu()
1 parent d545555 commit e9ececb

File tree

6 files changed

+899
-1
lines changed

6 files changed

+899
-1
lines changed

dpnp/scipy/linalg/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
3636
"""
3737

38-
from ._decomp_lu import lu_factor, lu_solve
38+
from ._decomp_lu import lu, lu_factor, lu_solve
3939

4040
__all__ = [
41+
"lu",
4142
"lu_factor",
4243
"lu_solve",
4344
]

dpnp/scipy/linalg/_decomp_lu.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,153 @@
4646
)
4747

4848
from ._utils import (
49+
dpnp_lu,
4950
dpnp_lu_factor,
5051
dpnp_lu_solve,
5152
)
5253

5354

55+
def lu(a, permute_l=False, overwrite_a=False, check_finite=True,
56+
p_indices=False):
57+
"""
58+
Compute LU decomposition of a matrix with partial pivoting.
59+
60+
The decomposition satisfies::
61+
62+
A = P @ L @ U
63+
64+
where `P` is a permutation matrix, `L` is lower triangular with unit
65+
diagonal elements, and `U` is upper triangular. If `permute_l` is set to
66+
``True`` then `L` is returned already permuted and hence satisfying
67+
``A = L @ U``.
68+
69+
For full documentation refer to :obj:`scipy.linalg.lu`.
70+
71+
Parameters
72+
----------
73+
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
74+
Input array to decompose.
75+
permute_l : bool, optional
76+
Perform the multiplication ``P @ L`` (Default: do not permute).
77+
78+
Default: ``False``.
79+
overwrite_a : {None, bool}, optional
80+
Whether to overwrite data in `a` (may increase performance).
81+
82+
Default: ``False``.
83+
check_finite : {None, bool}, optional
84+
Whether to check that the input matrix contains only finite numbers.
85+
Disabling may give a performance gain, but may result in problems
86+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
87+
88+
Default: ``True``.
89+
p_indices : bool, optional
90+
If ``True`` the permutation information is returned as row indices
91+
instead of a permutation matrix.
92+
93+
Default: ``False``.
94+
95+
Returns
96+
-------
97+
**(If ``permute_l`` is ``False``)**
98+
99+
p : (..., M, M) dpnp.ndarray or (..., M) dpnp.ndarray
100+
If `p_indices` is ``False`` (default), the permutation matrix.
101+
The permutation matrix always has a real dtype (``float32`` or
102+
``float64``) even when `a` is complex, since it only contains
103+
0s and 1s.
104+
If `p_indices` is ``True``, a 1-D (or batched) array of row
105+
permutation indices such that ``A = L[p] @ U``.
106+
l : (..., M, K) dpnp.ndarray
107+
Lower triangular or trapezoidal matrix with unit diagonal.
108+
``K = min(M, N)``.
109+
u : (..., K, N) dpnp.ndarray
110+
Upper triangular or trapezoidal matrix.
111+
112+
**(If ``permute_l`` is ``True``)**
113+
114+
pl : (..., M, K) dpnp.ndarray
115+
Permuted ``L`` matrix: ``pl = P @ L``.
116+
``K = min(M, N)``.
117+
u : (..., K, N) dpnp.ndarray
118+
Upper triangular or trapezoidal matrix.
119+
120+
Notes
121+
-----
122+
Permutation matrices are costly since they are nothing but row reorder of
123+
``L`` and hence indices are strongly recommended to be used instead if the
124+
permutation is required. The relation in the 2D case then becomes simply
125+
``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
126+
to avoid complicated indexing tricks.
127+
128+
In the 2D case, if one has the indices however, for some reason, the
129+
permutation matrix is still needed then it can be constructed by
130+
``dpnp.eye(M)[P, :]``.
131+
132+
Warning
133+
-------
134+
This function synchronizes in order to validate array elements
135+
when ``check_finite=True``, and also synchronizes to compute the
136+
permutation from LAPACK pivot indices.
137+
138+
See Also
139+
--------
140+
:obj:`dpnp.scipy.linalg.lu_factor` : LU factorize a matrix
141+
(compact representation).
142+
:obj:`dpnp.scipy.linalg.lu_solve` : Solve an equation system using
143+
the LU factorization of a matrix.
144+
145+
Examples
146+
--------
147+
>>> import dpnp as np
148+
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8],
149+
... [7, 5, 6, 6], [5, 4, 4, 8]])
150+
>>> p, l, u = np.scipy.linalg.lu(A)
151+
>>> np.allclose(A, p @ l @ u)
152+
array(True)
153+
154+
Retrieve the permutation as row indices with ``p_indices=True``:
155+
156+
>>> p, l, u = np.scipy.linalg.lu(A, p_indices=True)
157+
>>> p
158+
array([1, 3, 0, 2])
159+
>>> np.allclose(A, l[p] @ u)
160+
array(True)
161+
162+
Return the permuted ``L`` directly with ``permute_l=True``:
163+
164+
>>> pl, u = np.scipy.linalg.lu(A, permute_l=True)
165+
>>> np.allclose(A, pl @ u)
166+
array(True)
167+
168+
Non-square matrices are supported:
169+
170+
>>> B = np.array([[1, 2, 3], [4, 5, 6]])
171+
>>> p, l, u = np.scipy.linalg.lu(B)
172+
>>> np.allclose(B, p @ l @ u)
173+
array(True)
174+
175+
Batched input:
176+
177+
>>> C = np.random.randn(3, 2, 4, 4)
178+
>>> p, l, u = np.scipy.linalg.lu(C)
179+
>>> np.allclose(C, p @ l @ u)
180+
array(True)
181+
182+
"""
183+
184+
dpnp.check_supported_arrays_type(a)
185+
assert_stacked_2d(a)
186+
187+
return dpnp_lu(
188+
a,
189+
overwrite_a=overwrite_a,
190+
check_finite=check_finite,
191+
p_indices=p_indices,
192+
permute_l=permute_l,
193+
)
194+
195+
54196
def lu_factor(a, overwrite_a=False, check_finite=True):
55197
"""
56198
Compute the pivoted LU decomposition of `a` matrix.

0 commit comments

Comments
 (0)