Skip to content

Commit 8b22efb

Browse files
authored
Merge pull request data-apis#379 from ev-br/add_eig
ENH: add compatibility shims for {eig,eigvals}
2 parents 43de4af + d4c6f5d commit 8b22efb

File tree

4 files changed

+91
-0
lines changed

4 files changed

+91
-0
lines changed

array_api_compat/common/_linalg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class EighResult(NamedTuple):
3434
eigenvalues: Array
3535
eigenvectors: Array
3636

37+
class EigResult(NamedTuple):
38+
eigenvalues: Array
39+
eigenvectors: Array
40+
3741
class QRResult(NamedTuple):
3842
Q: Array
3943
R: Array

array_api_compat/numpy/linalg.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
cross = get_xp(np)(_linalg.cross)
2020
outer = get_xp(np)(_linalg.outer)
2121
EighResult = _linalg.EighResult
22+
EigResult = _linalg.EigResult
2223
QRResult = _linalg.QRResult
2324
SlogdetResult = _linalg.SlogdetResult
2425
SVDResult = _linalg.SVDResult
@@ -97,6 +98,85 @@ def solve(x1: Array, x2: Array, /) -> Array:
9798
return wrap(r.astype(result_t, copy=False))
9899

99100

101+
# Unlike numpy.linalg.eig, Array API version always returns complex results
102+
103+
def eig(x: Array, /) -> tuple[Array, Array]:
104+
try:
105+
from numpy.linalg._linalg import ( # type: ignore[attr-defined]
106+
_assert_stacked_square,
107+
_assert_finite,
108+
_commonType,
109+
_makearray,
110+
_raise_linalgerror_eigenvalues_nonconvergence,
111+
isComplexType,
112+
_complexType,
113+
)
114+
except ImportError:
115+
from numpy.linalg.linalg import ( # type: ignore[attr-defined]
116+
_assert_stacked_square,
117+
_assert_finite,
118+
_commonType,
119+
_makearray,
120+
_raise_linalgerror_eigenvalues_nonconvergence,
121+
isComplexType,
122+
_complexType,
123+
)
124+
from numpy.linalg import _umath_linalg
125+
126+
x, wrap = _makearray(x)
127+
_assert_stacked_square(x)
128+
_assert_finite(x)
129+
t, result_t = _commonType(x)
130+
131+
signature = 'D->DD' if isComplexType(t) else 'd->DD'
132+
with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence,
133+
invalid='call', over='ignore', divide='ignore',
134+
under='ignore'):
135+
w, vt = _umath_linalg.eig(x, signature=signature)
136+
137+
result_t = _complexType(result_t)
138+
vt = vt.astype(result_t, copy=False)
139+
return EigResult(w.astype(result_t, copy=False), wrap(vt))
140+
141+
142+
def eigvals(x: Array, /) -> Array:
143+
try:
144+
from numpy.linalg._linalg import ( # type: ignore[attr-defined]
145+
_assert_stacked_square,
146+
_assert_finite,
147+
_commonType,
148+
_makearray,
149+
_raise_linalgerror_eigenvalues_nonconvergence,
150+
isComplexType,
151+
_complexType,
152+
)
153+
except ImportError:
154+
from numpy.linalg.linalg import ( # type: ignore[attr-defined]
155+
_assert_stacked_square,
156+
_assert_finite,
157+
_commonType,
158+
_makearray,
159+
_raise_linalgerror_eigenvalues_nonconvergence,
160+
isComplexType,
161+
_complexType,
162+
)
163+
from numpy.linalg import _umath_linalg
164+
165+
x, wrap = _makearray(x)
166+
_assert_stacked_square(x)
167+
_assert_finite(x)
168+
t, result_t = _commonType(x)
169+
170+
signature = 'D->D' if isComplexType(t) else 'd->D'
171+
with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence,
172+
invalid='call', over='ignore', divide='ignore',
173+
under='ignore'):
174+
w = _umath_linalg.eigvals(x, signature=signature)
175+
176+
result_t = _complexType(result_t)
177+
return w.astype(result_t, copy=False)
178+
179+
100180
# These functions are completely new here. If the library already has them
101181
# (i.e., numpy 2.0), use the library version instead of our wrapper.
102182
if hasattr(np.linalg, "vector_norm"):

cupy-xfails.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
2424

2525
array_api_tests/test_linalg.py::test_solve
2626

27+
# 2025.12 support; {eig,eigvals} are new in CuPy 14
28+
array_api_tests/test_linalg.py::test_eig
29+
array_api_tests/test_linalg.py::test_eigvals
30+
2731
# We cannot modify array methods
2832
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
2933
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]

dask-xfails.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_broadcast
136136
array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_empty
137137
array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error
138138

139+
array_api_tests/test_linalg.py::test_eig
140+
array_api_tests/test_linalg.py::test_eigvals
141+
139142
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.)
140143
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
141144
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

0 commit comments

Comments
 (0)