Skip to content

Commit e0c8aa5

Browse files
committed
Support eigh for CUDA
1 parent cbecd7b commit e0c8aa5

8 files changed

Lines changed: 255 additions & 43 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR
8-
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
8+
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
1011
using CUDA
1112
using LinearAlgebra
1213
using LinearAlgebra: BlasFloat
@@ -23,6 +24,9 @@ end
2324
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
2425
return CUSOLVER_QRIteration(; kwargs...)
2526
end
27+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
28+
return CUSOLVER_DivideAndConquer(; kwargs...)
29+
end
2630

2731

2832
_gpu_geqrf!(A::StridedCuMatrix) = YACUSOLVER.geqrf!(A)
@@ -33,4 +37,7 @@ _gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::
3337
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
3438
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
3539

40+
_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...)
41+
_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...)
42+
3643
end

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module YACUSOLVER
22

33
using LinearAlgebra
4-
using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing
4+
using LinearAlgebra: BlasInt, BlasFloat, BlasReal, checksquare, chkstride1, require_one_based_indexing
55
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo
66

77
using CUDA
@@ -612,42 +612,92 @@ end
612612
# return X, info
613613
# end
614614

615-
# for (jname, bname, fname, elty, relty) in
616-
# ((:syevd!, :cusolverDnSsyevd_bufferSize, :cusolverDnSsyevd, :Float32, :Float32),
617-
# (:syevd!, :cusolverDnDsyevd_bufferSize, :cusolverDnDsyevd, :Float64, :Float64),
618-
# (:heevd!, :cusolverDnCheevd_bufferSize, :cusolverDnCheevd, :ComplexF32, :Float32),
619-
# (:heevd!, :cusolverDnZheevd_bufferSize, :cusolverDnZheevd, :ComplexF64, :Float64))
620-
# @eval begin
621-
# function $jname(jobz::Char,
622-
# uplo::Char,
623-
# A::StridedCuMatrix{$elty})
624-
# chkuplo(uplo)
625-
# n = checksquare(A)
626-
# lda = max(1, stride(A, 2))
627-
# W = CuArray{$relty}(undef, n)
628-
# dh = dense_handle()
615+
for (bname, fname, elty, relty) in ((:(CUSOLVER.cusolverDnSsyevj_bufferSize), :(CUSOLVER.cusolverDnSsyevj), :Float32, :Float32),
616+
(:(CUSOLVER.cusolverDnDsyevj_bufferSize), :(CUSOLVER.cusolverDnDsyevj), :Float64, :Float64),
617+
(:(CUSOLVER.cusolverDnCheevj_bufferSize), :(CUSOLVER.cusolverDnCheevj), :ComplexF32, :Float32),
618+
(:(CUSOLVER.cusolverDnZheevj_bufferSize), :(CUSOLVER.cusolverDnZheevj), :ComplexF64, :Float64))
619+
@eval begin
620+
function heevj!(A::StridedCuMatrix{$elty},
621+
W::StridedCuVector{$relty},
622+
V::StridedCuMatrix{$elty};
623+
uplo::Char='U',
624+
tol::$relty=eps($relty),
625+
max_sweeps::Int=100
626+
)
627+
chkuplo(uplo)
628+
n = checksquare(A)
629+
lda = max(1, stride(A, 2))
630+
dh = CUSOLVER.dense_handle()
631+
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
632+
if length(V) == 0
633+
jobz = 'N'
634+
else
635+
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
636+
jobz = 'V'
637+
end
638+
params = Ref{CUSOLVER.syevjInfo_t}(C_NULL)
639+
CUSOLVER.cusolverDnCreateSyevjInfo(params)
640+
CUSOLVER.cusolverDnXsyevjSetTolerance(params[], tol)
641+
CUSOLVER.cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps)
642+
function bufferSize()
643+
out = Ref{Cint}(0)
644+
$bname(dh, jobz, uplo, n, A, lda, W, out, params[])
645+
return out[] * sizeof($elty)
646+
end
647+
CUDA.with_workspace(dh.workspace_gpu, bufferSize) do buffer
648+
$fname(dh, jobz, uplo, n, A, lda, W, buffer,
649+
sizeof(buffer) ÷ sizeof($elty), dh.info, params[])
650+
end
629651

630-
# function bufferSize()
631-
# out = Ref{Cint}(0)
632-
# $bname(dh, jobz, uplo, n, A, lda, W, out)
633-
# return out[] * sizeof($elty)
634-
# end
652+
info = @allowscalar dh.info[1]
653+
chkargsok(BlasInt(info))
635654

636-
# with_workspace(dh.workspace_gpu, bufferSize) do buffer
637-
# return $fname(dh, jobz, uplo, n, A, lda, W,
638-
# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info)
639-
# end
655+
if jobz == 'V' && V !== A
656+
copy!(V, A)
657+
end
658+
return W, V
659+
end
660+
end
661+
end
640662

