|
46 | 46 | ) |
47 | 47 |
|
48 | 48 | from ._utils import ( |
| 49 | + dpnp_lu, |
49 | 50 | dpnp_lu_factor, |
50 | 51 | dpnp_lu_solve, |
51 | 52 | ) |
52 | 53 |
|
53 | 54 |
|
| 55 | +def lu( |
| 56 | + a, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False |
| 57 | +): |
| 58 | + """ |
| 59 | + Compute LU decomposition of a matrix with partial pivoting. |
| 60 | +
|
| 61 | + The decomposition satisfies:: |
| 62 | +
|
| 63 | + A = P @ L @ U |
| 64 | +
|
| 65 | + where `P` is a permutation matrix, `L` is lower triangular with unit |
| 66 | + diagonal elements, and `U` is upper triangular. If `permute_l` is set to |
| 67 | + ``True`` then `L` is returned already permuted and hence satisfying |
| 68 | + ``A = L @ U``. |
| 69 | +
|
| 70 | + For full documentation refer to :obj:`scipy.linalg.lu`. |
| 71 | +
|
| 72 | + Parameters |
| 73 | + ---------- |
| 74 | + a : (..., M, N) {dpnp.ndarray, usm_ndarray} |
| 75 | + Input array to decompose. |
| 76 | + permute_l : bool, optional |
| 77 | + Perform the multiplication ``P @ L`` (Default: do not permute). |
| 78 | +
|
| 79 | + Default: ``False``. |
| 80 | + overwrite_a : {None, bool}, optional |
| 81 | + Whether to overwrite data in `a` (may increase performance). |
| 82 | +
|
| 83 | + Default: ``False``. |
| 84 | + check_finite : {None, bool}, optional |
| 85 | + Whether to check that the input matrix contains only finite numbers. |
| 86 | + Disabling may give a performance gain, but may result in problems |
| 87 | + (crashes, non-termination) if the inputs do contain infinities or NaNs. |
| 88 | +
|
| 89 | + Default: ``True``. |
| 90 | + p_indices : bool, optional |
| 91 | + If ``True`` the permutation information is returned as row indices |
| 92 | + instead of a permutation matrix. |
| 93 | +
|
| 94 | + Default: ``False``. |
| 95 | +
|
| 96 | + Returns |
| 97 | + ------- |
| 98 | + **(If ``permute_l`` is ``False``)** |
| 99 | +
|
| 100 | + p : (..., M, M) dpnp.ndarray or (..., M) dpnp.ndarray |
| 101 | + If `p_indices` is ``False`` (default), the permutation matrix. |
| 102 | + The permutation matrix always has a real dtype (``float32`` or |
| 103 | + ``float64``) even when `a` is complex, since it only contains |
| 104 | + 0s and 1s. |
| 105 | + If `p_indices` is ``True``, a 1-D (or batched) array of row |
| 106 | + permutation indices such that ``A = L[p] @ U``. |
| 107 | + l : (..., M, K) dpnp.ndarray |
| 108 | + Lower triangular or trapezoidal matrix with unit diagonal. |
| 109 | + ``K = min(M, N)``. |
| 110 | + u : (..., K, N) dpnp.ndarray |
| 111 | + Upper triangular or trapezoidal matrix. |
| 112 | +
|
| 113 | + **(If ``permute_l`` is ``True``)** |
| 114 | +
|
| 115 | + pl : (..., M, K) dpnp.ndarray |
| 116 | + Permuted ``L`` matrix: ``pl = P @ L``. |
| 117 | + ``K = min(M, N)``. |
| 118 | + u : (..., K, N) dpnp.ndarray |
| 119 | + Upper triangular or trapezoidal matrix. |
| 120 | +
|
| 121 | + Notes |
| 122 | + ----- |
| 123 | + Permutation matrices are costly since they are nothing but row reorder of |
| 124 | + ``L`` and hence indices are strongly recommended to be used instead if the |
| 125 | + permutation is required. The relation in the 2D case then becomes simply |
| 126 | + ``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l` |
| 127 | + to avoid complicated indexing tricks. |
| 128 | +
|
| 129 | + In the 2D case, if one has the indices however, for some reason, the |
| 130 | + permutation matrix is still needed then it can be constructed by |
| 131 | + ``dpnp.eye(M)[P, :]``. |
| 132 | +
|
| 133 | + Warning |
| 134 | + ------- |
| 135 | + This function synchronizes in order to validate array elements |
| 136 | + when ``check_finite=True``, and also synchronizes to compute the |
| 137 | + permutation from LAPACK pivot indices. |
| 138 | +
|
| 139 | + See Also |
| 140 | + -------- |
| 141 | + :obj:`dpnp.scipy.linalg.lu_factor` : LU factorize a matrix |
| 142 | + (compact representation). |
| 143 | + :obj:`dpnp.scipy.linalg.lu_solve` : Solve an equation system using |
| 144 | + the LU factorization of a matrix. |
| 145 | +
|
| 146 | + Examples |
| 147 | + -------- |
| 148 | + >>> import dpnp as np |
| 149 | + >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], |
| 150 | + ... [7, 5, 6, 6], [5, 4, 4, 8]]) |
| 151 | + >>> p, l, u = np.scipy.linalg.lu(A) |
| 152 | + >>> np.allclose(A, p @ l @ u) |
| 153 | + array(True) |
| 154 | +
|
| 155 | + Retrieve the permutation as row indices with ``p_indices=True``: |
| 156 | +
|
| 157 | + >>> p, l, u = np.scipy.linalg.lu(A, p_indices=True) |
| 158 | + >>> p |
| 159 | + array([1, 3, 0, 2]) |
| 160 | + >>> np.allclose(A, l[p] @ u) |
| 161 | + array(True) |
| 162 | +
|
| 163 | + Return the permuted ``L`` directly with ``permute_l=True``: |
| 164 | +
|
| 165 | + >>> pl, u = np.scipy.linalg.lu(A, permute_l=True) |
| 166 | + >>> np.allclose(A, pl @ u) |
| 167 | + array(True) |
| 168 | +
|
| 169 | + Non-square matrices are supported: |
| 170 | +
|
| 171 | + >>> B = np.array([[1, 2, 3], [4, 5, 6]]) |
| 172 | + >>> p, l, u = np.scipy.linalg.lu(B) |
| 173 | + >>> np.allclose(B, p @ l @ u) |
| 174 | + array(True) |
| 175 | +
|
| 176 | + Batched input: |
| 177 | +
|
| 178 | + >>> C = np.random.randn(3, 2, 4, 4) |
| 179 | + >>> p, l, u = np.scipy.linalg.lu(C) |
| 180 | + >>> np.allclose(C, p @ l @ u) |
| 181 | + array(True) |
| 182 | +
|
| 183 | + """ |
| 184 | + |
| 185 | + dpnp.check_supported_arrays_type(a) |
| 186 | + assert_stacked_2d(a) |
| 187 | + |
| 188 | + return dpnp_lu( |
| 189 | + a, |
| 190 | + overwrite_a=overwrite_a, |
| 191 | + check_finite=check_finite, |
| 192 | + p_indices=p_indices, |
| 193 | + permute_l=permute_l, |
| 194 | + ) |
| 195 | + |
| 196 | + |
54 | 197 | def lu_factor(a, overwrite_a=False, check_finite=True): |
55 | 198 | """ |
56 | 199 | Compute the pivoted LU decomposition of `a` matrix. |
@@ -180,13 +323,13 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): |
180 | 323 |
|
181 | 324 | """ |
182 | 325 |
|
183 | | - lu, piv = lu_and_piv |
184 | | - dpnp.check_supported_arrays_type(lu, piv, b) |
185 | | - assert_stacked_2d(lu) |
186 | | - assert_stacked_square(lu) |
| 326 | + lu_matrix, piv = lu_and_piv |
| 327 | + dpnp.check_supported_arrays_type(lu_matrix, piv, b) |
| 328 | + assert_stacked_2d(lu_matrix) |
| 329 | + assert_stacked_square(lu_matrix) |
187 | 330 |
|
188 | 331 | return dpnp_lu_solve( |
189 | | - lu, |
| 332 | + lu_matrix, |
190 | 333 | piv, |
191 | 334 | b, |
192 | 335 | trans=trans, |
|
0 commit comments