Skip to content

Commit c1753b8

Browse files
committed
Attempt at wrapping AMDGPU eigh
1 parent 49cdee0 commit c1753b8

5 files changed

Lines changed: 227 additions & 38 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR
8-
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
9-
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
8+
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
9+
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
1011
using AMDGPU
1112
using LinearAlgebra
1213
using LinearAlgebra: BlasFloat
@@ -23,6 +24,9 @@ end
2324
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
2425
return ROCSOLVER_QRIteration(; kwargs...)
2526
end
27+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
28+
return ROCSOLVER_DivideAndConquer(; kwargs...)
29+
end
2630

2731
_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
2832
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)
@@ -32,4 +36,8 @@ _gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ:
3236
#_gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
3337
_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
3438

39+
_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...)
40+
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...)
41+
_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...)
42+
_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...)
3543
end

ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl

Lines changed: 132 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module YArocSOLVER
22

33
using LinearAlgebra
4-
using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing
4+
using LinearAlgebra: BlasInt, BlasReal, BlasFloat, checksquare, chkstride1, require_one_based_indexing
55
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo
66

77
using AMDGPU
@@ -475,42 +475,140 @@ end
475475
# return X, info
476476
# end
477477

478-
# for (jname, bname, fname, elty, relty) in
479-
# ((:syevd!, :rocsolverDnSsyevd_bufferSize, :rocsolverDnSsyevd, :Float32, :Float32),
480-
# (:syevd!, :rocsolverDnDsyevd_bufferSize, :rocsolverDnDsyevd, :Float64, :Float64),
481-
# (:heevd!, :rocsolverDnCheevd_bufferSize, :rocsolverDnCheevd, :ComplexF32, :Float32),
482-
# (:heevd!, :rocsolverDnZheevd_bufferSize, :rocsolverDnZheevd, :ComplexF64, :Float64))
483-
# @eval begin
484-
# function $jname(jobz::Char,
485-
# uplo::Char,
486-
# A::StridedROCMatrix{$elty})
487-
# chkuplo(uplo)
488-
# n = checksquare(A)
489-
# lda = max(1, stride(A, 2))
490-
# W = CuArray{$relty}(undef, n)
491-
# dh = rocBLAS.handle()
478+
for (heevd, heev, heevx, heevj, elty, relty) in
479+
((:(rocSOLVER.rocsolver_ssyevd), :(rocSOLVER.rocsolver_ssyev), :(rocSOLVER.rocsolver_ssyevx), :(rocSOLVER.rocsolver_ssyevj), :Float32, :Float32),
480+
(:(rocSOVLER.rocsolver_dsyevd), :(rocSOLVER.rocsolver_dsyev), :(rocSOLVER.rocsolver_dsyevx), :(rocSOLVER.rocsolver_dsyevj), :Float64, :Float64),
481+
(:(rocSOLVER.rocsolver_cheevd), :(rocSOLVER.rocsolver_cheev), :(rocSOLVER.rocsolver_cheevx), :(rocSOLVER.rocsolver_cheevj), :ComplexF32, :Float32),
482+
(:(rocSOLVER.rocsolver_zheevd), :(rocSOLVER.rocsolver_zheev), :(rocSOLVER.rocsolver_zheevx), :(rocSOLVER.rocsolver_zheevj), :ComplexF64, :Float64))
483+
@eval begin
484+
function heevd!(A::StridedROCMatrix{$elty},
485+
W::StridedROCVector{$relty},
486+
V::StridedROCMatrix{$elty};
487+
uplo::Char='U')
488+
chkuplo(uplo)
489+
n = checksquare(A)
490+
lda = max(1, stride(A, 2))
491+
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
492+
if length(V) == 0
493+
jobz = 'N'
494+
else
495+
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
496+
jobz = 'O'
497+
end
498+
dh = rocBLAS.handle()
499+
work = ROCVector{$relty}(undef, n)
500+
dev_info = ROCVector{Cint}(undef, 1)
501+
$heevd(dh, jobz, uplo, n, A, lda, W, work, dev_info)
492502