641-
# info = @allowscalar dh.info[1]
642-
# chkargsok(BlasInt(info))
663+
function heevd!(A::StridedCuMatrix{T},
664+
W::StridedCuVector{Tr},
665+
V::StridedCuMatrix{T};
666+
uplo::Char='U') where {T<:BlasFloat, Tr<:BlasReal}
667+
chkuplo(uplo)
668+
n = checksquare(A)
669+
lda = max(1, stride(A, 2))
670+
dh = CUSOLVER.dense_handle()
671+
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
672+
if length(V) == 0
673+
jobz = 'N'
674+
else
675+
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
676+
jobz = 'V'
677+
end
643678

644-
# if jobz == 'N'
645-
# return W
646-
# elseif jobz == 'V'
647-
# return W, A
648-
# end
649-
# end
650-
# end
651-
# end
679+
params = CUSOLVER.CuSolverParameters()
680+
function bufferSize()
681+
out_cpu = Ref{Csize_t}(0)
682+
out_gpu = Ref{Csize_t}(0)
683+
CUSOLVER.cusolverDnXsyevd_bufferSize(dh, params, jobz, uplo, n, T, A, lda, Tr, W, T, out_gpu, out_cpu)
684+
return out_gpu[], out_cpu[]
685+
end
686+
687+
CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
688+
bufferSize()...) do buffer_gpu, buffer_cpu
689+
return CUSOLVER.cusolverDnXsyevd(dh, params, jobz, uplo, n, T, A, lda, Tr, W,
690+
T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu,
691+
sizeof(buffer_cpu), dh.info)
692+
end
693+
694+
info = @allowscalar dh.info[1]
695+
chkargsok(BlasInt(info))
696+
697+
if jobz == 'V' && V !== A
698+
copy!(V, A)
699+
end
700+
return W, V
701+
end
652702

653703
end

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3333
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3434
LAPACK_DivideAndConquer, LAPACK_Jacobi,
3535
LQViaTransposedQR,
36-
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized,
36+
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer,
3737
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi
3838
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
3939

src/implementations/eigh.jl

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
6161
YALAPACK.heevx!(A, Dd, V; alg.kwargs...)
6262
end
6363
# TODO: make this controllable using a `gaugefix` keyword argument
64-
for j in 1:size(V, 2)
65-
v = view(V, :, j)
66-
s = conj(sign(argmax(abs, v)))
67-
v .*= s
68-
end
64+
V = gaugefix!(V)
6965
return D, V
7066
end
7167

@@ -88,3 +84,45 @@ function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
8884
D, V = eigh_full!(A, DV, alg.alg)
8985
return truncate!(eigh_trunc!, (D, V), alg.trunc)
9086
end
87+
88+
_gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevj!, (A, Dd, V)))
89+
_gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevd!, (A, Dd, V)))
90+
_gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heev!, (A, Dd, V)))
91+
_gpu_heevx!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevx!, (A, Dd, V)))
92+
93+
function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm)
94+
check_input(eigh_full!, A, DV, alg)
95+
D, V = DV
96+
Dd = D.diag
97+
if alg isa GPU_Jacobi
98+
_gpu_heevj!(A, Dd, V; alg.kwargs...)
99+
elseif alg isa GPU_DivideAndConquer
100+
_gpu_heevd!(A, Dd, V; alg.kwargs...)
101+
elseif alg isa GPU_QRIteration # alg isa GPU_QRIteration == GPU_Simple
102+
_gpu_heev!(A, Dd, V; alg.kwargs...)
103+
elseif alg isa GPU_Bisection # alg isa GPU_Bisection == GPU_Expert
104+
_gpu_heevx!(A, Dd, V; alg.kwargs...)
105+
else
106+
throw(ArgumentError("Unsupported eigh algorithm"))
107+
end
108+
# TODO: make this controllable using a `gaugefix` keyword argument
109+
V = gaugefix!(V)
110+
return D, V
111+
end
112+
113+
function eigh_vals!(A::AbstractMatrix, D, alg::GPU_EighAlgorithm)
114+
check_input(eigh_vals!, A, D, alg)
115+
V = similar(A, (size(A, 1), 0))
116+
if alg isa GPU_Jacobi
117+
_gpu_heevj!(A, D, V; alg.kwargs...)
118+
elseif alg isa GPU_DivideAndConquer
119+
_gpu_heevd!(A, D, V; alg.kwargs...)
120+
elseif alg isa GPU_QRIteration
121+
_gpu_heev!(A, D, V; alg.kwargs...)
122+
elseif alg isa GPU_Bisection
123+
_gpu_heevx!(A, D, V; alg.kwargs...)
124+
else
125+
throw(ArgumentError("Unsupported eigh algorithm"))
126+
end
127+
return D
128+
end

