Skip to content

Commit 77ef1e6

Browse files
committed
Add Diagonal svd implementation and tests
1 parent e35c55e commit 77ef1e6

3 files changed

Lines changed: 139 additions & 19 deletions

File tree

src/implementations/svd.jl

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A)
77
copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A)
88
copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A)
99

10+
copy_input(::typeof(svd_full), A::Diagonal) = copy(A)
11+
1012
# TODO: many of these checks are happening again in the LAPACK routines
1113
function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ, ::AbstractAlgorithm)
1214
m, n = size(A)
@@ -42,6 +44,32 @@ function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::AbstractAlgori
4244
return nothing
4345
end
4446

47+
function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ, ::DiagonalAlgorithm)
48+
m, n = size(A)
49+
@assert m == n && isdiag(A)
50+
U, S, Vᴴ = USVᴴ
51+
@assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
52+
@check_size(U, (m, m))
53+
@check_scalar(U, A)
54+
@check_size(S, (m, n))
55+
@check_scalar(S, A, real)
56+
@check_size(Vᴴ, (n, n))
57+
@check_scalar(Vᴴ, A)
58+
return nothing
59+
end
60+
function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ,
61+
alg::DiagonalAlgorithm)
62+
return check_input(svd_full!, A, USVᴴ, alg)
63+
end
64+
function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::DiagonalAlgorithm)
65+
m, n = size(A)
66+
@assert m == n && isdiag(A)
67+
@assert S isa AbstractVector
68+
@check_size(S, (m,))
69+
@check_scalar(S, A, real)
70+
return nothing
71+
end
72+
4573
# Outputs
4674
# -------
4775
function initialize_output(::typeof(svd_full!), A::AbstractMatrix, ::AbstractAlgorithm)
@@ -66,6 +94,18 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat
6694
return initialize_output(svd_compact!, A, alg.alg)
6795
end
6896

97+
function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithm)
98+
TA = eltype(A)
99+
TUV = Base.promote_op(sign_safe, TA)
100+
return similar(A, TUV, size(A)), similar(A, real(TA)), similar(A, TUV, size(A))
101+
end
102+
function initialize_output(::typeof(svd_compact!), A::Diagonal, alg::DiagonalAlgorithm)
103+
return initialize_output(svd_full!, A, alg)
104+
end
105+
function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm)
106+
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
107+
end
108+
69109
function gaugefix!(::typeof(svd_full!), U, S, Vᴴ, m::Int, n::Int)
70110
for j in 1:max(m, n)
71111
if j <= min(m, n)
@@ -111,7 +151,6 @@ function gaugefix!(::typeof(svd_trunc!), U, S, Vᴴ, m::Int, n::Int)
111151
return (U, S, Vᴴ)
112152
end
113153

114-
115154
# Implementation
116155
# --------------
117156
function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
@@ -203,7 +242,39 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm)
203242
return truncate!(svd_trunc!, USVᴴ′, alg.trunc)
204243
end
205244

206-
### GPU logic
245+
# Diagonal logic
246+
# --------------
247+
function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
248+
check_input(svd_full!, A, USVᴴ, alg)
249+
Ad = diagview(A)
250+
U, S, Vᴴ = USVᴴ
251+
Sd = diagview(S)
252+
Sd .= abs.(Ad)
253+
p = sortperm(Sd; rev=true)
254+
permute!(Sd, p)
255+
T = eltype(Vᴴ)
256+
zero!(U)
257+
zero!(Vᴴ)
258+
@inbounds for (i, pi) in enumerate(p)
259+
s = Ad[pi]
260+
U[pi, i] = sign_safe(s)
261+
Vᴴ[i, pi] = one(T)
262+
end
263+
return U, S, Vᴴ
264+
end
265+
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
266+
return svd_full!(A, USVᴴ, alg)
267+
end
268+
function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithm)
269+
check_input(svd_vals!, A, S, alg)
270+
Ad = diagview(A)
271+
S .= abs.(Ad)
272+
sort!(S; rev=true)
273+
return S
274+
end
275+
276+
# GPU logic
277+
# ---------
207278
# placed here to avoid code duplication since much of the logic is replicable across
208279
# CUDA and AMDGPU
209280
###
@@ -213,12 +284,13 @@ const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
213284
CUSOLVER_Randomized}
214285
const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration,
215286
ROCSOLVER_Jacobi}
216-
const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm}
287+
const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm,ROCSOLVER_SVDAlgorithm}
217288

218289
const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
219290
const GPU_Randomized = Union{CUSOLVER_Randomized}
220291

221-
function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized)
292+
function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ,
293+
alg::CUSOLVER_Randomized)
222294
m, n = size(A)
223295
minmn = min(m, n)
224296
U, S, Vᴴ = USVᴴ
@@ -232,7 +304,8 @@ function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOL
232304
return nothing
233305
end
234306

235-
function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized})
307+
function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix,
308+
alg::TruncatedAlgorithm{<:CUSOLVER_Randomized})
236309
m, n = size(A)
237310
minmn = min(m, n)
238311
U = similar(A, (m, m))
@@ -241,10 +314,22 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat
241314
return (U, S, Vᴴ)
242315
end
243316

