Skip to content

Commit 49cdee0

Browse files
committed
Support eigh for CUDA
1 parent 36ef228 commit 49cdee0

8 files changed

Lines changed: 255 additions & 43 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR
8-
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm
8+
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
1011
using CUDA
1112
using LinearAlgebra
1213
using LinearAlgebra: BlasFloat
@@ -26,6 +27,9 @@ end
2627
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
2728
return CUSOLVER_Simple(; kwargs...)
2829
end
30+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
31+
return CUSOLVER_DivideAndConquer(; kwargs...)
32+
end
2933

3034

3135
_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V)
@@ -37,4 +41,7 @@ _gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::
3741
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
3842
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
3943

44+
_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...)
45+
_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...)
46+
4047
end

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module YACUSOLVER
22

33
using LinearAlgebra
4-
using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing
4+
using LinearAlgebra: BlasInt, BlasFloat, BlasReal, checksquare, chkstride1, require_one_based_indexing
55
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo
66

77
using CUDA
@@ -679,43 +679,93 @@ end
679679
# return X, info
680680
# end
681681

682-
# for (jname, bname, fname, elty, relty) in
683-
# ((:syevd!, :cusolverDnSsyevd_bufferSize, :cusolverDnSsyevd, :Float32, :Float32),
684-
# (:syevd!, :cusolverDnDsyevd_bufferSize, :cusolverDnDsyevd, :Float64, :Float64),
685-
# (:heevd!, :cusolverDnCheevd_bufferSize, :cusolverDnCheevd, :ComplexF32, :Float32),
686-
# (:heevd!, :cusolverDnZheevd_bufferSize, :cusolverDnZheevd, :ComplexF64, :Float64))
687-
# @eval begin
688-
# function $jname(jobz::Char,
689-
# uplo::Char,
690-
# A::StridedCuMatrix{$elty})
691-
# chkuplo(uplo)
692-
# n = checksquare(A)
693-
# lda = max(1, stride(A, 2))
694-
# W = CuArray{$relty}(undef, n)
695-
# dh = dense_handle()
682+
for (bname, fname, elty, relty) in ((:(CUSOLVER.cusolverDnSsyevj_bufferSize), :(CUSOLVER.cusolverDnSsyevj), :Float32, :Float32),
683+
(:(CUSOLVER.cusolverDnDsyevj_bufferSize), :(CUSOLVER.cusolverDnDsyevj), :Float64, :Float64),
684+
(:(CUSOLVER.cusolverDnCheevj_bufferSize), :(CUSOLVER.cusolverDnCheevj), :ComplexF32, :Float32),
685+
(:(CUSOLVER.cusolverDnZheevj_bufferSize), :(CUSOLVER.cusolverDnZheevj), :ComplexF64, :Float64))
686+
@eval begin
687+
function heevj!(A::StridedCuMatrix{$elty},
688+
W::StridedCuVector{$relty},
689+
V::StridedCuMatrix{$elty};
690+
uplo::Char='U',
691+
tol::$relty=eps($relty),
692+
max_sweeps::Int=100
693+
)
694+
chkuplo(uplo)
695+
n = checksquare(A)
696+
lda = max(1, stride(A, 2))
697+
dh = CUSOLVER.dense_handle()
698+
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
699+
if length(V) == 0
700+
jobz = 'N'
701+
else
702+
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
703+
jobz = 'V'
704+
end
705+
params = Ref{CUSOLVER.syevjInfo_t}(C_NULL)
706+
CUSOLVER.cusolverDnCreateSyevjInfo(params)
707+
CUSOLVER.cusolverDnXsyevjSetTolerance(params[], tol)
708+
CUSOLVER.cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps)
709+
function bufferSize()
710+
out = Ref{Cint}(0)
711+
$bname(dh, jobz, uplo, n, A, lda, W, out, params[])
712+
return out[] * sizeof($elty)
713+
end
714+
CUDA.with_workspace(dh.workspace_gpu, bufferSize) do buffer
715+
$fname(dh, jobz, uplo, n, A, lda, W, buffer,
716+
sizeof(buffer) ÷ sizeof($elty), dh.info, params[])
717+
end
696718

697-
# function bufferSize()
698-
# out = Ref{Cint}(0)
699-
# $bname(dh, jobz, uplo, n, A, lda, W, out)
700-
# return out[] * sizeof($elty)
701-
# end
719+
info = @allowscalar dh.info[1]
720+
chkargsok(BlasInt(info))
702721

703-
# with_workspace(dh.workspace_gpu, bufferSize) do buffer
704-
# return $fname(dh, jobz, uplo, n, A, lda, W,
705-
# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info)
706-
# end
722+
if jobz == 'V' && V !== A
723+
copy!(V, A)
724+
end
725+
return W, V
726+
end
727+
end
728+
end
707729

