Skip to content

Commit 36ef228

Browse files
kshyattKatharine Hyatt
andauthored
Initial attempt to wrap CUSOLVER.Xgeev (#47)
* Initial attempt to wrap CUSOLVER.Xgeev * Remove unneeded using * Fixup re comments --------- Co-authored-by: Katharine Hyatt <katharine.s.hyatt@gmail.com>
1 parent cbecd7b commit 36ef228

7 files changed

Lines changed: 232 additions & 3 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR
8-
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
9-
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!
8+
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm
9+
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010
using CUDA
1111
using LinearAlgebra
1212
using LinearAlgebra: BlasFloat
@@ -23,8 +23,12 @@ end
2323
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
2424
return CUSOLVER_QRIteration(; kwargs...)
2525
end
26+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
27+
return CUSOLVER_Simple(; kwargs...)
28+
end
2629

2730

31+
_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V)
2832
_gpu_geqrf!(A::StridedCuMatrix) = YACUSOLVER.geqrf!(A)
2933
_gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) = YACUSOLVER.ungqr!(A, τ)
3034
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) = YACUSOLVER.unmqr!(side, trans, A, τ, C)

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_ba
55
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo
66

77
using CUDA
8-
using CUDA: @allowscalar
8+
using CUDA: @allowscalar, i32
99
using CUDA.CUSOLVER
1010

1111
# QR methods are implemented with full access to allocated arrays, so we do not need to redo this:
@@ -306,6 +306,73 @@ function Xgesvdr!(A::StridedCuMatrix{T},
306306
return S, U, Vᴴ
307307
end
308308

309+
# Wrapper for general eigensolver
310+
for (celty, elty) in ((:ComplexF32, :Float32), (:ComplexF64, :Float64), (:ComplexF32, :ComplexF32), (:ComplexF64, :ComplexF64))
311+
@eval begin
312+
function Xgeev!(A::StridedCuMatrix{$elty}, D::StridedCuVector{$celty}, V::StridedCuMatrix{$celty})
313+
require_one_based_indexing(A, V, D)
314+
chkstride1(A, V, D)
315+
n = checksquare(A)
316+
# TODO GPU appropriate version
317+
#chkfinite(A) # balancing routines don't support NaNs and Infs
318+
n == length(D) || throw(DimensionMismatch("length mismatch between A and D"))
319+
if length(V) == 0
320+
jobvr = 'N'
321+
elseif length(V) == n*n
322+
jobvr = 'V'
323+
else
324+
throw(DimensionMismatch("size of VR must match size of A"))
325+
end
326+
jobvl = 'N' # required by API for now (https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdnxgeev)
327+
#=if length(VL) == 0
328+
jobvl = 'N'
329+
elseif length(VL) == n*n
330+
jobvl = 'V'
331+
else
332+
throw(DimensionMismatch("size of VL must match size of A"))
333+
end=#
334+
VL = similar(A, n, 0)
335+
lda = max(1, stride(A, 2))
336+
ldvl = max(1, stride(VL, 2))
337+
params = CUSOLVER.CuSolverParameters()
338+
dh = CUSOLVER.dense_handle()
339+
340+
if $elty <: Real
341+
D2 = reinterpret($elty, D)
342+
# reuse memory, we will have to reorder afterwards to bring real and imaginary
343+
# components in the order as required for the Complex type
344+
VR = reinterpret($elty, V)
345+
else
346+
D2 = D
347+
VR = V
348+
end
349+
ldvr = max(1, stride(VR, 2))
350+
function bufferSize()
351+
out_cpu = Ref{Csize_t}(0)
352+
out_gpu = Ref{Csize_t}(0)
353+
CUSOLVER.cusolverDnXgeev_bufferSize(dh, params, jobvl, jobvr, n, $elty, A,
354+
lda, $elty, D2, $elty, VL, ldvl, $elty, VR, ldvr,
355+
$elty, out_gpu, out_cpu)
356+
out_gpu[], out_cpu[]
357+
end
358+
CUDA.with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu
359+
CUSOLVER.cusolverDnXgeev(dh, params, jobvl, jobvr, n, $elty, A, lda, $elty,
360+
D2, $elty, VL, ldvl, $elty, VR, ldvr, $elty, buffer_gpu,
361+
sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), dh.info)
362+
end
363+
flag = @allowscalar dh.info[1]
364+
CUSOLVER.chkargsok(BlasInt(flag))
365+
if eltype(A) <: Real
366+
work = CuVector{$elty}(undef, n)
367+
DR = view(D2, 1:n)
368+
DI = view(D2, (n + 1):(2n))
369+
_reorder_realeigendecomposition!(D, DR, DI, work, VR, jobvr)
370+
end
371+
return D, V
372+
end
373+
end
374+
end
375+
309376
# for (jname, bname, fname, elty, relty) in
310377
# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
311378
# (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64),
@@ -650,4 +717,57 @@ end
650717
# end
651718
# end
652719

