1+ const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
2+ CUSOLVER_SVDPolar,
3+ CUSOLVER_Jacobi}
4+
5+ # CUSOLVER SVD implementation
6+ function MatrixAlgebraKit. svd_full! (A:: CuMatrix , USVᴴ, alg:: CUSOLVER_SVDAlgorithm )
7+ check_input (svd_full!, A, USVᴴ)
8+ U, S, Vᴴ = USVᴴ
9+ fill! (S, zero (eltype (S)))
10+ m, n = size (A)
11+ minmn = min (m, n)
12+ if alg isa CUSOLVER_QRIteration
13+ isempty (alg. kwargs) ||
14+ throw (ArgumentError (" LAPACK_QRIteration does not accept any keyword arguments" ))
15+ YACUSOLVER. gesvd! (A, view (S, 1 : minmn, 1 ), U, Vᴴ)
16+ elseif alg isa CUSOLVER_SVDPolar
17+ YACUSOLVER. Xgesvdp! (A, view (S, 1 : minmn, 1 ), U, Vᴴ; alg. kwargs... )
18+ # elseif alg isa LAPACK_Bisection
19+ # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
20+ # elseif alg isa LAPACK_Jacobi
21+ # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
22+ else
23+ throw (ArgumentError (" Unsupported SVD algorithm" ))
24+ end
25+ diagview (S) .= view (S, 1 : minmn, 1 )
26+ view (S, 2 : minmn, 1 ) .= zero (eltype (S))
27+ # TODO : make this controllable using a `gaugefix` keyword argument
28+ for j in 1 : max (m, n)
29+ if j <= minmn
30+ u = view (U, :, j)
31+ v = view (Vᴴ, j, :)
32+ s = conj (sign (_argmaxabs (u)))
33+ u .*= s
34+ v .*= conj (s)
35+ elseif j <= m
36+ u = view (U, :, j)
37+ s = conj (sign (_argmaxabs (u)))
38+ u .*= s
39+ else
40+ v = view (Vᴴ, j, :)
41+ s = conj (sign (_argmaxabs (v)))
42+ v .*= s
43+ end
44+ end
45+ return USVᴴ
46+ end
47+
48+ function MatrixAlgebraKit. svd_compact! (A:: CuMatrix , USVᴴ, alg:: CUSOLVER_SVDAlgorithm )
49+ check_input (svd_compact!, A, USVᴴ)
50+ U, S, Vᴴ = USVᴴ
51+ if alg isa CUSOLVER_QRIteration
52+ isempty (alg. kwargs) ||
53+ throw (ArgumentError (" CUSOLVER_QRIteration does not accept any keyword arguments" ))
54+ YACUSOLVER. gesvd! (A, S. diag, U, Vᴴ)
55+ elseif alg isa CUSOLVER_SVDPolar
56+ YACUSOLVER. Xgesvdp! (A, S. diag, U, Vᴴ; alg. kwargs... )
57+ # elseif alg isa LAPACK_DivideAndConquer
58+ # isempty(alg.kwargs) ||
59+ # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
60+ # YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
61+ # elseif alg isa LAPACK_Bisection
62+ # YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...)
63+ # elseif alg isa LAPACK_Jacobi
64+ # isempty(alg.kwargs) ||
65+ # throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
66+ # YALAPACK.gesvj!(A, S.diag, U, Vᴴ)
67+ else
68+ throw (ArgumentError (" Unsupported SVD algorithm" ))
69+ end
70+ # TODO : make this controllable using a `gaugefix` keyword argument
71+ for j in 1 : size (U, 2 )
72+ u = view (U, :, j)
73+ v = view (Vᴴ, j, :)
74+ s = conj (sign (_argmaxabs (u)))
75+ u .*= s
76+ v .*= conj (s)
77+ end
78+ return USVᴴ
79+ end
80+ _argmaxabs (x) = reduce (_largest, x; init= zero (eltype (x)))
81+ _largest (x, y) = abs (x) < abs (y) ? y : x
82+
83+ function MatrixAlgebraKit. svd_vals! (A:: CuMatrix , S, alg:: CUSOLVER_SVDAlgorithm )
84+ check_input (svd_vals!, A, S)
85+ U, Vᴴ = similar (A, (0 , 0 )), similar (A, (0 , 0 ))
86+ if alg isa CUSOLVER_QRIteration
87+ isempty (alg. kwargs) ||
88+ throw (ArgumentError (" CUSOLVER_QRIteration does not accept any keyword arguments" ))
89+ YACUSOLVER. gesvd! (A, S, U, Vᴴ)
90+ elseif alg isa CUSOLVER_SVDPolar
91+ YACUSOLVER. Xgesvdp! (A, S, U, Vᴴ; alg. kwargs... )
92+ # elseif alg isa LAPACK_DivideAndConquer
93+ # isempty(alg.kwargs) ||
94+ # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
95+ # YALAPACK.gesdd!(A, S, U, Vᴴ)
96+ # elseif alg isa LAPACK_Bisection
97+ # YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...)
98+ # elseif alg isa LAPACK_Jacobi
99+ # isempty(alg.kwargs) ||
100+ # throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
101+ # YALAPACK.gesvj!(A, S, U, Vᴴ)
102+ else
103+ throw (ArgumentError (" Unsupported SVD algorithm" ))
104+ end
105+ return S
106+ end
0 commit comments