@@ -7,6 +7,8 @@ copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A)
77copy_input (:: typeof (svd_vals), A) = copy_input (svd_full, A)
88copy_input (:: typeof (svd_trunc), A) = copy_input (svd_compact, A)
99
10+ copy_input (:: typeof (svd_full), A:: Diagonal ) = copy (A)
11+
1012# TODO : many of these checks are happening again in the LAPACK routines
1113function check_input (:: typeof (svd_full!), A:: AbstractMatrix , USVᴴ, :: AbstractAlgorithm )
1214 m, n = size (A)
@@ -42,6 +44,32 @@ function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::AbstractAlgori
4244 return nothing
4345end
4446
47+ function check_input (:: typeof (svd_full!), A:: AbstractMatrix , USVᴴ, :: DiagonalAlgorithm )
48+ m, n = size (A)
49+ @assert m == n && isdiag (A)
50+ U, S, Vᴴ = USVᴴ
51+ @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
52+ @check_size (U, (m, m))
53+ @check_scalar (U, A)
54+ @check_size (S, (m, n))
55+ @check_scalar (S, A, real)
56+ @check_size (Vᴴ, (n, n))
57+ @check_scalar (Vᴴ, A)
58+ return nothing
59+ end
60+ function check_input (:: typeof (svd_compact!), A:: AbstractMatrix , USVᴴ,
61+ alg:: DiagonalAlgorithm )
62+ return check_input (svd_full!, A, USVᴴ, alg)
63+ end
64+ function check_input (:: typeof (svd_vals!), A:: AbstractMatrix , S, :: DiagonalAlgorithm )
65+ m, n = size (A)
66+ @assert m == n && isdiag (A)
67+ @assert S isa AbstractVector
68+ @check_size (S, (m,))
69+ @check_scalar (S, A, real)
70+ return nothing
71+ end
72+
4573# Outputs
4674# -------
4775function initialize_output (:: typeof (svd_full!), A:: AbstractMatrix , :: AbstractAlgorithm )
@@ -66,6 +94,18 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat
6694 return initialize_output (svd_compact!, A, alg. alg)
6795end
6896
97+ function initialize_output (:: typeof (svd_full!), A:: Diagonal , :: DiagonalAlgorithm )
98+ TA = eltype (A)
99+ TUV = Base. promote_op (sign_safe, TA)
100+ return similar (A, TUV, size (A)), similar (A, real (TA)), similar (A, TUV, size (A))
101+ end
102+ function initialize_output (:: typeof (svd_compact!), A:: Diagonal , alg:: DiagonalAlgorithm )
103+ return initialize_output (svd_full!, A, alg)
104+ end
105+ function initialize_output (:: typeof (svd_vals!), A:: Diagonal , :: DiagonalAlgorithm )
106+ return eltype (A) <: Real ? diagview (A) : similar (A, real (eltype (A)), size (A, 1 ))
107+ end
108+
69109function gaugefix! (:: typeof (svd_full!), U, S, Vᴴ, m:: Int , n:: Int )
70110 for j in 1 : max (m, n)
71111 if j <= min (m, n)
@@ -111,7 +151,6 @@ function gaugefix!(::typeof(svd_trunc!), U, S, Vᴴ, m::Int, n::Int)
111151 return (U, S, Vᴴ)
112152end
113153
114-
115154# Implementation
116155# --------------
117156function svd_full! (A:: AbstractMatrix , USVᴴ, alg:: LAPACK_SVDAlgorithm )
@@ -203,7 +242,39 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm)
203242 return truncate! (svd_trunc!, USVᴴ′, alg. trunc)
204243end
205244
206- # ## GPU logic
245+ # Diagonal logic
246+ # --------------
247+ function svd_full! (A:: AbstractMatrix , USVᴴ, alg:: DiagonalAlgorithm )
248+ check_input (svd_full!, A, USVᴴ, alg)
249+ Ad = diagview (A)
250+ U, S, Vᴴ = USVᴴ
251+ Sd = diagview (S)
252+ Sd .= abs .(Ad)
253+ p = sortperm (Sd; rev= true )
254+ permute! (Sd, p)
255+ T = eltype (Vᴴ)
256+ zero! (U)
257+ zero! (Vᴴ)
258+ @inbounds for (i, pi ) in enumerate (p)
259+ s = Ad[pi ]
260+ U[pi , i] = sign_safe (s)
261+ Vᴴ[i, pi ] = one (T)
262+ end
263+ return U, S, Vᴴ
264+ end
265+ function svd_compact! (A:: AbstractMatrix , USVᴴ, alg:: DiagonalAlgorithm )
266+ return svd_full! (A, USVᴴ, alg)
267+ end
268+ function svd_vals! (A:: AbstractMatrix , S, alg:: DiagonalAlgorithm )
269+ check_input (svd_vals!, A, S, alg)
270+ Ad = diagview (A)
271+ S .= abs .(Ad)
272+ sort! (S; rev= true )
273+ return S
274+ end
275+
276+ # GPU logic
277+ # ---------
207278# placed here to avoid code duplication since much of the logic is replicable across
208279# CUDA and AMDGPU
209280# ##
0 commit comments