Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand All @@ -23,6 +24,9 @@ end
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
return ROCSOLVER_QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
return ROCSOLVER_DivideAndConquer(; kwargs...)
end

_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)
Expand All @@ -32,4 +36,8 @@ _gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ:
#_gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)

_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...)
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...)
_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...)
_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...)
end
172 changes: 138 additions & 34 deletions ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module YArocSOLVER

using LinearAlgebra
using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing
using LinearAlgebra: BlasInt, BlasReal, BlasFloat, checksquare, chkstride1, require_one_based_indexing
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo

using AMDGPU
Expand Down Expand Up @@ -475,42 +475,146 @@ end
# return X, info
# end

# for (jname, bname, fname, elty, relty) in
# ((:syevd!, :rocsolverDnSsyevd_bufferSize, :rocsolverDnSsyevd, :Float32, :Float32),
# (:syevd!, :rocsolverDnDsyevd_bufferSize, :rocsolverDnDsyevd, :Float64, :Float64),
# (:heevd!, :rocsolverDnCheevd_bufferSize, :rocsolverDnCheevd, :ComplexF32, :Float32),
# (:heevd!, :rocsolverDnZheevd_bufferSize, :rocsolverDnZheevd, :ComplexF64, :Float64))
# @eval begin
# function $jname(jobz::Char,
# uplo::Char,
# A::StridedROCMatrix{$elty})
# chkuplo(uplo)
# n = checksquare(A)
# lda = max(1, stride(A, 2))
# W = CuArray{$relty}(undef, n)
# dh = rocBLAS.handle()
for (heevd, heev, heevx, heevj, elty, relty) in
((:(rocSOLVER.rocsolver_ssyevd), :(rocSOLVER.rocsolver_ssyev), :(rocSOLVER.rocsolver_ssyevx), :(rocSOLVER.rocsolver_ssyevj), :Float32, :Float32),
(:(rocSOLVER.rocsolver_dsyevd), :(rocSOLVER.rocsolver_dsyev), :(rocSOLVER.rocsolver_dsyevx), :(rocSOLVER.rocsolver_dsyevj), :Float64, :Float64),
(:(rocSOLVER.rocsolver_cheevd), :(rocSOLVER.rocsolver_cheev), :(rocSOLVER.rocsolver_cheevx), :(rocSOLVER.rocsolver_cheevj), :ComplexF32, :Float32),
(:(rocSOLVER.rocsolver_zheevd), :(rocSOLVER.rocsolver_zheev), :(rocSOLVER.rocsolver_zheevx), :(rocSOLVER.rocsolver_zheevj), :ComplexF64, :Float64))
@eval begin
function heevd!(A::StridedROCMatrix{$elty},
W::StridedROCVector{$relty},
V::StridedROCMatrix{$elty};
uplo::Char='U')
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
if length(V) == 0
jobz = rocSOLVER.rocblas_evect_none
else
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
jobz = rocSOLVER.rocblas_evect_original
end
dh = rocBLAS.handle()
work = ROCVector{$relty}(undef, n)
dev_info = ROCVector{Cint}(undef, 1)
roc_uplo = convert(rocSOLVER.rocblas_fill, uplo)
$heevd(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info)

# function bufferSize()
# out = Ref{Cint}(0)
# $bname(dh, jobz, uplo, n, A, lda, W, out)
# return out[] * sizeof($elty)
# end
info = @allowscalar dev_info[1]
chkargsok(BlasInt(info))

# with_workspace(dh.workspace_gpu, bufferSize) do buffer
# return $fname(dh, jobz, uplo, n, A, lda, W,
# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info)
# end
if jobz == rocSOLVER.rocblas_evect_original && V !== A
copy!(V, A)
end
return W, V
end
function heev!(A::StridedROCMatrix{$elty},
W::StridedROCVector{$relty},
V::StridedROCMatrix{$elty};
uplo::Char='U')
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
if length(V) == 0
jobz = rocSOLVER.rocblas_evect_none
else
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
jobz = rocSOLVER.rocblas_evect_original
end
dh = rocBLAS.handle()
work = ROCVector{$relty}(undef, n)
dev_info = ROCVector{Cint}(undef, 1)
roc_uplo = convert(rocSOLVER.rocblas_fill, uplo)
$heev(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info)