708-
# info = @allowscalar dh.info[1]
709-
# chkargsok(BlasInt(info))
730+
function heevd!(A::StridedCuMatrix{T},
731+
W::StridedCuVector{Tr},
732+
V::StridedCuMatrix{T};
733+
uplo::Char='U') where {T<:BlasFloat, Tr<:BlasReal}
734+
chkuplo(uplo)
735+
n = checksquare(A)
736+
lda = max(1, stride(A, 2))
737+
dh = CUSOLVER.dense_handle()
738+
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
739+
if length(V) == 0
740+
jobz = 'N'
741+
else
742+
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
743+
jobz = 'V'
744+
end
710745

711-
# if jobz == 'N'
712-
# return W
713-
# elseif jobz == 'V'
714-
# return W, A
715-
# end
716-
# end
717-
# end
718-
# end
746+
params = CUSOLVER.CuSolverParameters()
747+
function bufferSize()
748+
out_cpu = Ref{Csize_t}(0)
749+
out_gpu = Ref{Csize_t}(0)
750+
CUSOLVER.cusolverDnXsyevd_bufferSize(dh, params, jobz, uplo, n, T, A, lda, Tr, W, T, out_gpu, out_cpu)
751+
return out_gpu[], out_cpu[]
752+
end
753+
754+
CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
755+
bufferSize()...) do buffer_gpu, buffer_cpu
756+
return CUSOLVER.cusolverDnXsyevd(dh, params, jobz, uplo, n, T, A, lda, Tr, W,
757+
T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu,
758+
sizeof(buffer_cpu), dh.info)
759+
end
760+
761+
info = @allowscalar dh.info[1]
762+
chkargsok(BlasInt(info))
763+
764+
if jobz == 'V' && V !== A
765+
copy!(V, A)
766+
end
767+
return W, V
768+
end
719769

720770
# device code is unreachable by coverage right now
721771
# COV_EXCL_START

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3434
LAPACK_DivideAndConquer, LAPACK_Jacobi,
3535
LQViaTransposedQR,
3636
CUSOLVER_Simple,
37-
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized,
37+
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer,
3838
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi
3939
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
4040

src/implementations/eigh.jl

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
6161
YALAPACK.heevx!(A, Dd, V; alg.kwargs...)
6262
end
6363
# TODO: make this controllable using a `gaugefix` keyword argument
64-
for j in 1:size(V, 2)
65-
v = view(V, :, j)
66-
s = conj(sign(argmax(abs, v)))
67-
v .*= s
68-
end
64+
V = gaugefix!(V)
6965
return D, V
7066
end
7167

@@ -88,3 +84,45 @@ function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
8884
D, V = eigh_full!(A, DV, alg.alg)
8985
return truncate!(eigh_trunc!, (D, V), alg.trunc)
9086
end
87+
88+
_gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevj!, (A, Dd, V)))
89+
_gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevd!, (A, Dd, V)))
90+
_gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heev!, (A, Dd, V)))
91+
_gpu_heevx!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevx!, (A, Dd, V)))
92+
93+
function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm)
94+
check_input(eigh_full!, A, DV, alg)
95+
D, V = DV
96+
Dd = D.diag
97+
if alg isa GPU_Jacobi
98+
_gpu_heevj!(A, Dd, V; alg.kwargs...)
99+
elseif alg isa GPU_DivideAndConquer
100+
_gpu_heevd!(A, Dd, V; alg.kwargs...)
101+
elseif alg isa GPU_QRIteration # alg isa GPU_QRIteration == GPU_Simple
102+
_gpu_heev!(A, Dd, V; alg.kwargs...)
103+
elseif alg isa GPU_Bisection # alg isa GPU_Bisection == GPU_Expert
104+
_gpu_heevx!(A, Dd, V; alg.kwargs...)
105+
else
106+
throw(ArgumentError("Unsupported eigh algorithm"))
107+
end
108+
# TODO: make this controllable using a `gaugefix` keyword argument
109+
V = gaugefix!(V)
110+
return D, V
111+
end
112+
113+
function eigh_vals!(A::AbstractMatrix, D, alg::GPU_EighAlgorithm)
114+
check_input(eigh_vals!, A, D, alg)
115+
V = similar(A, (size(A, 1), 0))
116+
if alg isa GPU_Jacobi
117+
_gpu_heevj!(A, D, V; alg.kwargs...)
118+
elseif alg isa GPU_DivideAndConquer
119+
_gpu_heevd!(A, D, V; alg.kwargs...)
120+
elseif alg isa GPU_QRIteration
121+
_gpu_heev!(A, D, V; alg.kwargs...)
122+
elseif alg isa GPU_Bisection
123+
_gpu_heevx!(A, D, V; alg.kwargs...)
124+
else
125+
throw(ArgumentError("Unsupported eigh algorithm"))
126+
end
127+
return D
128+
end

