Skip to content

Commit 1355804

Browse files
authored
GPU improvements for SVD (#80)
Pull in SVD-specific changes from ksh/tk
1 parent ba9867b commit 1355804

4 files changed

Lines changed: 91 additions & 17 deletions

File tree

src/implementations/svd.jl

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -347,25 +347,46 @@ function _gpu_gesvdj!(
347347
)
348348
throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ)))
349349
end
350+
function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix)
351+
m, n = size(A)
352+
m n && return _gpu_gesvd!(A, S, U, Vᴴ)
353+
# both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration)
354+
# if this condition is not met, do the SVD via adjoint
355+
minmn = min(m, n)
356+
Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
357+
Uᴴ = similar(U')
358+
V = similar(Vᴴ')
359+
if size(U) == (m, m)
360+
_gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ)
361+
else
362+
_gpu_gesvd!(Aᴴ, S, V, Uᴴ)
363+
end
364+
length(U) > 0 && adjoint!(U, Uᴴ)
365+
length(Vᴴ) > 0 && adjoint!(Vᴴ, V)
366+
return U, S, Vᴴ
367+
end
368+
350369
# GPU SVD implementation
351-
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
370+
function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
352371
check_input(svd_full!, A, USVᴴ, alg)
353372
U, S, Vᴴ = USVᴴ
354373
fill!(S, zero(eltype(S)))
355374
m, n = size(A)
356375
minmn = min(m, n)
376+
if minmn == 0
377+
one!(U)
378+
zero!(S)
379+
one!(Vᴴ)
380+
return USVᴴ
381+
end
357382
if alg isa GPU_QRIteration
358383
isempty(alg.kwargs) ||
359-
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
360-
_gpu_gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
384+
@warn "GPU_QRIteration does not accept any keyword arguments"
385+
_gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ)
361386
elseif alg isa GPU_SVDPolar
362387
_gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
363388
elseif alg isa GPU_Jacobi
364389
_gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
365-
# elseif alg isa LAPACK_Bisection
366-
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
367-
# elseif alg isa LAPACK_Jacobi
368-
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
369390
else
370391
throw(ArgumentError("Unsupported SVD algorithm"))
371392
end
@@ -390,13 +411,13 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
390411
return USVᴴtrunc..., ϵ
391412
end
392413

393-
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
414+
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
394415
check_input(svd_compact!, A, USVᴴ, alg)
395416
U, S, Vᴴ = USVᴴ
396417
if alg isa GPU_QRIteration
397418
isempty(alg.kwargs) ||
398-
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
399-
_gpu_gesvd!(A, S.diag, U, Vᴴ)
419+
@warn "GPU_QRIteration does not accept any keyword arguments"
420+
_gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ)
400421
elseif alg isa GPU_SVDPolar
401422
_gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
402423
elseif alg isa GPU_Jacobi
@@ -416,8 +437,8 @@ function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
416437
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
417438
if alg isa GPU_QRIteration
418439
isempty(alg.kwargs) ||
419-
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
420-
_gpu_gesvd!(A, S, U, Vᴴ)
440+
@warn "GPU_QRIteration does not accept any keyword arguments"
441+
_gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ)
421442
elseif alg isa GPU_SVDPolar
422443
_gpu_Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
423444
elseif alg isa GPU_Jacobi

test/amd/svd.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ include(joinpath("..", "utilities.jl"))
1515
k = min(m, n)
1616
algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
1717
@testset "algorithm $alg" for alg in algs
18-
n > m && alg isa ROCSOLVER_QRIteration && continue # not supported
1918
minmn = min(m, n)
2019
A = ROCArray(randn(rng, T, m, n))
2120

