1+ const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
2+ CUSOLVER_SVDPolar,
3+ CUSOLVER_Jacobi}
4+
5+ # CUSOLVER SVD implementation
6+ function MatrixAlgebraKit. svd_full! (A:: CuMatrix , USVᴴ, alg:: CUSOLVER_SVDAlgorithm )
7+ check_input (svd_full!, A, USVᴴ)
8+ U, S, Vᴴ = USVᴴ
9+ fill! (S, zero (eltype (S)))
10+ m, n = size (A)
11+ minmn = min (m, n)
12+ if alg isa CUSOLVER_QRIteration
13+ isempty (alg. kwargs) ||
14+ throw (ArgumentError (" LAPACK_QRIteration does not accept any keyword arguments" ))
15+ YACUSOLVER. gesvd! (A, view (S, 1 : minmn, 1 ), U, Vᴴ)
16+ elseif alg isa CUSOLVER_SVDPolar
17+ YACUSOLVER. Xgesvdp! (A, view (S, 1 : minmn, 1 ), U, Vᴴ; alg. kwargs... )
18+ elseif alg isa CUSOLVER_Jacobi
19+ YACUSOLVER. gesvdj! (A, view (S, 1 : minmn, 1 ), U, Vᴴ; alg. kwargs... )
20+ # elseif alg isa LAPACK_Bisection
21+ # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
22+ # elseif alg isa LAPACK_Jacobi
23+ # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
24+ else
25+ throw (ArgumentError (" Unsupported SVD algorithm" ))
26+ end
27+ diagview (S) .= view (S, 1 : minmn, 1 )
28+ view (S, 2 : minmn, 1 ) .= zero (eltype (S))
29+ # TODO : make this controllable using a `gaugefix` keyword argument
30+ for j in 1 : max (m, n)
31+ if j <= minmn
32+ u = view (U, :, j)
33+ v = view (Vᴴ, j, :)
34+ s = conj (sign (_argmaxabs (u)))
35+ u .*= s
36+ v .*= conj (s)
37+ elseif j <= m
38+ u = view (U, :, j)
39+ s = conj (sign (_argmaxabs (u)))
40+ u .*= s
41+ else
42+ v = view (Vᴴ, j, :)
43+ s = conj (sign (_argmaxabs (v)))
44+ v .*= s
45+ end
46+ end
47+ return USVᴴ
48+ end
49+
50+ function MatrixAlgebraKit. svd_compact! (A:: CuMatrix , USVᴴ, alg:: CUSOLVER_SVDAlgorithm )
51+ check_input (svd_compact!, A, USVᴴ)
52+ U, S, Vᴴ = USVᴴ
53+ if alg isa CUSOLVER_QRIteration
54+ isempty (alg. kwargs) ||
55+ throw (ArgumentError (" CUSOLVER_QRIteration does not accept any keyword arguments" ))
56+ YACUSOLVER. gesvd! (A, S. diag, U, Vᴴ)
57+ elseif alg isa CUSOLVER_SVDPolar
58+ YACUSOLVER. Xgesvdp! (A, S. diag, U, Vᴴ; alg. kwargs... )
59+ elseif alg isa CUSOLVER_Jacobi
60+ YACUSOLVER. gesvdj! (A, S. diag, U, Vᴴ; alg. kwargs... )
61+ # elseif alg isa LAPACK_DivideAndConquer
62+ # isempty(alg.kwargs) ||
63+ # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
64+ # YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
65+ # elseif alg isa LAPACK_Bisection
66+ # YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...)
67+ else
68+ throw (ArgumentError (" Unsupported SVD algorithm" ))
69+ end
70+ # TODO : make this controllable using a `gaugefix` keyword argument
71+ for j in 1 : size (U, 2 )
72+ u = view (U, :, j)
73+ v = view (Vᴴ, j, :)
74+ s = conj (sign (_argmaxabs (u)))
75+ u .*= s
76+ v .*= conj (s)
77+ end
78+ return USVᴴ
79+ end
80+ _argmaxabs (x) = reduce (_largest, x; init= zero (eltype (x)))
81+ _largest (x, y) = abs (x) < abs (y) ? y : x
82+
83+ function MatrixAlgebraKit. svd_vals! (A:: CuMatrix , S, alg:: CUSOLVER_SVDAlgorithm )
84+ check_input (svd_vals!, A, S)
85+ U, Vᴴ = similar (A, (0 , 0 )), similar (A, (0 , 0 ))
86+ if alg isa CUSOLVER_QRIteration
87+ isempty (alg. kwargs) ||
88+ throw (ArgumentError (" CUSOLVER_QRIteration does not accept any keyword arguments" ))
89+ YACUSOLVER. gesvd! (A, S, U, Vᴴ)
90+ elseif alg isa CUSOLVER_SVDPolar
91+ YACUSOLVER. Xgesvdp! (A, S, U, Vᴴ; alg. kwargs... )
92+ elseif alg isa CUSOLVER_Jacobi
93+ YACUSOLVER. gesvdj! (A, S, U, Vᴴ; alg. kwargs... )
94+ # elseif alg isa LAPACK_DivideAndConquer
95+ # isempty(alg.kwargs) ||
96+ # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
97+ # YALAPACK.gesdd!(A, S, U, Vᴴ)
98+ # elseif alg isa LAPACK_Bisection
99+ # YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...)
100+ # elseif alg isa LAPACK_Jacobi
101+ # isempty(alg.kwargs) ||
102+ # throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
103+ # YALAPACK.gesvj!(A, S, U, Vᴴ)
104+ else
105+ throw (ArgumentError (" Unsupported SVD algorithm" ))
106+ end
107+ return S
108+ end
0 commit comments