493-
# function bufferSize()
494-
# out = Ref{Cint}(0)
495-
# $bname(dh, jobz, uplo, n, A, lda, W, out)
496-
# return out[] * sizeof($elty)
497-
# end
503+
info = @allowscalar dev_info[1]
504+
chkargsok(BlasInt(info))
498505

499-
# with_workspace(dh.workspace_gpu, bufferSize) do buffer
500-
# return $fname(dh, jobz, uplo, n, A, lda, W,
501-
# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info)
502-
# end
506+
if jobz == 'O' && V !== A
507+
copy!(V, A)
508+
end
509+
return W, V
510+
end
511+
function heev!(A::StridedROCMatrix{$elty},
512+
W::StridedROCVector{$relty},
513+
V::StridedROCMatrix{$elty};
514+
uplo::Char='U')
515+
chkuplo(uplo)
516+
n = checksquare(A)
517+
lda = max(1, stride(A, 2))
518+
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
519+
if length(V) == 0
520+
jobz = 'N'
521+
else
522+
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
523+
jobz = 'O'
524+
end
525+
dh = rocBLAS.handle()
526+
work = ROCVector{$relty}(undef, n)
527+
dev_info = ROCVector{Cint}(undef, 1)
528+
$heev(dh, jobz, uplo, n, A, lda, W, work, dev_info)
503529

504-
# info = @allowscalar dh.info[1]
505-
# chkargsok(BlasInt(info))
530+
info = @allowscalar dev_info[1]
531+
chkargsok(BlasInt(info))
506532

507-
# if jobz == 'N'
508-
# return W
509-
# elseif jobz == 'V'
510-
# return W, A
511-
# end
512-
# end
513-
# end
514-
# end
533+
if jobz == 'O' && V !== A
534+
copy!(V, A)
535+
end
536+
return W, V
537+
end
538+
function heevx!(A::StridedROCMatrix{$elty},
539+
W::StridedROCVector{$relty},
540+
V::StridedROCMatrix{$elty};
541+
uplo::Char='U',
542+
kwargs...)
543+
chkuplo(uplo)
544+
n = checksquare(A)
545+
lda = max(1, stride(A, 2))
546+
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
547+
if haskey(kwargs, :irange)
548+
il = first(kwargs[:irange])
549+
iu = last(kwargs[:irange])
550+
vl = vu = zero($relty)
551+
range = 'I'
552+
elseif haskey(kwargs, :vl) || haskey(kwargs, :vu)
553+
vl = convert($relty, get(kwargs, :vl, -Inf))
554+
vu = convert($relty, get(kwargs, :vu, +Inf))
555+
il = iu = 0
556+
range = 'V'
557+
else
558+
il = iu = 0
559+
vl = vu = zero($relty)
560+
range = 'A'
561+
end
562+
if length(V) == 0
563+
jobz = 'N'
564+
else
565+
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
566+
jobz = 'O'
567+
end
568+
dh = rocBLAS.handle()
569+
abstol = -one($relty)
570+
m = Ref{BlasInt}()
571+
ldv = max(1, stride(V, 2))
572+
work = ROCVector{$relty}(undef, n)
573+
ifail = ROCVector{BlasInt}(undef, n)
574+
dev_info = ROCVector{Cint}(undef, 1)
575+
$heevx(dh, jobz, range, uplo, n, A, lda, vl, vu, il, iu, abstol, m, W, V, ldv, ifail, dev_info)
576+
577+
info = @allowscalar dev_info[1]
578+
chkargsok(BlasInt(info))
579+
return W, V, m[]
580+
end
581+
function heevj!(A::StridedROCMatrix{$elty},
582+
W::StridedROCVector{$relty},
583+
V::StridedROCMatrix{$elty};
584+
uplo::Char='U',
585+
tol::$relty=eps($relty),
586+
max_sweeps::Int=100)
587+
chkuplo(uplo)
588+
n = checksquare(A)
589+
lda = max(1, stride(A, 2))
590+
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
591+
if length(V) == 0
592+
jobz = 'N'
593+
else
594+
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
595+
jobz = 'O'
596+
end
597+
dh = rocBLAS.handle()
598+
dev_info = ROCVector{Cint}(undef, 1)
599+
residual = ROCVector{$relty}(undef, 1)
600+
n_sweeps = ROCVector{Cint}(undef, 1)
601+
$heev(dh, jobz, uplo, n, A, lda, abstol, residual, max_sweeps, n_sweeps, W, dev_info)
602+
603+
info = @allowscalar dev_info[1]
604+
chkargsok(BlasInt(info))
605+
606+
if jobz == 'O' && V !== A
607+
copy!(V, A)
608+
end
609+
return W, V
610+
end
611+
end
612+
end
515613

