@@ -6,32 +6,47 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66using MatrixAlgebraKit: diagview, sign_safe
77using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
88using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
9- import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd !, _gpu_Xgesvdp!, _gpu_gesvdj !
9+ import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd !, gesvdj !
1010import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
1111using AMDGPU
1212using LinearAlgebra
1313using LinearAlgebra: BlasFloat
1414
1515include (" yarocsolver.jl" )
1616
17- MatrixAlgebraKit. default_householder_driver (:: Type{A} ) where {A <: StridedROCMatrix{<:BlasFloat} } = ROCSOLVER ()
18- function MatrixAlgebraKit. default_svd_algorithm (:: Type{T} ; kwargs... ) where {T <: StridedROCMatrix }
19- return ROCSOLVER_QRIteration (; kwargs... )
17+ MatrixAlgebraKit. default_householder_driver (:: Type{A} ) where {A <: StridedROCVecOrMat{<:BlasFloat} } = ROCSOLVER ()
18+ MatrixAlgebraKit. default_qr_iteration_driver (:: Type{<:StridedROCVecOrMat} ) = ROCSOLVER ()
19+ MatrixAlgebraKit. default_jacobi_driver (:: Type{<:StridedROCVecOrMat} ) = ROCSOLVER ()
20+ function MatrixAlgebraKit. default_svd_algorithm (:: Type{T} ; kwargs... ) where {T <: StridedROCVecOrMat }
21+ return QRIteration (; kwargs... )
2022end
21- function MatrixAlgebraKit. default_eigh_algorithm (:: Type{T} ; kwargs... ) where {T <: StridedROCMatrix }
23+ function MatrixAlgebraKit. default_eigh_algorithm (:: Type{T} ; kwargs... ) where {T <: StridedROCVecOrMat }
2224 return ROCSOLVER_DivideAndConquer (; kwargs... )
2325end
2426
2527for f in (:geqrf! , :ungqr! , :unmqr! )
2628 @eval $ f (:: ROCSOLVER , args... ) = YArocSOLVER.$ f (args... )
2729end
2830
29- _gpu_gesvd! (A:: StridedROCMatrix , S:: StridedROCVector , U:: StridedROCMatrix , Vᴴ:: StridedROCMatrix ) =
30- YArocSOLVER. gesvd! (A, S, U, Vᴴ)
31- # not yet supported
32- # _gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
33- # YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
34- _gpu_gesvdj! (A:: StridedROCMatrix , S:: StridedROCVector , U:: StridedROCMatrix , Vᴴ:: StridedROCMatrix ; kwargs... ) =
31+ function gesvd! (:: ROCSOLVER , A:: StridedROCMatrix , S:: StridedROCVector , U:: StridedROCMatrix , Vᴴ:: StridedROCMatrix ; kwargs... )
32+ m, n = size (A)
33+ m >= n && return YArocSOLVER. gesvd! (A, S, U, Vᴴ)
34+ # ROCSOLVER requires m ≥ n; compute SVD via adjoint when m < n
35+ minmn = min (m, n)
36+ Aᴴ = minmn > 0 ? adjoint! (similar (A' ), A):: AbstractMatrix : similar (A' )
37+ Uᴴ = similar (U' )
38+ V = similar (Vᴴ' )
39+ if size (U) == (m, m)
40+ YArocSOLVER. gesvd! (Aᴴ, view (S, 1 : minmn, 1 ), V, Uᴴ)
41+ else
42+ YArocSOLVER. gesvd! (Aᴴ, S, V, Uᴴ)
43+ end
44+ length (U) > 0 && adjoint! (U, Uᴴ)
45+ length (Vᴴ) > 0 && adjoint! (Vᴴ, V)
46+ return S, U, Vᴴ
47+ end
48+
49+ gesvdj! (:: ROCSOLVER , A:: StridedROCMatrix , S:: StridedROCVector , U:: StridedROCMatrix , Vᴴ:: StridedROCMatrix ; kwargs... ) =
3550 YArocSOLVER. gesvdj! (A, S, U, Vᴴ; kwargs... )
3651_gpu_heevj! (A:: StridedROCMatrix , Dd:: StridedROCVector , V:: StridedROCMatrix ; kwargs... ) =
3752 YArocSOLVER. heevj! (A, Dd, V; kwargs... )
0 commit comments