-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathlinalg.py
More file actions
62 lines (50 loc) · 1.8 KB
/
linalg.py
File metadata and controls
62 lines (50 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from cupy.linalg import * # noqa: F403
# https://github.com/cupy/cupy/issues/9749
from cupy.linalg import lstsq # noqa: F401
# cupy.linalg doesn't have __all__ in cupy<14. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n: dict[str, object] = {}
exec('from cupy.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n) + ['lstsq']
del _n
try:
# cupy 14 exports it, cupy 13 does not
from cupy.linalg import annotations # noqa: F401
linalg_all += ['annotations']
except ImportError:
pass
from ..common import _linalg
from .._internal import get_xp
import cupy as cp
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
cross = get_xp(cp)(_linalg.cross)
outer = get_xp(cp)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(cp)(_linalg.eigh)
qr = get_xp(cp)(_linalg.qr)
slogdet = get_xp(cp)(_linalg.slogdet)
svd = get_xp(cp)(_linalg.svd)
cholesky = get_xp(cp)(_linalg.cholesky)
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)
__all__ = linalg_all + _linalg.__all__
# cupy 13 does not have __all__, cupy 14 has it: remove duplicates
__all__ = sorted(set(__all__))
def __dir__() -> list[str]:
return __all__