# info = @allowscalar dh.info[1]
# chkargsok(BlasInt(info))
info = @allowscalar dev_info[1]
chkargsok(BlasInt(info))

# if jobz == 'N'
# return W
# elseif jobz == 'V'
# return W, A
# end
# end
# end
# end
if jobz == rocSOLVER.rocblas_evect_original && V !== A
copy!(V, A)
end
return W, V
end
function heevx!(A::StridedROCMatrix{$elty},
W::StridedROCVector{$relty},
V::StridedROCMatrix{$elty};
uplo::Char='U',
kwargs...)
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
if haskey(kwargs, :irange)
il = first(kwargs[:irange])
iu = last(kwargs[:irange])
vl = vu = zero($relty)
range = rocSOLVER.rocblas_erange_index
elseif haskey(kwargs, :vl) || haskey(kwargs, :vu)
vl = convert($relty, get(kwargs, :vl, -Inf))
vu = convert($relty, get(kwargs, :vu, +Inf))
il = iu = 0
range = rocSOLVER.rocblas_erange_value
else
il = iu = 0
vl = vu = zero($relty)
range = rocSOLVER.rocblas_erange_all
end
if length(V) == 0
jobz = rocSOLVER.rocblas_evect_none
else
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
jobz = rocSOLVER.rocblas_evect_original
end
dh = rocBLAS.handle()
abstol = -one($relty)
nev = ROCVector{Cint}(undef, 1)
ldv = max(1, stride(V, 2))
ifail = ROCVector{Cint}(undef, n)
dev_info = ROCVector{Cint}(undef, 1)
roc_uplo = convert(rocSOLVER.rocblas_fill, uplo)
$heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, nev, W, V, ldv, ifail, dev_info)

info = @allowscalar dev_info[1]
chkargsok(BlasInt(info))
m = @allowscalar nev[1]
return W, V, m
end
function heevj!(A::StridedROCMatrix{$elty},
W::StridedROCVector{$relty},
V::StridedROCMatrix{$elty};
uplo::Char='U',
tol::$relty=eps($relty),
max_sweeps::Int=100,
sort::Char='N')
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
if length(V) == 0
jobz = rocSOLVER.rocblas_evect_none
else
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
jobz = rocSOLVER.rocblas_evect_original
end
dh = rocBLAS.handle()
dev_info = ROCVector{Cint}(undef, 1)
residual = ROCVector{$relty}(undef, 1)
n_sweeps = ROCVector{Cint}(undef, 1)
roc_uplo = convert(rocSOLVER.rocblas_fill, uplo)
roc_sort = sort == 'N' ? rocSOLVER.rocblas_esort_none : rocSOLVER.rocblas_esort_ascending
$heevj(dh, roc_sort, jobz, roc_uplo, n, A, lda, tol, residual, max_sweeps, n_sweeps, W, dev_info)

info = @allowscalar dev_info[1]
chkargsok(BlasInt(info))

if jobz == rocSOLVER.rocblas_evect_original && V !== A
copy!(V, A)
end
return W, V
end
end
end

end
9 changes: 8 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
using CUDA
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand All @@ -26,6 +27,9 @@ end
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
return CUSOLVER_DivideAndConquer(; kwargs...)
end


_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V)
Expand All @@ -37,4 +41,7 @@ _gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)

_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...)
_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...)

end
118 changes: 84 additions & 34 deletions ext/MatrixAlgebraKitCUDAExt/yacusolver.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module YACUSOLVER

using LinearAlgebra
using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing
using LinearAlgebra: BlasInt, BlasFloat, BlasReal, checksquare, chkstride1, require_one_based_indexing
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo

using CUDA
Expand Down Expand Up @@ -679,43 +679,93 @@ end
# return X, info
# end

