-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathlinalg.py
More file actions
208 lines (178 loc) · 6.36 KB
/
linalg.py
File metadata and controls
208 lines (178 loc) · 6.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# pyright: reportAttributeAccessIssue=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
from __future__ import annotations
import numpy as np
from .._internal import clone_module, get_xp
from ..common import _linalg
__all__ = clone_module("numpy.linalg", globals())
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
from ._typing import Array
cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
EigResult = _linalg.EigResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
# of matrices. The np.linalg.solve behavior of allowing stacks of both
# matrices and vectors is ambiguous c.f.
# https://github.com/numpy/numpy/issues/15349 and
# https://github.com/data-apis/array-api/issues/285.
# To workaround this, the below is the code from np.linalg.solve except
# only calling solve1 in the exactly 1D case.
# This code is here instead of in common because it is numpy specific. Also
# note that CuPy's solve() does not currently support broadcasting (see
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
def solve(x1: Array, x2: Array, /) -> Array:
try:
from numpy.linalg._linalg import ( # type: ignore[attr-defined]
_assert_stacked_2d,
_assert_stacked_square,
_commonType,
_makearray,
_raise_linalgerror_singular,
isComplexType,
)
except ImportError:
from numpy.linalg.linalg import ( # type: ignore[attr-defined]
_assert_stacked_2d,
_assert_stacked_square,
_commonType,
_makearray,
_raise_linalgerror_singular,
isComplexType,
)
from numpy.linalg import _umath_linalg
x1, _ = _makearray(x1)
_assert_stacked_2d(x1)
_assert_stacked_square(x1)
x2, wrap = _makearray(x2)
t, result_t = _commonType(x1, x2)
# This part is different from np.linalg.solve
gufunc: np.ufunc
if x2.ndim == 1:
gufunc = _umath_linalg.solve1
else:
gufunc = _umath_linalg.solve
# This does nothing currently but is left in because it will be relevant
# when complex dtype support is added to the spec in 2022.
signature = "DD->D" if isComplexType(t) else "dd->d"
with np.errstate(
call=_raise_linalgerror_singular,
invalid="call",
over="ignore",
divide="ignore",
under="ignore",
):
r: Array = gufunc(x1, x2, signature=signature)
return wrap(r.astype(result_t, copy=False))
# Unlike numpy.linalg.eig, Array API version always returns complex results
def eig(x: Array, /) -> tuple[Array, Array]:
try:
from numpy.linalg._linalg import ( # type: ignore[attr-defined]
_assert_stacked_square,
_assert_finite,
_commonType,
_makearray,
_raise_linalgerror_eigenvalues_nonconvergence,
isComplexType,
_complexType,
)
except ImportError:
from numpy.linalg.linalg import ( # type: ignore[attr-defined]
_assert_stacked_square,
_assert_finite,
_commonType,
_makearray,
_raise_linalgerror_eigenvalues_nonconvergence,
isComplexType,
_complexType,
)
from numpy.linalg import _umath_linalg
x, wrap = _makearray(x)
_assert_stacked_square(x)
_assert_finite(x)
t, result_t = _commonType(x)
signature = 'D->DD' if isComplexType(t) else 'd->DD'
with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence,
invalid='call', over='ignore', divide='ignore',
under='ignore'):
w, vt = _umath_linalg.eig(x, signature=signature)
result_t = _complexType(result_t)
vt = vt.astype(result_t, copy=False)
return EigResult(w.astype(result_t, copy=False), wrap(vt))
def eigvals(x: Array, /) -> Array:
try:
from numpy.linalg._linalg import ( # type: ignore[attr-defined]
_assert_stacked_square,
_assert_finite,
_commonType,
_makearray,
_raise_linalgerror_eigenvalues_nonconvergence,
isComplexType,
_complexType,
)
except ImportError:
from numpy.linalg.linalg import ( # type: ignore[attr-defined]
_assert_stacked_square,
_assert_finite,
_commonType,
_makearray,
_raise_linalgerror_eigenvalues_nonconvergence,
isComplexType,
_complexType,
)
from numpy.linalg import _umath_linalg
x, wrap = _makearray(x)
_assert_stacked_square(x)
_assert_finite(x)
t, result_t = _commonType(x)
signature = 'D->D' if isComplexType(t) else 'd->D'
with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence,
invalid='call', over='ignore', divide='ignore',
under='ignore'):
w = _umath_linalg.eigvals(x, signature=signature)
result_t = _complexType(result_t)
return w.astype(result_t, copy=False)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, "vector_norm"):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
_all = [
"LinAlgError",
"cond",
"det",
"eig",
"eigvals",
"eigvalsh",
"inv",
"lstsq",
"matrix_power",
"multi_dot",
"norm",
"solve",
"tensorinv",
"tensorsolve",
"vector_norm",
]
__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all))
def __dir__() -> list[str]:
return __all__