720+
# device code is unreachable by coverage right now
721+
# COV_EXCL_START
722+
# TODO use a shmem array here
723+
function _reorder_kernel_real(real_ev_ixs, VR::CuDeviceArray{T}, n::Int) where {T}
724+
grid_idx = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
725+
@inbounds if grid_idx <= length(real_ev_ixs)
726+
i = real_ev_ixs[grid_idx]
727+
for j in n:-1:1
728+
VR[2 * j, i] = zero(T)
729+
VR[2 * j - 1, i] = VR[j, i]
730+
end
731+
end
732+
return
733+
end
734+
735+
function _reorder_kernel_complex(complex_ev_ixs, VR::CuDeviceArray{T}, n::Int) where {T}
736+
grid_idx = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
737+
@inbounds if grid_idx <= length(complex_ev_ixs)
738+
i = complex_ev_ixs[grid_idx]
739+
for j in n:-1:1
740+
VR[2 * j, i] = VR[j, i + 1]
741+
VR[2 * j - 1, i] = VR[j, i]
742+
VR[2 * j, i + 1] = -VR[j, i + 1]
743+
VR[2 * j - 1, i + 1] = VR[j, i]
744+
end
745+
end
746+
return
747+
end
748+
# COV_EXCL_STOP
749+
750+
function _reorder_realeigendecomposition!(W, WR, WI, work, VR, jobvr)
751+
# first reorder eigenvalues and recycle work as temporary buffer to efficiently implement the permutation
752+
copy!(work, WI)
753+
n = size(W, 1)
754+
@. W[1:n] = WR[1:n] + im * work[1:n]
755+
T = eltype(WR)
756+
if jobvr == 'V' # also reorganise vectors
757+
real_ev_ixs = findall(isreal, W)
758+
_cmplx_ev_ixs = findall(!isreal, W) # these come in pairs, choose only the first of each pair
759+
complex_ev_ixs = view(_cmplx_ev_ixs, 1:2:length(_cmplx_ev_ixs))
760+
if !isempty(real_ev_ixs)
761+
real_threads = 128
762+
real_blocks = max(1, div(length(real_ev_ixs), real_threads))
763+
@cuda threads=real_threads blocks=real_blocks _reorder_kernel_real(real_ev_ixs, VR, n)
764+
end
765+
if !isempty(complex_ev_ixs)
766+
complex_threads = 128
767+
complex_blocks = max(1, div(length(complex_ev_ixs), complex_threads))
768+
@cuda threads=complex_threads blocks=complex_blocks _reorder_kernel_complex(complex_ev_ixs, VR, n)
769+
end
770+
end
771+
end
772+
653773
end

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3333
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3434
LAPACK_DivideAndConquer, LAPACK_Jacobi,
3535
LQViaTransposedQR,
36+
CUSOLVER_Simple,
3637
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized,
3738
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi
3839
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered

