Skip to content

Commit 1aa703d

Browse files
committed
incorporate changes for GPU and GLA
1 parent 1c7e84f commit 1aa703d

7 files changed

Lines changed: 152 additions & 181 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,47 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
9-
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
9+
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
1111
using AMDGPU
1212
using LinearAlgebra
1313
using LinearAlgebra: BlasFloat
1414

1515
include("yarocsolver.jl")
1616

17-
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCMatrix{<:BlasFloat}} = ROCSOLVER()
18-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
19-
return ROCSOLVER_QRIteration(; kwargs...)
17+
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER()
18+
MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER()
19+
MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER()
20+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat}
21+
return QRIteration(; kwargs...)
2022
end
21-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
23+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat}
2224
return ROCSOLVER_DivideAndConquer(; kwargs...)
2325
end
2426

2527
for f in (:geqrf!, :ungqr!, :unmqr!)
2628
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
2729
end
2830

29-
_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) =
30-
YArocSOLVER.gesvd!(A, S, U, Vᴴ)
31-
# not yet supported
32-
# _gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
33-
# YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
34-
_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
31+
function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
32+
m, n = size(A)
33+
m >= n && return YArocSOLVER.gesvd!(A, S, U, Vᴴ)
34+
# ROCSOLVER requires m ≥ n; compute SVD via adjoint when m < n
35+
minmn = min(m, n)
36+
Aᴴ = minmn > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
37+
Uᴴ = similar(U')
38+
V = similar(Vᴴ')
39+
if size(U) == (m, m)
40+
YArocSOLVER.gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ)
41+
else
42+
YArocSOLVER.gesvd!(Aᴴ, S, V, Uᴴ)
43+
end
44+
length(U) > 0 && adjoint!(U, Uᴴ)
45+
length(Vᴴ) > 0 && adjoint!(Vᴴ, V)
46+
return S, U, Vᴴ
47+
end
48+
49+
gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
3550
YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
3651
_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
3752
YArocSOLVER.heevj!(A, Dd, V; kwargs...)

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank
1111
using CUDA, CUDA.CUBLAS
1212
using CUDA: i32
1313
using LinearAlgebra
@@ -16,8 +16,11 @@ using LinearAlgebra: BlasFloat
1616
include("yacusolver.jl")
1717

1818
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
19+
MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER()
20+
MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER()
21+
MatrixAlgebraKit.default_svd_polar_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER()
1922
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
20-
return CUSOLVER_QRIteration(; kwargs...)
23+
return QRIteration(; kwargs...)
2124
end
2225
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
2326
return CUSOLVER_Simple(; kwargs...)
@@ -30,6 +33,33 @@ for f in (:geqrf!, :ungqr!, :unmqr!)
3033
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
3134
end
3235

36+
function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
37+
m, n = size(A)
38+
m >= n && return YACUSOLVER.gesvd!(A, S, U, Vᴴ)
39+
# CUSOLVER requires m ≥ n; compute SVD via adjoint when m < n
40+
minmn = min(m, n)
41+
Aᴴ = minmn > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
42+
Uᴴ = similar(U')
43+
V = similar(Vᴴ')
44+
if size(U) == (m, m)
45+
YACUSOLVER.gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ)
46+
else
47+
YACUSOLVER.gesvd!(Aᴴ, S, V, Uᴴ)
48+
end
49+
length(U) > 0 && adjoint!(U, Uᴴ)
50+
length(Vᴴ) > 0 && adjoint!(Vᴴ, V)
51+
return S, U, Vᴴ
52+
end
53+
54+
gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
55+
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
56+
57+
gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
58+
YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...)
59+
60+
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
61+
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)
62+
3363
_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
3464
YACUSOLVER.Xgeev!(A, D, V)
3565

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,32 @@ module MatrixAlgebraKitGenericLinearAlgebraExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
5+
using MatrixAlgebraKit: GLA
6+
import MatrixAlgebraKit: gesvd!
57
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
68
using LinearAlgebra: I, Diagonal, lmul!
79

8-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
9-
return GLA_QRIteration()
10-
end
11-
12-
for f! in (:svd_compact!, :svd_full!)
13-
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing)
14-
end
15-
MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
10+
MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}) = GLA()
1611

17-
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
18-
F = svd!(A)
19-
U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt
20-
21-
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
22-
do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)
23-
24-
return U, S, Vᴴ
25-
end
26-
27-
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
28-
F = svd!(A; full = true)
29-
U, Vᴴ = F.U, F.Vt
30-
S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1))))
31-
diagview(S) .= F.S
32-
33-
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
34-
do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)
35-
36-
return U, S, Vᴴ
12+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
13+
return QRIteration(; kwargs...)
3714
end
3815

39-
function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration)
40-
return svdvals!(A)
16+
function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...)
17+
m, n = size(A)
18+
if length(U) == 0 && length(Vᴴ) == 0
19+
Sv = svdvals!(A)
20+
copyto!(S, Sv)
21+
else
22+
minmn = min(m, n)
23+
# full SVD if U has m columns or Vᴴ has n rows (beyond the compact min(m,n))
24+
full = (length(U) > 0 && size(U, 2) > minmn) || (length(Vᴴ) > 0 && size(Vᴴ, 1) > minmn)
25+
F = svd!(A; full = full)
26+
length(S) > 0 && copyto!(S, F.S)
27+
length(U) > 0 && copyto!(U, F.U)
28+
length(Vᴴ) > 0 && copyto!(Vᴴ, F.Vt)
29+
end
30+
return S, U, Vᴴ
4131
end
4232

4333
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export left_orth, right_orth, left_null, right_null
3232
export left_orth!, right_orth!, left_null!, right_null!
3333

3434
export Householder, Native_HouseholderQR, Native_HouseholderLQ
35+
export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDPolar
3536
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3637
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3738
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer

0 commit comments

Comments
 (0)