@@ -7,6 +7,8 @@ copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A)
77copy_input (:: typeof (svd_vals), A) = copy_input (svd_full, A)
88copy_input (:: typeof (svd_trunc), A) = copy_input (svd_compact, A)
99
10+ copy_input (:: typeof (svd_full), A:: Diagonal ) = copy (A)
11+
1012# TODO : many of these checks are happening again in the LAPACK routines
1113function check_input (:: typeof (svd_full!), A:: AbstractMatrix , USVᴴ, :: AbstractAlgorithm )
1214 m, n = size (A)
@@ -42,6 +44,32 @@ function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::AbstractAlgori
4244 return nothing
4345end
4446
47+ function check_input (:: typeof (svd_full!), A:: AbstractMatrix , USVᴴ, :: DiagonalAlgorithm )
48+ m, n = size (A)
49+ @assert m == n && isdiag (A)
50+ U, S, Vᴴ = USVᴴ
51+ @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
52+ @check_size (U, (m, m))
53+ @check_scalar (U, A)
54+ @check_size (S, (m, n))
55+ @check_scalar (S, A, real)
56+ @check_size (Vᴴ, (n, n))
57+ @check_scalar (Vᴴ, A)
58+ return nothing
59+ end
60+ function check_input (:: typeof (svd_compact!), A:: AbstractMatrix , USVᴴ,
61+ alg:: DiagonalAlgorithm )
62+ return check_input (svd_full!, A, USVᴴ, alg)
63+ end
64+ function check_input (:: typeof (svd_vals!), A:: AbstractMatrix , S, :: DiagonalAlgorithm )
65+ m, n = size (A)
66+ @assert m == n && isdiag (A)
67+ @assert S isa AbstractVector
68+ @check_size (S, (m,))
69+ @check_scalar (S, A, real)
70+ return nothing
71+ end
72+
4573# Outputs
4674# -------
4775function initialize_output (:: typeof (svd_full!), A:: AbstractMatrix , :: AbstractAlgorithm )
@@ -66,6 +94,18 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat
6694 return initialize_output (svd_compact!, A, alg. alg)
6795end
6896
97+ function initialize_output (:: typeof (svd_full!), A:: Diagonal , :: DiagonalAlgorithm )
98+ TA = eltype (A)
99+ TUV = Base. promote_op (sign_safe, TA)
100+ return similar (A, TUV, size (A)), similar (A, real (TA)), similar (A, TUV, size (A))
101+ end
102+ function initialize_output (:: typeof (svd_compact!), A:: Diagonal , alg:: DiagonalAlgorithm )
103+ return initialize_output (svd_full!, A, alg)
104+ end
105+ function initialize_output (:: typeof (svd_vals!), A:: Diagonal , :: DiagonalAlgorithm )
106+ return eltype (A) <: Real ? diagview (A) : similar (A, real (eltype (A)), size (A, 1 ))
107+ end
108+
69109function gaugefix! (:: typeof (svd_full!), U, S, Vᴴ, m:: Int , n:: Int )
70110 for j in 1 : max (m, n)
71111 if j <= min (m, n)
@@ -111,7 +151,6 @@ function gaugefix!(::typeof(svd_trunc!), U, S, Vᴴ, m::Int, n::Int)
111151 return (U, S, Vᴴ)
112152end
113153
114-
115154# Implementation
116155# --------------
117156function svd_full! (A:: AbstractMatrix , USVᴴ, alg:: LAPACK_SVDAlgorithm )
@@ -203,7 +242,39 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm)
203242 return truncate! (svd_trunc!, USVᴴ′, alg. trunc)
204243end
205244
206- # ## GPU logic
245+ # Diagonal logic
246+ # --------------
247+ function svd_full! (A:: AbstractMatrix , USVᴴ, alg:: DiagonalAlgorithm )
248+ check_input (svd_full!, A, USVᴴ, alg)
249+ Ad = diagview (A)
250+ U, S, Vᴴ = USVᴴ
251+ Sd = diagview (S)
252+ Sd .= abs .(Ad)
253+ p = sortperm (Sd; rev= true )
254+ permute! (Sd, p)
255+ T = eltype (Vᴴ)
256+ zero! (U)
257+ zero! (Vᴴ)
258+ @inbounds for (i, pi ) in enumerate (p)
259+ s = Ad[pi ]
260+ U[pi , i] = sign_safe (s)
261+ Vᴴ[i, pi ] = one (T)
262+ end
263+ return U, S, Vᴴ
264+ end
265+ function svd_compact! (A:: AbstractMatrix , USVᴴ, alg:: DiagonalAlgorithm )
266+ return svd_full! (A, USVᴴ, alg)
267+ end
268+ function svd_vals! (A:: AbstractMatrix , S, alg:: DiagonalAlgorithm )
269+ check_input (svd_vals!, A, S, alg)
270+ Ad = diagview (A)
271+ S .= abs .(Ad)
272+ sort! (S; rev= true )
273+ return S
274+ end
275+
276+ # GPU logic
277+ # ---------
207278# placed here to avoid code duplication since much of the logic is replicable across
208279# CUDA and AMDGPU
209280# ##
@@ -213,12 +284,13 @@ const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
213284 CUSOLVER_Randomized}
214285const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration,
215286 ROCSOLVER_Jacobi}
216- const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm}
287+ const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm,ROCSOLVER_SVDAlgorithm}
217288
218289const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
219290const GPU_Randomized = Union{CUSOLVER_Randomized}
220291
221- function check_input (:: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ, alg:: CUSOLVER_Randomized )
292+ function check_input (:: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ,
293+ alg:: CUSOLVER_Randomized )
222294 m, n = size (A)
223295 minmn = min (m, n)
224296 U, S, Vᴴ = USVᴴ
@@ -232,7 +304,8 @@ function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOL
232304 return nothing
233305end
234306
235- function initialize_output (:: typeof (svd_trunc!), A:: AbstractMatrix , alg:: TruncatedAlgorithm{<:CUSOLVER_Randomized} )
307+ function initialize_output (:: typeof (svd_trunc!), A:: AbstractMatrix ,
308+ alg:: TruncatedAlgorithm{<:CUSOLVER_Randomized} )
236309 m, n = size (A)
237310 minmn = min (m, n)
238311 U = similar (A, (m, m))
@@ -241,10 +314,22 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat
241314 return (U, S, Vᴴ)
242315end
243316
244- _gpu_gesvd! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix ) = throw (MethodError (_gpu_gesvd!, (A, S, U, Vᴴ)))
245- _gpu_Xgesvdp! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix ; kwargs... ) = throw (MethodError (_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
246- _gpu_Xgesvdr! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix ; kwargs... ) = throw (MethodError (_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
247- _gpu_gesvdj! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix ; kwargs... ) = throw (MethodError (_gpu_gesvdj!, (A, S, U, Vᴴ)))
317+ function _gpu_gesvd! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix ,
318+ Vᴴ:: AbstractMatrix )
319+ throw (MethodError (_gpu_gesvd!, (A, S, U, Vᴴ)))
320+ end
321+ function _gpu_Xgesvdp! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix ,
322+ Vᴴ:: AbstractMatrix ; kwargs... )
323+ throw (MethodError (_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
324+ end
325+ function _gpu_Xgesvdr! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix ,
326+ Vᴴ:: AbstractMatrix ; kwargs... )
327+ throw (MethodError (_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
328+ end
329+ function _gpu_gesvdj! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix ,
330+ Vᴴ:: AbstractMatrix ; kwargs... )
331+ throw (MethodError (_gpu_gesvdj!, (A, S, U, Vᴴ)))
332+ end
248333# GPU SVD implementation
249334function MatrixAlgebraKit. svd_full! (A:: AbstractMatrix , USVᴴ, alg:: GPU_SVDAlgorithm )
250335 check_input (svd_full!, A, USVᴴ, alg)
@@ -298,7 +383,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
298383 throw (ArgumentError (" Unsupported SVD algorithm" ))
299384 end
300385 # TODO : make this controllable using a `gaugefix` keyword argument
301- gaugefix! (svd_compact!, U, S, Vᴴ, size (A)... )
386+ gaugefix! (svd_compact!, U, S, Vᴴ, size (A)... )
302387 return USVᴴ
303388end
304389_argmaxabs (x) = reduce (_largest, x; init= zero (eltype (x)))
0 commit comments