|
| 1 | +module MatrixAlgebraKitAMDGPUExt |
| 2 | + |
| 3 | +using MatrixAlgebraKit |
| 4 | +using MatrixAlgebraKit: @algdef, Algorithm, check_input |
| 5 | +using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! |
| 6 | +using MatrixAlgebraKit: diagview, sign_safe |
| 7 | +using MatrixAlgebraKit: LQViaTransposedQR |
| 8 | +using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm |
| 9 | +import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! |
| 10 | +using AMDGPU |
| 11 | +using LinearAlgebra |
| 12 | +using LinearAlgebra: BlasFloat |
| 13 | + |
| 14 | +include("yarocsolver.jl") |
| 15 | + |
| 16 | +function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} |
| 17 | + return ROCSOLVER_HouseholderQR(; kwargs...) |
| 18 | +end |
| 19 | +function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} |
| 20 | + qr_alg = ROCSOLVER_HouseholderQR(; kwargs...) |
| 21 | + return LQViaTransposedQR(qr_alg) |
| 22 | +end |
| 23 | +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} |
| 24 | + return ROCSOLVER_QRIteration(; kwargs...) |
| 25 | +end |
| 26 | + |
| 27 | +_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A) |
| 28 | +_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ) |
| 29 | +_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) = YArocSOLVER.unmqr!(side, trans, A, τ, C) |
| 30 | +_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) = YArocSOLVER.gesvd!(A, S, U, Vᴴ) |
| 31 | +# not yet supported |
| 32 | +#_gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) |
| 33 | +_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) |
| 34 | + |
| 35 | +end |
0 commit comments