Skip to content

Commit e35c55e

Browse files
committed
Add Diagonal eigh implementation and tests
1 parent 32f6978 commit e35c55e

3 files changed

Lines changed: 79 additions & 6 deletions

File tree

src/implementations/eigh.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix)
88
end
99
copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)
1010

11+
copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)
12+
1113
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
1214
m, n = size(A)
1315
m == n || throw(DimensionMismatch("square input matrix expected"))
@@ -21,6 +23,27 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgo
2123
end
2224
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
2325
m, n = size(A)
26+
m == n || throw(DimensionMismatch("square input matrix expected"))
27+
@assert D isa AbstractVector
28+
@check_size(D, (n,))
29+
@check_scalar(D, A, real)
30+
return nothing
31+
end
32+
33+
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
34+
m, n = size(A)
35+
@assert m == n && isdiag(A)
36+
D, V = DV
37+
@assert D isa Diagonal && V isa Diagonal
38+
@check_size(D, (m, m))
39+
@check_scalar(D, A, real)
40+
@check_size(V, (m, m))
41+
@check_scalar(V, A)
42+
return nothing
43+
end
44+
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
45+
m, n = size(A)
46+
@assert m == n && isdiag(A)
2447
@assert D isa AbstractVector
2548
@check_size(D, (n,))
2649
@check_scalar(D, A, real)
@@ -45,6 +68,13 @@ function initialize_output(::typeof(eigh_trunc!), A::AbstractMatrix,
4568
return initialize_output(eigh_full!, A, alg.alg)
4669
end
4770

71+
function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm)
72+
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A)
73+
end
74+
function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm)
75+
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
76+
end
77+
4878
# Implementation
4979
# --------------
5080
function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
@@ -85,6 +115,25 @@ function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
85115
return truncate!(eigh_trunc!, (D, V), alg.trunc)
86116
end
87117

118+
# Diagonal logic
119+
# --------------
120+
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
121+
check_input(eigh_full!, A, DV, alg)
122+
D, V = DV
123+
D === A || (diagview(D) .= real.(diagview(A)))
124+
one!(V)
125+
return D, V
126+
end
127+
128+
function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm)
129+
check_input(eigh_vals!, A, D, alg)
130+
Ad = diagview(A)
131+
D === Ad || (D .= real.(Ad))
132+
return D
133+
end
134+
135+
# GPU logic
136+
# ---------
88137
_gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevj!, (A, Dd, V)))
89138
_gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevd!, (A, Dd, V)))
90139
_gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heev!, (A, Dd, V)))

src/interface/eigh.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ end
9393
function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
9494
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
9595
end
96+
function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:Diagonal}
97+
return DiagonalAlgorithm(; kwargs...)
98+
end
9699

97100
for f in (:eigh_full!, :eigh_vals!)
98101
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}

test/eigh.jl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ using StableRNGs
55
using LinearAlgebra: LinearAlgebra, Diagonal, I
66
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
77

8-
@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
8+
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
10+
@testset "eigh_full! for T = $T" for T in BLASFloats
911
rng = StableRNG(123)
1012
m = 54
1113
for alg in (LAPACK_MultipleRelativelyRobustRepresentations(),
@@ -29,7 +31,7 @@ using MatrixAlgebraKit: TruncatedAlgorithm, diagview
2931
end
3032
end
3133

32-
@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
34+
@testset "eigh_trunc! for T = $T" for T in BLASFloats
3335
rng = StableRNG(123)
3436
m = 54
3537
for alg in (LAPACK_QRIteration(),
@@ -62,10 +64,7 @@ end
6264
end
6365
end
6466

65-
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
66-
(Float32, Float64,
67-
ComplexF32,
68-
ComplexF64)
67+
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats
6968
rng = StableRNG(123)
7069
m = 4
7170
V = qr_compact(randn(rng, T, m, m))[1]
@@ -77,3 +76,25 @@ end
7776
@test diagview(D2) diagview(D)[1:2] rtol = sqrt(eps(real(T)))
7877
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
7978
end
79+
80+
@testset "eigh for Diagonal{$T}" for T in BLASFloats
81+
rng = StableRNG(123)
82+
m = 54
83+
Ad = randn(rng, T, m)
84+
Ad .+= conj.(Ad)
85+
A = Diagonal(Ad)
86+
87+
D, V = @constinferred eigh_full(A)
88+
@test D isa Diagonal{real(T)} && size(D) == size(A)
89+
@test V isa Diagonal{T} && size(V) == size(A)
90+
@test A * V V * D
91+
92+
D2 = @constinferred eigh_vals(A)
93+
@test D2 isa AbstractVector{real(T)} && length(D2) == m
94+
@test diagview(D) D2
95+
96+
A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
97+
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
98+
D2, V2 = @constinferred eigh_trunc(A2; alg)
99+
@test diagview(D2) diagview(A2)[1:2]
100+
end

0 commit comments

Comments
 (0)