src/implementations/svd.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,7 @@ const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration,
215215
ROCSOLVER_Jacobi}
216216
const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm}
217217

218-
const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration}
219218
const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
220-
const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi}
221219
const GPU_Randomized = Union{CUSOLVER_Randomized}
222220

223221
function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized)

src/interface/decompositions.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,15 @@ a general matrix using the randomized SVD algorithm.
164164
"""
165165
@algdef CUSOLVER_Randomized
166166

167+
"""
168+
CUSOLVER_DivideAndConquer()
169+
170+
Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a
171+
Hermitian matrix, or the singular value decomposition of a general matrix using the
172+
Divide and Conquer algorithm.
173+
"""
174+
@algdef CUSOLVER_DivideAndConquer
175+
167176
# =========================
168177
# ROCSOLVER ALGORITHMS
169178
# =========================
@@ -192,3 +201,32 @@ Algorithm type to denote the ROCSOLVER driver for computing the singular value d
192201
a general matrix using the Jacobi algorithm.
193202
"""
194203
@algdef ROCSOLVER_Jacobi
204+
205+
"""
206+
ROCSOLVER_Bisection()
207+
208+
Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a
209+
Hermitian matrix, or the singular value decomposition of a general matrix using the
210+
Bisection algorithm.
211+
"""
212+
@algdef ROCSOLVER_Bisection
213+
214+
"""
215+
ROCSOLVER_DivideAndConquer()
216+
217+
Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a
218+
Hermitian matrix, or the singular value decomposition of a general matrix using the
219+
Divide and Conquer algorithm.
220+
"""
221+
@algdef ROCSOLVER_DivideAndConquer
222+
223+
224+
const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration}
225+
const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi}
226+
const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer}
227+
const GPU_Bisection = Union{ROCSOLVER_Bisection}
228+
const GPU_EighAlgorithm = Union{GPU_QRIteration,
229+
GPU_Jacobi,
230+
GPU_DivideAndConquer,
231+
GPU_Bisection}
232+

test/cuda/eigh.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, I
6+
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
7+
using CUDA
8+
9+
@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
10+
rng = StableRNG(123)
11+
m = 54
12+
for alg in (CUSOLVER_DivideAndConquer(),
13+
CUSOLVER_Jacobi(),
14+
)
15+
A = CuArray(randn(rng, T, m, m))
16+
A = (A + A') / 2
17+
18+
D, V = @constinferred eigh_full(A; alg)
19+
@test A * V V * D
20+
@test isunitary(V)
21+
@test all(isreal, D)
22+
23+
D2, V2 = eigh_full!(copy(A), (D, V), alg)
24+
@test D2 === D
25+
@test V2 === V
26+
27+
D3 = @constinferred eigh_vals(A, alg)
28+
@test parent(D) D3
29+
end
30+
end
31+
32+
#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
33+
rng = StableRNG(123)
34+
m = 54
35+
for alg in (CUSOLVER_QRIteration(),
36+
CUSOLVER_DivideAndConquer(),
37+
)
38+
A = CuArray(randn(rng, T, m, m))
39+
A = A * A'
40+
A = (A + A') / 2
41+
Ac = similar(A)
42+
D₀ = reverse(eigh_vals(A))
43+
r = m - 2
44+
s = 1 + sqrt(eps(real(T)))
45+
46+
D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r))
47+
@test length(diagview(D1)) == r
48+
@test isisometry(V1)
49+
@test A * V1 ≈ V1 * D1
50+
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1]
51+
52+
trunc = trunctol(s * D₀[r + 1])
53+
D2, V2 = @constinferred eigh_trunc(A; alg, trunc)
54+
@test length(diagview(D2)) == r
55+
@test isisometry(V2)
56+
@test A * V2 ≈ V2 * D2
57+
58+
# test for same subspace
59+
@test V1 * (V1' * V2) ≈ V2
60+
@test V2 * (V2' * V1) ≈ V1
61+
end
62+
end
63+
64+
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
65+
(Float32, Float64,
66+
ComplexF32,
67+
ComplexF64)
68+
rng = StableRNG(123)
69+
m = 4
70+
V = qr_compact(CuArray(randn(rng, T, m, m)))[1]
71+
D = Diagonal([0.9, 0.3, 0.1, 0.01])
72+
A = V * D * V'
73+
A = (A + A') / 2
74+
alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2))
75+
D2, V2 = @constinferred eigh_trunc(A; alg)
76+
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
77+
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
78+
end=#

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ if CUDA.functional()
6363
@safetestset "CUDA SVD" begin
6464
include("cuda/svd.jl")
6565
end
66+
@safetestset "CUDA Hermitian Eigenvalue Decomposition" begin
67+
include("cuda/eigh.jl")
68+
end
6669
end
6770

6871
using AMDGPU

0 commit comments

Comments
 (0)