src/implementations/svd.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,7 @@ const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration,
215215
ROCSOLVER_Jacobi}
216216
const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm}
217217

218-
const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration}
219218
const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
220-
const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi}
221219
const GPU_Randomized = Union{CUSOLVER_Randomized}
222220

223221
function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized)

src/interface/decompositions.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ eigenvalue decomposition of a matrix.
173173
@algdef CUSOLVER_Simple
174174

175175
const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple}
176+
177+
"""
178+
CUSOLVER_DivideAndConquer()
179+
180+
Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a
181+
Hermitian matrix, or the singular value decomposition of a general matrix using the
182+
Divide and Conquer algorithm.
183+
"""
184+
@algdef CUSOLVER_DivideAndConquer
185+
176186
# =========================
177187
# ROCSOLVER ALGORITHMS
178188
# =========================
@@ -202,5 +212,33 @@ a general matrix using the Jacobi algorithm.
202212
"""
203213
@algdef ROCSOLVER_Jacobi
204214

215+
"""
216+
ROCSOLVER_Bisection()
217+
218+
Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a
219+
Hermitian matrix, or the singular value decomposition of a general matrix using the
220+
Bisection algorithm.
221+
"""
222+
@algdef ROCSOLVER_Bisection
223+
224+
"""
225+
ROCSOLVER_DivideAndConquer()
226+
227+
Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a
228+
Hermitian matrix, or the singular value decomposition of a general matrix using the
229+
Divide and Conquer algorithm.
230+
"""
231+
@algdef ROCSOLVER_DivideAndConquer
232+
233+
205234
const GPU_Simple = Union{CUSOLVER_Simple}
206235
const GPU_EigAlgorithm = Union{GPU_Simple}
236+
const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration}
237+
const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi}
238+
const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer}
239+
const GPU_Bisection = Union{ROCSOLVER_Bisection}
240+
const GPU_EighAlgorithm = Union{GPU_QRIteration,
241+
GPU_Jacobi,
242+
GPU_DivideAndConquer,
243+
GPU_Bisection}
244+

test/cuda/eigh.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, I
6+
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
7+
using CUDA
8+
9+
@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
10+
rng = StableRNG(123)
11+
m = 54
12+
for alg in (CUSOLVER_DivideAndConquer(),
13+
CUSOLVER_Jacobi(),
14+
)
15+
A = CuArray(randn(rng, T, m, m))
16+
A = (A + A') / 2
17+
18+
D, V = @constinferred eigh_full(A; alg)
19+
@test A * V V * D
20+
@test isunitary(V)
21+
@test all(isreal, D)
22+
23+
D2, V2 = eigh_full!(copy(A), (D, V), alg)
24+
@test D2 === D
25+
@test V2 === V
26+
27+
D3 = @constinferred eigh_vals(A, alg)
28+
@test parent(D) D3
29+
end
30+
end
31+
32+
#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
33+
rng = StableRNG(123)
34+
m = 54
35+
for alg in (CUSOLVER_QRIteration(),
36+
CUSOLVER_DivideAndConquer(),
37+
)
38+
A = CuArray(randn(rng, T, m, m))
39+
A = A * A'
40+
A = (A + A') / 2
41+
Ac = similar(A)
42+
D₀ = reverse(eigh_vals(A))
43+
r = m - 2
44+
s = 1 + sqrt(eps(real(T)))
45+
46+
D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r))
47+
@test length(diagview(D1)) == r
48+
@test isisometry(V1)
49+
@test A * V1 ≈ V1 * D1
50+
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1]
51+
52+
trunc = trunctol(s * D₀[r + 1])
53+
D2, V2 = @constinferred eigh_trunc(A; alg, trunc)
54+
@test length(diagview(D2)) == r
55+
@test isisometry(V2)
56+
@test A * V2 ≈ V2 * D2
57+
58+
# test for same subspace
59+
@test V1 * (V1' * V2) ≈ V2
60+
@test V2 * (V2' * V1) ≈ V1
61+
end
62+
end
63+
64+
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
65+
(Float32, Float64,
66+
ComplexF32,
67+
ComplexF64)
68+
rng = StableRNG(123)
69+
m = 4
70+
V = qr_compact(CuArray(randn(rng, T, m, m)))[1]
71+
D = Diagonal([0.9, 0.3, 0.1, 0.01])
72+
A = V * D * V'
73+
A = (A + A') / 2
74+
alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2))
75+
D2, V2 = @constinferred eigh_trunc(A; alg)
76+
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
77+
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
78+
end=#

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ if CUDA.functional()
6666
@safetestset "CUDA General Eigenvalue Decomposition" begin
6767
include("cuda/eig.jl")
6868
end
69+
@safetestset "CUDA Hermitian Eigenvalue Decomposition" begin
70+
include("cuda/eigh.jl")
71+
end
6972
end
7073

7174
using AMDGPU

0 commit comments

Comments
 (0)