244-
_gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) = throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ)))
245-
_gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
246-
_gpu_Xgesvdr!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
247-
_gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ)))
317+
function _gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
318+
Vᴴ::AbstractMatrix)
319+
throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ)))
320+
end
321+
function _gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
322+
Vᴴ::AbstractMatrix; kwargs...)
323+
throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
324+
end
325+
function _gpu_Xgesvdr!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
326+
Vᴴ::AbstractMatrix; kwargs...)
327+
throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
328+
end
329+
function _gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
330+
Vᴴ::AbstractMatrix; kwargs...)
331+
throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ)))
332+
end
248333
# GPU SVD implementation
249334
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
250335
check_input(svd_full!, A, USVᴴ, alg)
@@ -298,7 +383,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
298383
throw(ArgumentError("Unsupported SVD algorithm"))
299384
end
300385
# TODO: make this controllable using a `gaugefix` keyword argument
301-
gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...)
386+
gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...)
302387
return USVᴴ
303388
end
304389
_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))

src/interface/svd.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ end
9797
function default_svd_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
9898
return LAPACK_DivideAndConquer(; kwargs...)
9999
end
100+
function default_svd_algorithm(::Type{T}; kwargs...) where {T<:Diagonal}
101+
return DiagonalAlgorithm(; kwargs...)
102+
end
100103

101104
for f in (:svd_full!, :svd_compact!, :svd_vals!)
102105
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}

test/svd.jl

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ using StableRNGs
55
using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef
66
using MatrixAlgebraKit: TruncatedAlgorithm, TruncationKeepAbove, diagview, isisometry
77

8-
@testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
8+
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
10+
@testset "svd_compact! for T = $T" for T in BLASFloats
911
rng = StableRNG(123)
1012
m = 54
1113
@testset "size ($m, $n)" for n in (37, m, 63, 0)
@@ -54,7 +56,7 @@ using MatrixAlgebraKit: TruncatedAlgorithm, TruncationKeepAbove, diagview, isiso
5456
end
5557
end
5658

57-
@testset "svd_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
59+
@testset "svd_full! for T = $T" for T in BLASFloats
5860
rng = StableRNG(123)
5961
m = 54
6062
@testset "size ($m, $n)" for n in (37, m, 63, 0)
@@ -88,7 +90,7 @@ end
8890
end
8991
end
9092

91-
@testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
93+
@testset "svd_trunc! for T = $T" for T in BLASFloats
9294
rng = StableRNG(123)
9395
m = 54
9496
if LinearAlgebra.LAPACK.version() < v"3.12.0"
@@ -122,9 +124,7 @@ end
122124
end
123125
end
124126

125-
@testset "svd_trunc! mix maxrank and tol for T = $T" for T in
126-
(Float32, Float64, ComplexF32,
127-
ComplexF64)
127+
@testset "svd_trunc! mix maxrank and tol for T = $T" for T in BLASFloats
128128
rng = StableRNG(123)
129129
if LinearAlgebra.LAPACK.version() < v"3.12.0"
130130
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
@@ -152,9 +152,7 @@ end
152152
end
153153
end
154154

155-
@testset "svd_trunc! specify truncation algorithm T = $T" for T in
156-
(Float32, Float64, ComplexF32,
157-
ComplexF64)
155+
@testset "svd_trunc! specify truncation algorithm T = $T" for T in BLASFloats
158156
rng = StableRNG(123)
159157
m = 4
160158
U = qr_compact(randn(rng, T, m, m))[1]
@@ -166,3 +164,37 @@ end
166164
@test diagview(S2) diagview(S)[1:2] rtol = sqrt(eps(real(T)))
167165
@test_throws ArgumentError svd_trunc(A; alg, trunc=(; maxrank=2))
168166
end
167+
168+
@testset "svd for Diagonal{$T}" for T in BLASFloats
169+
rng = StableRNG(123)
170+
for m in (54, 0)
171+
Ad = randn(T, m)
172+
A = Diagonal(Ad)
173+
174+
U, S, Vᴴ = @constinferred svd_compact(A)
175+
@test U isa AbstractMatrix{T} && size(U) == size(A)
176+
@test Vᴴ isa AbstractMatrix{T} && size(Vᴴ) == size(A)
177+
@test S isa Diagonal{real(T)} && size(S) == size(A)
178+
@test isunitary(U)
179+
@test isunitary(Vᴴ)
180+
@test all((0), diagview(S))
181+
@test A U * S * Vᴴ
182+
183+
U, S, Vᴴ = @constinferred svd_full(A)
184+
@test U isa AbstractMatrix{T} && size(U) == size(A)
185+
@test Vᴴ isa AbstractMatrix{T} && size(Vᴴ) == size(A)
186+
@test S isa Diagonal{real(T)} && size(S) == size(A)
187+
@test isunitary(U)
188+
@test isunitary(Vᴴ)
189+
@test all((0), diagview(S))
190+
@test A U * S * Vᴴ
191+
192+
S2 = @constinferred svd_vals(A)
193+
@test S2 isa AbstractVector{real(T)} && length(S2) == m
194+
@test S2 diagview(S)
195+
196+
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
197+
U3, S3, Vᴴ3 = @constinferred svd_trunc(A; alg)
198+
@test diagview(S3) S2[1:min(m, 2)]
199+
end
200+
end

0 commit comments

Comments
 (0)