516614
end

src/yalapack.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -921,8 +921,8 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
921921
end
922922
chkuplofinite(A, uplo)
923923
if haskey(kwargs, :irange)
924-
il = first(irange)
925-
iu = last(irange)
924+
il = first(kwargs[:irange])
925+
iu = last(kwargs[:irange])
926926
vl = vu = zero($relty)
927927
range = 'I'
928928
elseif haskey(kwargs, :vl) || haskey(kwargs, :vu)

test/amd/eigh.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, I
6+
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
7+
using AMDGPU
8+
9+
@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
10+
rng = StableRNG(123)
11+
m = 54
12+
for alg in (ROCSOLVER_DivideAndConquer(),
13+
ROCSOLVER_Jacobi(),
14+
ROCSOLVER_Bisection(),
15+
ROCSOLVER_QRIteration(),
16+
)
17+
A = ROCArray(randn(rng, T, m, m))
18+
A = (A + A') / 2
19+
20+
D, V = @constinferred eigh_full(A; alg)
21+
@test A * V V * D
22+
@test isunitary(V)
23+
@test all(isreal, D)
24+
25+
D2, V2 = eigh_full!(copy(A), (D, V), alg)
26+
@test D2 === D
27+
@test V2 === V
28+
29+
D3 = @constinferred eigh_vals(A, alg)
30+
@test parent(D) D3
31+
end
32+
end
33+
34+
#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
35+
rng = StableRNG(123)
36+
m = 54
37+
for alg in (CUSOLVER_QRIteration(),
38+
CUSOLVER_DivideAndConquer(),
39+
)
40+
A = ROCArray(randn(rng, T, m, m))
41+
A = A * A'
42+
A = (A + A') / 2
43+
Ac = similar(A)
44+
D₀ = reverse(eigh_vals(A))
45+
r = m - 2
46+
s = 1 + sqrt(eps(real(T)))
47+
48+
D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r))
49+
@test length(diagview(D1)) == r
50+
@test isisometry(V1)
51+
@test A * V1 ≈ V1 * D1
52+
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1]
53+
54+
trunc = trunctol(s * D₀[r + 1])
55+
D2, V2 = @constinferred eigh_trunc(A; alg, trunc)
56+
@test length(diagview(D2)) == r
57+
@test isisometry(V2)
58+
@test A * V2 ≈ V2 * D2
59+
60+
# test for same subspace
61+
@test V1 * (V1' * V2) ≈ V2
62+
@test V2 * (V2' * V1) ≈ V1
63+
end
64+
end
65+
66+
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
67+
(Float32, Float64,
68+
ComplexF32,
69+
ComplexF64)
70+
rng = StableRNG(123)
71+
m = 4
72+
V = qr_compact(ROCArray(randn(rng, T, m, m)))[1]
73+
D = Diagonal([0.9, 0.3, 0.1, 0.01])
74+
A = V * D * V'
75+
A = (A + A') / 2
76+
alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2))
77+
D2, V2 = @constinferred eigh_trunc(A; alg)
78+
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
79+
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
80+
end=#

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,7 @@ if AMDGPU.functional()
8282
@safetestset "AMDGPU SVD" begin
8383
include("amd/svd.jl")
8484
end
85+
@safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin
86+
include("amd/eigh.jl")
87+
end
8588
end

0 commit comments

Comments
 (0)