# for (jname, bname, fname, elty, relty) in
# ((:syevd!, :cusolverDnSsyevd_bufferSize, :cusolverDnSsyevd, :Float32, :Float32),
# (:syevd!, :cusolverDnDsyevd_bufferSize, :cusolverDnDsyevd, :Float64, :Float64),
# (:heevd!, :cusolverDnCheevd_bufferSize, :cusolverDnCheevd, :ComplexF32, :Float32),
# (:heevd!, :cusolverDnZheevd_bufferSize, :cusolverDnZheevd, :ComplexF64, :Float64))
# @eval begin
# function $jname(jobz::Char,
# uplo::Char,
# A::StridedCuMatrix{$elty})
# chkuplo(uplo)
# n = checksquare(A)
# lda = max(1, stride(A, 2))
# W = CuArray{$relty}(undef, n)
# dh = dense_handle()
for (bname, fname, elty, relty) in ((:(CUSOLVER.cusolverDnSsyevj_bufferSize), :(CUSOLVER.cusolverDnSsyevj), :Float32, :Float32),
(:(CUSOLVER.cusolverDnDsyevj_bufferSize), :(CUSOLVER.cusolverDnDsyevj), :Float64, :Float64),
(:(CUSOLVER.cusolverDnCheevj_bufferSize), :(CUSOLVER.cusolverDnCheevj), :ComplexF32, :Float32),
(:(CUSOLVER.cusolverDnZheevj_bufferSize), :(CUSOLVER.cusolverDnZheevj), :ComplexF64, :Float64))
@eval begin
function heevj!(A::StridedCuMatrix{$elty},
W::StridedCuVector{$relty},
V::StridedCuMatrix{$elty};
uplo::Char='U',
tol::$relty=eps($relty),
max_sweeps::Int=100
)
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
dh = CUSOLVER.dense_handle()
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
if length(V) == 0
jobz = 'N'
else
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
jobz = 'V'
end
params = Ref{CUSOLVER.syevjInfo_t}(C_NULL)
CUSOLVER.cusolverDnCreateSyevjInfo(params)
CUSOLVER.cusolverDnXsyevjSetTolerance(params[], tol)
CUSOLVER.cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps)
function bufferSize()
out = Ref{Cint}(0)
$bname(dh, jobz, uplo, n, A, lda, W, out, params[])
return out[] * sizeof($elty)
end
CUDA.with_workspace(dh.workspace_gpu, bufferSize) do buffer
$fname(dh, jobz, uplo, n, A, lda, W, buffer,
sizeof(buffer) ÷ sizeof($elty), dh.info, params[])
end

# function bufferSize()
# out = Ref{Cint}(0)
# $bname(dh, jobz, uplo, n, A, lda, W, out)
# return out[] * sizeof($elty)
# end
info = @allowscalar dh.info[1]
chkargsok(BlasInt(info))

# with_workspace(dh.workspace_gpu, bufferSize) do buffer
# return $fname(dh, jobz, uplo, n, A, lda, W,
# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info)
# end
if jobz == 'V' && V !== A
copy!(V, A)
end
return W, V
end
end
end

# info = @allowscalar dh.info[1]
# chkargsok(BlasInt(info))
function heevd!(A::StridedCuMatrix{T},
W::StridedCuVector{Tr},
V::StridedCuMatrix{T};
uplo::Char='U') where {T<:BlasFloat, Tr<:BlasReal}
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
dh = CUSOLVER.dense_handle()
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
if length(V) == 0
jobz = 'N'
else
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
jobz = 'V'
end

# if jobz == 'N'
# return W
# elseif jobz == 'V'
# return W, A
# end
# end
# end
# end
params = CUSOLVER.CuSolverParameters()
function bufferSize()
out_cpu = Ref{Csize_t}(0)
out_gpu = Ref{Csize_t}(0)
CUSOLVER.cusolverDnXsyevd_bufferSize(dh, params, jobz, uplo, n, T, A, lda, Tr, W, T, out_gpu, out_cpu)
return out_gpu[], out_cpu[]
end

CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
bufferSize()...) do buffer_gpu, buffer_cpu
return CUSOLVER.cusolverDnXsyevd(dh, params, jobz, uplo, n, T, A, lda, Tr, W,
T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu,
sizeof(buffer_cpu), dh.info)
end

info = @allowscalar dh.info[1]
chkargsok(BlasInt(info))

if jobz == 'V' && V !== A
copy!(V, A)
end
return W, V
end

# device code is unreachable by coverage right now
# COV_EXCL_START
Expand Down
4 changes: 2 additions & 2 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
LAPACK_DivideAndConquer, LAPACK_Jacobi,
LQViaTransposedQR,
CUSOLVER_Simple,
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized,
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer,
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered

VERSION >= v"1.11.0-DEV.469" &&
Expand Down
Loading
Loading