-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathsvd.jl
More file actions
108 lines (105 loc) · 4.24 KB
/
svd.jl
File metadata and controls
108 lines (105 loc) · 4.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
CUSOLVER_SVDPolar,
CUSOLVER_Jacobi}
# CUSOLVER SVD implementation
function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm)
check_input(svd_full!, A, USVᴴ)
U, S, Vᴴ = USVᴴ
fill!(S, zero(eltype(S)))
m, n = size(A)
minmn = min(m, n)
if alg isa CUSOLVER_QRIteration
isempty(alg.kwargs) ||
throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments"))
YACUSOLVER.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
elseif alg isa CUSOLVER_SVDPolar
YACUSOLVER.Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
elseif alg isa CUSOLVER_Jacobi
YACUSOLVER.gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
# elseif alg isa LAPACK_Bisection
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
# elseif alg isa LAPACK_Jacobi
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
else
throw(ArgumentError("Unsupported SVD algorithm"))
end
diagview(S) .= view(S, 1:minmn, 1)
view(S, 2:minmn, 1) .= zero(eltype(S))
# TODO: make this controllable using a `gaugefix` keyword argument
for j in 1:max(m, n)
if j <= minmn
u = view(U, :, j)
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(u)))
u .*= s
v .*= conj(s)
elseif j <= m
u = view(U, :, j)
s = conj(sign(_argmaxabs(u)))
u .*= s
else
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(v)))
v .*= s
end
end
return USVᴴ
end
function MatrixAlgebraKit.svd_compact!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm)
check_input(svd_compact!, A, USVᴴ)
U, S, Vᴴ = USVᴴ
if alg isa CUSOLVER_QRIteration
isempty(alg.kwargs) ||
throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments"))
YACUSOLVER.gesvd!(A, S.diag, U, Vᴴ)
elseif alg isa CUSOLVER_SVDPolar
YACUSOLVER.Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
elseif alg isa CUSOLVER_Jacobi
YACUSOLVER.gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...)
# elseif alg isa LAPACK_DivideAndConquer
# isempty(alg.kwargs) ||
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
# YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
# elseif alg isa LAPACK_Bisection
# YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...)
else
throw(ArgumentError("Unsupported SVD algorithm"))
end
# TODO: make this controllable using a `gaugefix` keyword argument
for j in 1:size(U, 2)
u = view(U, :, j)
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(u)))
u .*= s
v .*= conj(s)
end
return USVᴴ
end
_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))
_largest(x, y) = abs(x) < abs(y) ? y : x
function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm)
check_input(svd_vals!, A, S)
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
if alg isa CUSOLVER_QRIteration
isempty(alg.kwargs) ||
throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments"))
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
elseif alg isa CUSOLVER_SVDPolar
YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
elseif alg isa CUSOLVER_Jacobi
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; alg.kwargs...)
# elseif alg isa LAPACK_DivideAndConquer
# isempty(alg.kwargs) ||
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
# YALAPACK.gesdd!(A, S, U, Vᴴ)
# elseif alg isa LAPACK_Bisection
# YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...)
# elseif alg isa LAPACK_Jacobi
# isempty(alg.kwargs) ||
# throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
# YALAPACK.gesvj!(A, S, U, Vᴴ)
else
throw(ArgumentError("Unsupported SVD algorithm"))
end
return S
end