src/implementations/eig.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,29 @@ function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
8282
D, V = eig_full!(A, DV, alg.alg)
8383
return truncate!(eig_trunc!, (D, V), alg.trunc)
8484
end
85+
86+
_gpu_geev!(A::AbstractMatrix, D, V) = throw(MethodError(_gpu_geev!, (A, D, V)))
87+
88+
function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm)
89+
check_input(eig_full!, A, DV, alg)
90+
D, V = DV
91+
if alg isa GPU_Simple
92+
isempty(alg.kwargs) ||
93+
throw(ArgumentError("GPU_Simple (geev) does not accept any keyword arguments"))
94+
_gpu_geev!(A, D.diag, V)
95+
end
96+
# TODO: make this controllable using a `gaugefix` keyword argument
97+
V = gaugefix!(V)
98+
return D, V
99+
end
100+
101+
function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm)
102+
check_input(eig_vals!, A, D, alg)
103+
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
104+
if alg isa GPU_Simple
105+
isempty(alg.kwargs) ||
106+
throw(ArgumentError("LAPACK_Simple (geev) does not accept any keyword arguments"))
107+
_gpu_geev!(A, D, V)
108+
end
109+
return D
110+
end

src/interface/decompositions.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,15 @@ a general matrix using the randomized SVD algorithm.
164164
"""
165165
@algdef CUSOLVER_Randomized
166166

167+
"""
168+
CUSOLVER_Simple()
169+
170+
Algorithm type to denote the simple CUSOLVER driver for computing the non-Hermitian
171+
eigenvalue decomposition of a matrix.
172+
"""
173+
@algdef CUSOLVER_Simple
174+
175+
const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple}
167176
# =========================
168177
# ROCSOLVER ALGORITHMS
169178
# =========================
@@ -192,3 +201,6 @@ Algorithm type to denote the ROCSOLVER driver for computing the singular value d
192201
a general matrix using the Jacobi algorithm.
193202
"""
194203
@algdef ROCSOLVER_Jacobi
204+
205+
const GPU_Simple = Union{CUSOLVER_Simple}
206+
const GPU_EigAlgorithm = Union{GPU_Simple}

test/cuda/eig.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using MatrixAlgebraKit
2+
using LinearAlgebra: Diagonal
3+
using Test
4+
using TestExtras
5+
using StableRNGs
6+
using CUDA
7+
8+
include(joinpath("..", "utilities.jl"))
9+
10+
@testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
11+
rng = StableRNG(123)
12+
m = 54
13+
for alg in (CUSOLVER_Simple(), :CUSOLVER_Simple, CUSOLVER_Simple)
14+
A = CuArray(randn(rng, T, m, m))
15+
Tc = complex(T)
16+
17+
D, V = @constinferred eig_full(A; alg=($alg))
18+
@test eltype(D) == eltype(V) == Tc
19+
@test A * V V * D
20+
21+
alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg)
22+
23+
Ac = similar(A)
24+
D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′)
25+
@test D2 === D
26+
@test V2 === V
27+
@test A * V V * D
28+
29+
Dc = @constinferred eig_vals(A, alg′)
30+
@test eltype(Dc) == Tc
31+
@test parent(D) Dc
32+
end
33+
end
34+
35+
#=
36+
@testset "eig_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
37+
rng = StableRNG(123)
38+
m = 54
39+
for alg in (CUSOLVER_Simple(),)
40+
A = CuArray(randn(rng, T, m, m))
41+
A *= A' # TODO: deal with eigenvalue ordering etc
42+
# eigenvalues are sorted by ascending real component...
43+
D₀ = sort!(eig_vals(A); by=abs, rev=true)
44+
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
45+
r = length(D₀) - rmin
46+
47+
D1, V1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r))
48+
@test length(D1.diag) == r
49+
@test A * V1 ≈ V1 * D1
50+
51+
s = 1 + sqrt(eps(real(T)))
52+
trunc = trunctol(s * abs(D₀[r + 1]))
53+
D2, V2 = @constinferred eig_trunc(A; alg, trunc)
54+
@test length(diagview(D2)) == r
55+
@test A * V2 ≈ V2 * D2
56+
57+
# trunctol keeps order, truncrank might not
58+
# test for same subspace
59+
@test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2
60+
@test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1
61+
end
62+
end
63+
=#

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ if CUDA.functional()
6363
@safetestset "CUDA SVD" begin
6464
include("cuda/svd.jl")
6565
end
66+
@safetestset "CUDA General Eigenvalue Decomposition" begin
67+
include("cuda/eig.jl")
68+
end
6669
end
6770

6871
using AMDGPU

0 commit comments

Comments
 (0)