Skip to content

Commit 3ced35f

Browse files
committed
centralize SVD via adjoint implementation
1 parent 1aa703d commit 3ced35f

3 files changed

Lines changed: 36 additions & 31 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,14 @@ end
3131
function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
3232
m, n = size(A)
3333
m >= n && return YArocSOLVER.gesvd!(A, S, U, Vᴴ)
34-
# ROCSOLVER requires m ≥ n; compute SVD via adjoint when m < n
35-
minmn = min(m, n)
36-
Aᴴ = minmn > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
37-
Uᴴ = similar(U')
38-
V = similar(Vᴴ')
39-
if size(U) == (m, m)
40-
YArocSOLVER.gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ)
41-
else
42-
YArocSOLVER.gesvd!(Aᴴ, S, V, Uᴴ)
43-
end
44-
length(U) > 0 && adjoint!(U, Uᴴ)
45-
length(Vᴴ) > 0 && adjoint!(Vᴴ, V)
46-
return S, U, Vᴴ
34+
return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
4735
end
4836

49-
gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
50-
YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
37+
function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
38+
m, n = size(A)
39+
m >= n && return YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
40+
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
41+
end
5142
_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
5243
YArocSOLVER.heevj!(A, Dd, V; kwargs...)
5344
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,14 @@ end
3636
function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
3737
m, n = size(A)
3838
m >= n && return YACUSOLVER.gesvd!(A, S, U, Vᴴ)
39-
# CUSOLVER requires m ≥ n; compute SVD via adjoint when m < n
40-
minmn = min(m, n)
41-
Aᴴ = minmn > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
42-
Uᴴ = similar(U')
43-
V = similar(Vᴴ')
44-
if size(U) == (m, m)
45-
YACUSOLVER.gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ)
46-
else
47-
YACUSOLVER.gesvd!(Aᴴ, S, V, Uᴴ)
48-
end
49-
length(U) > 0 && adjoint!(U, Uᴴ)
50-
length(Vᴴ) > 0 && adjoint!(Vᴴ, V)
51-
return S, U, Vᴴ
39+
return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, CUSOLVER(), A, S, U, Vᴴ; kwargs...)
5240
end
5341

54-
gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
55-
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
42+
function gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
43+
m, n = size(A)
44+
m >= n && return YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
45+
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, CUSOLVER(), A, S, U, Vᴴ; kwargs...)
46+
end
5647

5748
gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
5849
YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...)

src/implementations/svd.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,34 @@ for f! in (:gesdd!, :gesvd!, :gesvdj!, :gesvdp!, :gesvdx!, :gesvdr!, :gesdvd!)
117117
@eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!"))
118118
end
119119

120+
"""
121+
svd_via_adjoint!(f!, driver, A, S, U, Vᴴ; kwargs...)
122+
123+
Compute the SVD of `A` (m × n, m < n) by computing the SVD of `adjoint(A)` using
124+
the provided function `f!(driver, A, S, U, Vᴴ; kwargs...)`. Use this as a building
125+
block for drivers whose SVD routines require m ≥ n.
126+
"""
127+
function svd_via_adjoint!(f!::F, driver::Driver, A, S, U, Vᴴ; kwargs...) where {F}
128+
Aᴴ = adjoint!(similar(A'), A)
129+
Uᴴ = similar(U')
130+
V = similar(Vᴴ')
131+
f!(driver, Aᴴ, S, V, Uᴴ; kwargs...)
132+
length(U) > 0 && adjoint!(U, Uᴴ)
133+
length(Vᴴ) > 0 && adjoint!(Vᴴ, V)
134+
return S, U, Vᴴ
135+
end
136+
120137
# LAPACK
121-
for f! in (:gesdd!, :gesvd!, :gesvdj!, :gesvdx!, :gesdvd!)
138+
for f! in (:gesdd!, :gesvd!, :gesvdx!, :gesdvd!)
122139
@eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...)
123140
end
124141

142+
function gesvdj!(::LAPACK, A, S, U, Vᴴ; kwargs...)
143+
m, n = size(A)
144+
m >= n && return YALAPACK.gesvdj!(A, S, U, Vᴴ)
145+
return svd_via_adjoint!(gesvdj!, LAPACK(), A, S, U, Vᴴ; kwargs...)
146+
end
147+
125148
for (f, f_lapack!, Alg) in (
126149
(:safe_divide_and_conquer, :gesdvd!, :SafeDivideAndConquer),
127150
(:divide_and_conquer, :gesdd!, :DivideAndConquer),

0 commit comments

Comments
 (0)