@@ -41,6 +40,9 @@ include(joinpath("..", "utilities.jl"))
4140
Sd = svd_vals(A, alg)
4241
@test ROCArray(diagview(S)) Sd
4342
# ROCArray is necessary because norm of ROCArray view with non-unit step is broken
43+
if alg isa ROCSOLVER_QRIteration
44+
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad"))
45+
end
4446
end
4547
end
4648
end
@@ -51,7 +53,6 @@ end
5153
@testset "size ($m, $n)" for n in (37, m, 63)
5254
algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
5355
@testset "algorithm $alg" for alg in algs
54-
n > m && alg isa ROCSOLVER_QRIteration && continue # not supported
5556
A = ROCArray(randn(rng, T, m, n))
5657
U, S, Vᴴ = svd_full(A; alg)
5758
@test U isa ROCMatrix{T} && size(U) == (m, m)
@@ -81,6 +82,26 @@ end
8182
@test Sc === Sc2
8283
@test ROCArray(diagview(S)) Sc
8384
# ROCArray is necessary because norm of ROCArray view with non-unit step is broken
85+
if alg isa ROCSOLVER_QRIteration
86+
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_full!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad"))
87+
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_vals!(copy!(Ac, A), Sc, ROCSOLVER_QRIteration(; bad = "bad"))
88+
end
89+
end
90+
end
91+
@testset "size (0, 0)" begin
92+
algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
93+
@testset "algorithm $alg" for alg in algs
94+
A = ROCArray(randn(rng, T, 0, 0))
95+
U, S, Vᴴ = svd_full(A; alg)
96+
@test U isa ROCMatrix{T} && size(U) == (0, 0)
97+
@test S isa ROCMatrix{real(T)} && size(S) == (0, 0)
98+
@test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (0, 0)
99+
@test U * S * Vᴴ A
100+
@test isapproxone(U' * U)
101+
@test isapproxone(U * U')
102+
@test isapproxone(Vᴴ * Vᴴ')
103+
@test isapproxone(Vᴴ' * Vᴴ)
104+
@test all(isposdef, diagview(S))
84105
end
85106
end
86107
end

test/cuda/svd.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ include(joinpath("..", "utilities.jl"))
1515
k = min(m, n)
1616
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
1717
@testset "algorithm $alg" for alg in algs
18-
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
1918
minmn = min(m, n)
2019
A = CuArray(randn(rng, T, m, n))
2120

@@ -41,6 +40,9 @@ include(joinpath("..", "utilities.jl"))
4140
Sd = svd_vals(A, alg)
4241
@test CuArray(diagview(S)) Sd
4342
# CuArray is necessary because norm of CuArray view with non-unit step is broken
43+
if alg isa CUSOLVER_QRIteration
44+
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad"))
45+
end
4446
end
4547
end
4648
end
@@ -51,7 +53,6 @@ end
5153
@testset "size ($m, $n)" for n in (37, m, 63)
5254
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
5355
@testset "algorithm $alg" for alg in algs
54-
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
5556
A = CuArray(randn(rng, T, m, n))
5657
U, S, Vᴴ = svd_full(A; alg)
5758
@test U isa CuMatrix{T} && size(U) == (m, m)
@@ -82,8 +83,26 @@ end
8283
@test Sc === Sc2
8384
@test CuArray(diagview(S)) Sc
8485
# CuArray is necessary because norm of CuArray view with non-unit step is broken
86+
if alg isa CUSOLVER_QRIteration
87+
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_full!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad"))
88+
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_vals!(copy!(Ac, A), Sc, CUSOLVER_QRIteration(; bad = "bad"))
89+
end
8590
end
91+
end
92+
@testset "size (0, 0)" begin
93+
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
8694
@testset "algorithm $alg" for alg in algs
95+
A = CuArray(randn(rng, T, 0, 0))
96+
U, S, Vᴴ = svd_full(A; alg)
97+
@test U isa CuMatrix{T} && size(U) == (0, 0)
98+
@test S isa CuMatrix{real(T)} && size(S) == (0, 0)
99+
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (0, 0)
100+
@test U * S * Vᴴ A
101+
@test isapproxone(U' * U)
102+
@test isapproxone(U * U')
103+
@test isapproxone(Vᴴ * Vᴴ')
104+
@test isapproxone(Vᴴ' * Vᴴ)
105+
@test all(isposdef, diagview(S))
87106
end
88107
end
89108
end
@@ -96,7 +115,6 @@ end
96115
p = min(m, n) - k - 1
97116
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k = k, p = p, niters = 100))
98117
@testset "algorithm $alg" for alg in algs
99-
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
100118
hA = randn(rng, T, m, n)
101119
S₀ = svd_vals(hA)
102120
A = CuArray(hA)

test/svd.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,20 @@ end
9292
@test diagview(S) Sc
9393
end
9494
end
95+
@testset "size (0, 0)" begin
96+
@testset "algorithm $alg" for alg in
97+
(LAPACK_DivideAndConquer(), LAPACK_QRIteration())
98+
A = randn(rng, T, 0, 0)
99+
U, S, Vᴴ = svd_full(A; alg)
100+
@test U isa Matrix{T} && size(U) == (0, 0)
101+
@test S isa Matrix{real(T)} && size(S) == (0, 0)
102+
@test Vᴴ isa Matrix{T} && size(Vᴴ) == (0, 0)
103+
@test U * S * Vᴴ A
104+
@test isunitary(U)
105+
@test isunitary(Vᴴ)
106+
@test all(isposdef, diagview(S))
107+
end
108+
end
95109
end
96110

97111
@testset "svd_trunc! for T = $T" for T in BLASFloats

0 commit comments

Comments
 (0)