diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 0f8c35139..7f51bde93 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -5,8 +5,9 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: LQViaTransposedQR -using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm -import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! +using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm +import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! +import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx! using AMDGPU using LinearAlgebra using LinearAlgebra: BlasFloat @@ -23,6 +24,9 @@ end function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} return ROCSOLVER_QRIteration(; kwargs...) end +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} + return ROCSOLVER_DivideAndConquer(; kwargs...) +end _gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A) _gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ) @@ -32,4 +36,8 @@ _gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ: #_gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) _gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) +_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...) +_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...) +_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...) +_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...) end diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index 029bc0181..3338308b1 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -1,7 +1,7 @@ module YArocSOLVER using LinearAlgebra -using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing +using LinearAlgebra: BlasInt, BlasReal, BlasFloat, checksquare, chkstride1, require_one_based_indexing using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo using AMDGPU @@ -475,42 +475,146 @@ end # return X, info # end -# for (jname, bname, fname, elty, relty) in -# ((:syevd!, :rocsolverDnSsyevd_bufferSize, :rocsolverDnSsyevd, :Float32, :Float32), -# (:syevd!, :rocsolverDnDsyevd_bufferSize, :rocsolverDnDsyevd, :Float64, :Float64), -# (:heevd!, :rocsolverDnCheevd_bufferSize, :rocsolverDnCheevd, :ComplexF32, :Float32), -# (:heevd!, :rocsolverDnZheevd_bufferSize, :rocsolverDnZheevd, :ComplexF64, :Float64)) -# @eval begin -# function $jname(jobz::Char, -# uplo::Char, -# A::StridedROCMatrix{$elty}) -# chkuplo(uplo) -# n = checksquare(A) -# lda = max(1, stride(A, 2)) -# W = CuArray{$relty}(undef, n) -# dh = rocBLAS.handle() +for (heevd, heev, heevx, heevj, elty, relty) in + ((:(rocSOLVER.rocsolver_ssyevd), :(rocSOLVER.rocsolver_ssyev), :(rocSOLVER.rocsolver_ssyevx), :(rocSOLVER.rocsolver_ssyevj), :Float32, :Float32), + (:(rocSOLVER.rocsolver_dsyevd), :(rocSOLVER.rocsolver_dsyev), :(rocSOLVER.rocsolver_dsyevx), :(rocSOLVER.rocsolver_dsyevj), :Float64, :Float64), + (:(rocSOLVER.rocsolver_cheevd), :(rocSOLVER.rocsolver_cheev), :(rocSOLVER.rocsolver_cheevx), :(rocSOLVER.rocsolver_cheevj), :ComplexF32, :Float32), + (:(rocSOLVER.rocsolver_zheevd), :(rocSOLVER.rocsolver_zheev), :(rocSOLVER.rocsolver_zheevx), :(rocSOLVER.rocsolver_zheevj), :ComplexF64, :Float64)) + @eval begin + function heevd!(A::StridedROCMatrix{$elty}, + W::StridedROCVector{$relty}, + V::StridedROCMatrix{$elty}; + uplo::Char='U') + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = rocSOLVER.rocblas_evect_none + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = rocSOLVER.rocblas_evect_original + end + dh = rocBLAS.handle() + work = ROCVector{$relty}(undef, n) + dev_info = ROCVector{Cint}(undef, 1) + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) + $heevd(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info) -# function bufferSize() -# out = Ref{Cint}(0) -# $bname(dh, jobz, uplo, n, A, lda, W, out) -# return out[] * sizeof($elty) -# end + info = @allowscalar dev_info[1] + chkargsok(BlasInt(info)) -# with_workspace(dh.workspace_gpu, bufferSize) do buffer -# return $fname(dh, jobz, uplo, n, A, lda, W, -# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) -# end + if jobz == rocSOLVER.rocblas_evect_original && V !== A + copy!(V, A) + end + return W, V + end + function heev!(A::StridedROCMatrix{$elty}, + W::StridedROCVector{$relty}, + V::StridedROCMatrix{$elty}; + uplo::Char='U') + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = rocSOLVER.rocblas_evect_none + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = rocSOLVER.rocblas_evect_original + end + dh = rocBLAS.handle() + work = ROCVector{$relty}(undef, n) + dev_info = ROCVector{Cint}(undef, 1) + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) + $heev(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info) -# info = @allowscalar dh.info[1] -# chkargsok(BlasInt(info)) + info = @allowscalar dev_info[1] + chkargsok(BlasInt(info)) -# if jobz == 'N' -# return W -# elseif jobz == 'V' -# return W, A -# end -# end -# end -# end + if jobz == rocSOLVER.rocblas_evect_original && V !== A + copy!(V, A) + end + return W, V + end + function heevx!(A::StridedROCMatrix{$elty}, + W::StridedROCVector{$relty}, + V::StridedROCMatrix{$elty}; + uplo::Char='U', + kwargs...) + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if haskey(kwargs, :irange) + il = first(kwargs[:irange]) + iu = last(kwargs[:irange]) + vl = vu = zero($relty) + range = rocSOLVER.rocblas_erange_index + elseif haskey(kwargs, :vl) || haskey(kwargs, :vu) + vl = convert($relty, get(kwargs, :vl, -Inf)) + vu = convert($relty, get(kwargs, :vu, +Inf)) + il = iu = 0 + range = rocSOLVER.rocblas_erange_value + else + il = iu = 0 + vl = vu = zero($relty) + range = rocSOLVER.rocblas_erange_all + end + if length(V) == 0 + jobz = rocSOLVER.rocblas_evect_none + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = rocSOLVER.rocblas_evect_original + end + dh = rocBLAS.handle() + abstol = -one($relty) + nev = ROCVector{Cint}(undef, 1) + ldv = max(1, stride(V, 2)) + ifail = ROCVector{Cint}(undef, n) + dev_info = ROCVector{Cint}(undef, 1) + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) + $heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, nev, W, V, ldv, ifail, dev_info) + + info = @allowscalar dev_info[1] + chkargsok(BlasInt(info)) + m = @allowscalar nev[1] + return W, V, m + end + function heevj!(A::StridedROCMatrix{$elty}, + W::StridedROCVector{$relty}, + V::StridedROCMatrix{$elty}; + uplo::Char='U', + tol::$relty=eps($relty), + max_sweeps::Int=100, + sort::Char='N') + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = rocSOLVER.rocblas_evect_none + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = rocSOLVER.rocblas_evect_original + end + dh = rocBLAS.handle() + dev_info = ROCVector{Cint}(undef, 1) + residual = ROCVector{$relty}(undef, 1) + n_sweeps = ROCVector{Cint}(undef, 1) + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) + roc_sort = sort == 'N' ? rocSOLVER.rocblas_esort_none : rocSOLVER.rocblas_esort_ascending + $heevj(dh, roc_sort, jobz, roc_uplo, n, A, lda, tol, residual, max_sweeps, n_sweeps, W, dev_info) + + info = @allowscalar dev_info[1] + chkargsok(BlasInt(info)) + + if jobz == rocSOLVER.rocblas_evect_original && V !== A + copy!(V, A) + end + return W, V + end + end +end end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index c42e935b3..1fafb269d 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -5,8 +5,9 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: LQViaTransposedQR -using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm +using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! +import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd! using CUDA using LinearAlgebra using LinearAlgebra: BlasFloat @@ -26,6 +27,9 @@ end function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} return CUSOLVER_Simple(; kwargs...) end +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} + return CUSOLVER_DivideAndConquer(; kwargs...) +end _gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V) @@ -37,4 +41,7 @@ _gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ:: _gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...) _gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) +_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...) +_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...) + end diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index 427eedb1a..a1a52620e 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -1,7 +1,7 @@ module YACUSOLVER using LinearAlgebra -using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing +using LinearAlgebra: BlasInt, BlasFloat, BlasReal, checksquare, chkstride1, require_one_based_indexing using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo using CUDA @@ -679,43 +679,93 @@ end # return X, info # end -# for (jname, bname, fname, elty, relty) in -# ((:syevd!, :cusolverDnSsyevd_bufferSize, :cusolverDnSsyevd, :Float32, :Float32), -# (:syevd!, :cusolverDnDsyevd_bufferSize, :cusolverDnDsyevd, :Float64, :Float64), -# (:heevd!, :cusolverDnCheevd_bufferSize, :cusolverDnCheevd, :ComplexF32, :Float32), -# (:heevd!, :cusolverDnZheevd_bufferSize, :cusolverDnZheevd, :ComplexF64, :Float64)) -# @eval begin -# function $jname(jobz::Char, -# uplo::Char, -# A::StridedCuMatrix{$elty}) -# chkuplo(uplo) -# n = checksquare(A) -# lda = max(1, stride(A, 2)) -# W = CuArray{$relty}(undef, n) -# dh = dense_handle() +for (bname, fname, elty, relty) in ((:(CUSOLVER.cusolverDnSsyevj_bufferSize), :(CUSOLVER.cusolverDnSsyevj), :Float32, :Float32), + (:(CUSOLVER.cusolverDnDsyevj_bufferSize), :(CUSOLVER.cusolverDnDsyevj), :Float64, :Float64), + (:(CUSOLVER.cusolverDnCheevj_bufferSize), :(CUSOLVER.cusolverDnCheevj), :ComplexF32, :Float32), + (:(CUSOLVER.cusolverDnZheevj_bufferSize), :(CUSOLVER.cusolverDnZheevj), :ComplexF64, :Float64)) + @eval begin + function heevj!(A::StridedCuMatrix{$elty}, + W::StridedCuVector{$relty}, + V::StridedCuMatrix{$elty}; + uplo::Char='U', + tol::$relty=eps($relty), + max_sweeps::Int=100 + ) + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + dh = CUSOLVER.dense_handle() + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = 'N' + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = 'V' + end + params = Ref{CUSOLVER.syevjInfo_t}(C_NULL) + CUSOLVER.cusolverDnCreateSyevjInfo(params) + CUSOLVER.cusolverDnXsyevjSetTolerance(params[], tol) + CUSOLVER.cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps) + function bufferSize() + out = Ref{Cint}(0) + $bname(dh, jobz, uplo, n, A, lda, W, out, params[]) + return out[] * sizeof($elty) + end + CUDA.with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, jobz, uplo, n, A, lda, W, buffer, + sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) + end -# function bufferSize() -# out = Ref{Cint}(0) -# $bname(dh, jobz, uplo, n, A, lda, W, out) -# return out[] * sizeof($elty) -# end + info = @allowscalar dh.info[1] + chkargsok(BlasInt(info)) -# with_workspace(dh.workspace_gpu, bufferSize) do buffer -# return $fname(dh, jobz, uplo, n, A, lda, W, -# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) -# end + if jobz == 'V' && V !== A + copy!(V, A) + end + return W, V + end + end +end -# info = @allowscalar dh.info[1] -# chkargsok(BlasInt(info)) +function heevd!(A::StridedCuMatrix{T}, + W::StridedCuVector{Tr}, + V::StridedCuMatrix{T}; + uplo::Char='U') where {T<:BlasFloat, Tr<:BlasReal} + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + dh = CUSOLVER.dense_handle() + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = 'N' + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = 'V' + end -# if jobz == 'N' -# return W -# elseif jobz == 'V' -# return W, A -# end -# end -# end -# end + params = CUSOLVER.CuSolverParameters() + function bufferSize() + out_cpu = Ref{Csize_t}(0) + out_gpu = Ref{Csize_t}(0) + CUSOLVER.cusolverDnXsyevd_bufferSize(dh, params, jobz, uplo, n, T, A, lda, Tr, W, T, out_gpu, out_cpu) + return out_gpu[], out_cpu[] + end + + CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu, + bufferSize()...) do buffer_gpu, buffer_cpu + return CUSOLVER.cusolverDnXsyevd(dh, params, jobz, uplo, n, T, A, lda, Tr, W, + T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, + sizeof(buffer_cpu), dh.info) + end + + info = @allowscalar dh.info[1] + chkargsok(BlasInt(info)) + + if jobz == 'V' && V !== A + copy!(V, A) + end + return W, V +end # device code is unreachable by coverage right now # COV_EXCL_START diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 035caf11e..6e28cf969 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -34,8 +34,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_DivideAndConquer, LAPACK_Jacobi, LQViaTransposedQR, CUSOLVER_Simple, - CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, - ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi + CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer, + ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered VERSION >= v"1.11.0-DEV.469" && diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 1e6a47f40..9b261d05f 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -61,11 +61,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) YALAPACK.heevx!(A, Dd, V; alg.kwargs...) end # TODO: make this controllable using a `gaugefix` keyword argument - for j in 1:size(V, 2) - v = view(V, :, j) - s = conj(sign(argmax(abs, v))) - v .*= s - end + V = gaugefix!(V) return D, V end @@ -88,3 +84,45 @@ function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm) D, V = eigh_full!(A, DV, alg.alg) return truncate!(eigh_trunc!, (D, V), alg.trunc) end + +_gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevj!, (A, Dd, V))) +_gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevd!, (A, Dd, V))) +_gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heev!, (A, Dd, V))) +_gpu_heevx!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevx!, (A, Dd, V))) + +function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) + check_input(eigh_full!, A, DV, alg) + D, V = DV + Dd = D.diag + if alg isa GPU_Jacobi + _gpu_heevj!(A, Dd, V; alg.kwargs...) + elseif alg isa GPU_DivideAndConquer + _gpu_heevd!(A, Dd, V; alg.kwargs...) + elseif alg isa GPU_QRIteration # alg isa GPU_QRIteration == GPU_Simple + _gpu_heev!(A, Dd, V; alg.kwargs...) + elseif alg isa GPU_Bisection # alg isa GPU_Bisection == GPU_Expert + _gpu_heevx!(A, Dd, V; alg.kwargs...) + else + throw(ArgumentError("Unsupported eigh algorithm")) + end + # TODO: make this controllable using a `gaugefix` keyword argument + V = gaugefix!(V) + return D, V +end + +function eigh_vals!(A::AbstractMatrix, D, alg::GPU_EighAlgorithm) + check_input(eigh_vals!, A, D, alg) + V = similar(A, (size(A, 1), 0)) + if alg isa GPU_Jacobi + _gpu_heevj!(A, D, V; alg.kwargs...) + elseif alg isa GPU_DivideAndConquer + _gpu_heevd!(A, D, V; alg.kwargs...) + elseif alg isa GPU_QRIteration + _gpu_heev!(A, D, V; alg.kwargs...) + elseif alg isa GPU_Bisection + _gpu_heevx!(A, D, V; alg.kwargs...) + else + throw(ArgumentError("Unsupported eigh algorithm")) + end + return D +end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 83e5f2206..15d9137e7 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -215,9 +215,7 @@ const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} -const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} -const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} const GPU_Randomized = Union{CUSOLVER_Randomized} function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 1310944e0..722f90111 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -173,6 +173,16 @@ eigenvalue decomposition of a matrix. @algdef CUSOLVER_Simple const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} + +""" + CUSOLVER_DivideAndConquer() + +Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Divide and Conquer algorithm. +""" +@algdef CUSOLVER_DivideAndConquer + # ========================= # ROCSOLVER ALGORITHMS # ========================= @@ -202,5 +212,33 @@ a general matrix using the Jacobi algorithm. """ @algdef ROCSOLVER_Jacobi +""" + ROCSOLVER_Bisection() + +Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Bisection algorithm. +""" +@algdef ROCSOLVER_Bisection + +""" + ROCSOLVER_DivideAndConquer() + +Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Divide and Conquer algorithm. +""" +@algdef ROCSOLVER_DivideAndConquer + + const GPU_Simple = Union{CUSOLVER_Simple} const GPU_EigAlgorithm = Union{GPU_Simple} +const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} +const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} +const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer} +const GPU_Bisection = Union{ROCSOLVER_Bisection} +const GPU_EighAlgorithm = Union{GPU_QRIteration, + GPU_Jacobi, + GPU_DivideAndConquer, + GPU_Bisection} + diff --git a/src/yalapack.jl b/src/yalapack.jl index ec37b6745..3d7b0369a 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -921,8 +921,8 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in end chkuplofinite(A, uplo) if haskey(kwargs, :irange) - il = first(irange) - iu = last(irange) + il = first(kwargs[:irange]) + iu = last(kwargs[:irange]) vl = vu = zero($relty) range = 'I' elseif haskey(kwargs, :vl) || haskey(kwargs, :vu) diff --git a/test/amd/eigh.jl b/test/amd/eigh.jl new file mode 100644 index 000000000..44be84952 --- /dev/null +++ b/test/amd/eigh.jl @@ -0,0 +1,80 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I +using MatrixAlgebraKit: TruncatedAlgorithm, diagview +using AMDGPU + +@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for alg in (ROCSOLVER_DivideAndConquer(), + ROCSOLVER_Jacobi(), + ROCSOLVER_Bisection(), + ROCSOLVER_QRIteration(), + ) + A = ROCArray(randn(rng, T, m, m)) + A = (A + A') / 2 + + D, V = @constinferred eigh_full(A; alg) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(copy(A), (D, V), alg) + @test D2 === D + @test V2 === V + + D3 = @constinferred eigh_vals(A, alg) + @test parent(D) ≈ D3 + end +end + +#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for alg in (CUSOLVER_QRIteration(), + CUSOLVER_DivideAndConquer(), + ) + A = ROCArray(randn(rng, T, m, m)) + A = A * A' + A = (A + A') / 2 + Ac = similar(A) + D₀ = reverse(eigh_vals(A)) + r = m - 2 + s = 1 + sqrt(eps(real(T))) + + D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + @test length(diagview(D1)) == r + @test isisometry(V1) + @test A * V1 ≈ V1 * D1 + @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + + trunc = trunctol(s * D₀[r + 1]) + D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + @test length(diagview(D2)) == r + @test isisometry(V2) + @test A * V2 ≈ V2 * D2 + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + end +end + +@testset "eigh_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, + ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + V = qr_compact(ROCArray(randn(rng, T, m, m)))[1] + D = Diagonal([0.9, 0.3, 0.1, 0.01]) + A = V * D * V' + A = (A + A') / 2 + alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) + D2, V2 = @constinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) +end=# diff --git a/test/cuda/eigh.jl b/test/cuda/eigh.jl new file mode 100644 index 000000000..c15bbb12e --- /dev/null +++ b/test/cuda/eigh.jl @@ -0,0 +1,78 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I +using MatrixAlgebraKit: TruncatedAlgorithm, diagview +using CUDA + +@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for alg in (CUSOLVER_DivideAndConquer(), + CUSOLVER_Jacobi(), + ) + A = CuArray(randn(rng, T, m, m)) + A = (A + A') / 2 + + D, V = @constinferred eigh_full(A; alg) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(copy(A), (D, V), alg) + @test D2 === D + @test V2 === V + + D3 = @constinferred eigh_vals(A, alg) + @test parent(D) ≈ D3 + end +end + +#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for alg in (CUSOLVER_QRIteration(), + CUSOLVER_DivideAndConquer(), + ) + A = CuArray(randn(rng, T, m, m)) + A = A * A' + A = (A + A') / 2 + Ac = similar(A) + D₀ = reverse(eigh_vals(A)) + r = m - 2 + s = 1 + sqrt(eps(real(T))) + + D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + @test length(diagview(D1)) == r + @test isisometry(V1) + @test A * V1 ≈ V1 * D1 + @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + + trunc = trunctol(s * D₀[r + 1]) + D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + @test length(diagview(D2)) == r + @test isisometry(V2) + @test A * V2 ≈ V2 * D2 + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + end +end + +@testset "eigh_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, + ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + V = qr_compact(CuArray(randn(rng, T, m, m)))[1] + D = Diagonal([0.9, 0.3, 0.1, 0.01]) + A = V * D * V' + A = (A + A') / 2 + alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) + D2, V2 = @constinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) +end=# diff --git a/test/runtests.jl b/test/runtests.jl index 9e6ed13ec..d85e8a32c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,6 +66,9 @@ if CUDA.functional() @safetestset "CUDA General Eigenvalue Decomposition" begin include("cuda/eig.jl") end + @safetestset "CUDA Hermitian Eigenvalue Decomposition" begin + include("cuda/eigh.jl") + end end using AMDGPU @@ -79,4 +82,7 @@ if AMDGPU.functional() @safetestset "AMDGPU SVD" begin include("amd/svd.jl") end + @safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin + include("amd/eigh.jl") + end end