Skip to content

Commit b4b5176

Browse files
committed
Support gesvd via transpose and passing orthnull tests
1 parent 73ed875 commit b4b5176

5 files changed

Lines changed: 55 additions & 20 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:S
2222
return LQViaTransposedQR(qr_alg)
2323
end
2424
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
25-
return ROCSOLVER_Jacobi(; kwargs...)
25+
return ROCSOLVER_QRIteration(; kwargs...)
2626
end
2727
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
2828
return ROCSOLVER_DivideAndConquer(; kwargs...)

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT<:
2222
return LQViaTransposedQR(qr_alg)
2323
end
2424
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
25-
return CUSOLVER_Jacobi(; kwargs...)
25+
return CUSOLVER_QRIteration(; kwargs...)
2626
end
2727
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
2828
return CUSOLVER_Simple(; kwargs...)

src/implementations/svd.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,26 @@ _gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::Abstr
245245
_gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
246246
_gpu_Xgesvdr!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
247247
_gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ)))
248+
function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix)
249+
m, n = size(A)
250+
m n && return _gpu_gesvd!(A, S, U, Vᴴ)
251+
# both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration)
252+
# if this condition is not met, do the SVD via adjoint
253+
minmn = min(m, n)
254+
At = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
255+
Ut = similar(U')
256+
Vᴴt = similar(Vᴴ')
257+
if size(U) == (m, m)
258+
_gpu_gesvd!(At, view(S, 1:minmn, 1), Vᴴt, Ut)
259+
else
260+
_gpu_gesvd!(At, S, Vᴴt, Ut)
261+
end
262+
length(U) > 0 ? adjoint!(U, Ut) : one!(U)
263+
length(Vᴴ) > 0 ? adjoint!(Vᴴ, Vᴴt) : one!(Vᴴ)
264+
conj!(S)
265+
return U, S, Vᴴ
266+
end
267+
248268
# GPU SVD implementation
249269
function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
250270
check_input(svd_full!, A, USVᴴ, alg)
@@ -260,8 +280,8 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
260280
end
261281
if alg isa GPU_QRIteration
262282
isempty(alg.kwargs) ||
263-
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
264-
_gpu_gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
283+
@warn "GPU_QRIteration does not accept any keyword arguments"
284+
_gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ)
265285
elseif alg isa GPU_SVDPolar
266286
_gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
267287
elseif alg isa GPU_Jacobi
@@ -295,7 +315,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
295315
if alg isa GPU_QRIteration
296316
isempty(alg.kwargs) ||
297317
@warn "GPU_QRIteration does not accept any keyword arguments"
298-
_gpu_gesvd!(A, S.diag, U, Vᴴ)
318+
_gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ)
299319
elseif alg isa GPU_SVDPolar
300320
_gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
301321
elseif alg isa GPU_Jacobi
@@ -315,8 +335,8 @@ function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
315335
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
316336
if alg isa GPU_QRIteration
317337
isempty(alg.kwargs) ||
318-
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
319-
_gpu_gesvd!(A, S, U, Vᴴ)
338+
@warn "GPU_QRIteration does not accept any keyword arguments"
339+
_gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ)
320340
elseif alg isa GPU_SVDPolar
321341
_gpu_Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
322342
elseif alg isa GPU_Jacobi

test/cuda/orthnull.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using StableRNGs
5-
using LinearAlgebra: LinearAlgebra, I, mul!
5+
using LinearAlgebra: LinearAlgebra, I, mul!, diagm, norm
66
using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
77
using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
88
initialize_output, AbstractAlgorithm
@@ -64,9 +64,11 @@ end
6464
@test N isa CuMatrix{T} && size(N) == (m, m - minmn)
6565
@test V * C A
6666
@test isisometry(V)
67-
@test LinearAlgebra.norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
67+
@test norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
6868
@test isisometry(N)
69-
@test V * V' + N * N' I atol = 100 * MatrixAlgebraKit.defaulttol(T)
69+
hV = collect(V)
70+
hN = collect(N)
71+
@test hV * hV' + hN * hN' I
7072

7173
M = LinearMap(A)
7274
VM, CM = @constinferred left_orth(M; kind=:svd)
@@ -94,9 +96,12 @@ end
9496
@test N isa CuMatrix{T} && size(N) == (m, m - minmn)
9597
@test V * C A
9698
@test isisometry(V)
97-
@test LinearAlgebra.norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
99+
@test norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
98100
@test isisometry(N)
99-
@test V * V' + N * N' I atol = MatrixAlgebraKit.defaulttol(T)
101+
#@test norm(V * V' + N * N' - CuArray(diagm(ones(T, m)))) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
102+
hV = collect(V)
103+
hN = collect(N)
104+
@test hV * hV' + hN * hN' I
100105
end
101106

102107
Ac = similar(A)
@@ -109,7 +114,9 @@ end
109114
@test isisometry(V2)
110115
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
111116
@test isisometry(N2)
112-
@test V2 * V2' + N2 * N2' I atol = MatrixAlgebraKit.defaulttol(T)
117+
hV2 = collect(V2)
118+
hN2 = collect(N2)
119+
@test hV2 * hV2' + hN2 * hN2' I
113120

114121
atol = eps(real(T))
115122
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol))
@@ -150,7 +157,9 @@ end
150157
@test N2 === N
151158
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
152159
@test isisometry(N2)
153-
@test V2 * V2' + N2 * N2' I atol = MatrixAlgebraKit.defaulttol(T)
160+
hV2 = collect(V2)
161+
hN2 = collect(N2)
162+
@test hV2 * hV2' + hN2 * hN2' I
154163
end
155164

156165
# with kind and tol kwargs
@@ -210,7 +219,9 @@ end
210219
@test isisometry(Vᴴ; side=:right)
211220
@test LinearAlgebra.norm(A * adjoint(Nᴴ)) 0 atol = MatrixAlgebraKit.defaulttol(T)
212221
@test isisometry(Nᴴ; side=:right)
213-
@test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ I atol = MatrixAlgebraKit.defaulttol(T)
222+
hVᴴ = collect(Vᴴ)
223+
hNᴴ = collect(Nᴴ)
224+
@test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ I
214225

215226
M = LinearMap(A)
216227
CM, VMᴴ = @constinferred right_orth(M; kind=:svd)
@@ -226,7 +237,9 @@ end
226237
@test isisometry(Vᴴ2; side=:right)
227238
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
228239
@test isisometry(Nᴴ; side=:right)
229-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 I atol = MatrixAlgebraKit.defaulttol(T)
240+
hVᴴ2 = collect(Vᴴ2)
241+
hNᴴ2 = collect(Nᴴ2)
242+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
230243

231244
# TODO truncate currently broken due to searchsortedlast
232245
atol = eps(real(T))
@@ -266,7 +279,9 @@ end
266279
@test Nᴴ2 === Nᴴ
267280
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
268281
@test isisometry(Nᴴ2; side=:right)
269-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 I atol = 100 * MatrixAlgebraKit.defaulttol(T)
282+
hVᴴ2 = collect(Vᴴ2)
283+
hNᴴ2 = collect(Nᴴ2)
284+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
270285
end
271286

272287
if kind == :svd

test/cuda/svd.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ include(joinpath("..", "utilities.jl"))
1515
k = min(m, n)
1616
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
1717
@testset "algorithm $alg" for alg in algs
18-
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
18+
#n > m && alg isa CUSOLVER_QRIteration && continue # not supported
1919
minmn = min(m, n)
2020
A = CuArray(randn(rng, T, m, n))
2121

@@ -51,7 +51,7 @@ end
5151
@testset "size ($m, $n)" for n in (37, m, 63)
5252
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
5353
@testset "algorithm $alg" for alg in algs
54-
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
54+
#n > m && alg isa CUSOLVER_QRIteration && continue # not supported
5555
A = CuArray(randn(rng, T, m, n))
5656
U, S, Vᴴ = svd_full(A; alg)
5757
@test U isa CuMatrix{T} && size(U) == (m, m)
@@ -96,7 +96,7 @@ end
9696
p = min(m, n) - k - 1
9797
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k=k, p=p, niters=100),)
9898
@testset "algorithm $alg" for alg in algs
99-
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
99+
#n > m && alg isa CUSOLVER_QRIteration && continue # not supported
100100
hA = randn(rng, T, m, n)
101101
S₀ = svd_vals(hA)
102102
A = CuArray(hA)

0 commit comments

Comments
 (0)