Skip to content

Commit 1bf4dbf

Browse files
author
Katharine Hyatt
committed
Add generalized eigenvalue decomposition, fix some bugs
1 parent 4d091cb commit 1bf4dbf

8 files changed

Lines changed: 312 additions & 18 deletions

File tree

src/MatrixAlgebraKit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ export eigh_full, eigh_vals, eigh_trunc
1919
export eigh_full!, eigh_vals!, eigh_trunc!
2020
export eig_full, eig_vals, eig_trunc
2121
export eig_full!, eig_vals!, eig_trunc!
22+
export gen_eig_full, gen_eig_vals
23+
export gen_eig_full!, gen_eig_vals!
2224
export schur_full, schur_vals
2325
export schur_full!, schur_vals!
2426
export left_polar, right_polar
@@ -54,6 +56,7 @@ include("interface/lq.jl")
5456
include("interface/svd.jl")
5557
include("interface/eig.jl")
5658
include("interface/eigh.jl")
59+
include("interface/gen_eig.jl")
5760
include("interface/schur.jl")
5861
include("interface/polar.jl")
5962
include("interface/orthnull.jl")
@@ -64,6 +67,7 @@ include("implementations/lq.jl")
6467
include("implementations/svd.jl")
6568
include("implementations/eig.jl")
6669
include("implementations/eigh.jl")
70+
include("implementations/gen_eig.jl")
6771
include("implementations/schur.jl")
6872
include("implementations/polar.jl")
6973
include("implementations/orthnull.jl")

src/algorithms.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,14 @@ explicitly.
106106
New types should prefer to register their default algorithms in the type domain.
107107
""" default_algorithm
108108
default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
109+
default_algorithm(f::F, A, B; kwargs...) where {F} = default_algorithm(f, typeof(A), typeof(B); kwargs...)
109110
# avoid infinite recursion:
110111
function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
111112
throw(MethodError(default_algorithm, (f, T)))
112113
end
114+
function default_algorithm(f::F, ::Type{TA}, ::Type{TB}; kwargs...) where {F,TA,TB}
115+
throw(MethodError(default_algorithm, (f, TA, TB)))
116+
end
113117

114118
@doc """
115119
copy_input(f, A)
@@ -177,6 +181,8 @@ macro functiondef(f)
177181
# out of place to inplace
178182
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
179183
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
184+
$f(A, B; kwargs...) = $f!(copy_input($f, A, B)...; kwargs...)
185+
$f(A, B, alg::AbstractAlgorithm) = $f!(copy_input($f, A, B)..., alg)
180186

181187
# fill in arguments
182188
function $f!(A; alg=nothing, kwargs...)
@@ -185,6 +191,12 @@ macro functiondef(f)
185191
function $f!(A, out; alg=nothing, kwargs...)
186192
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
187193
end
194+
function $f!(A, B, out; alg=nothing, kwargs...)
195+
return $f!(A, B, out, select_algorithm($f!, (A, B), alg; kwargs...))
196+
end
197+
function $f!(A, B, alg::AbstractAlgorithm)
198+
return $f!(A, B, initialize_output($f!, A, B, alg), alg)
199+
end
188200
function $f!(A, alg::AbstractAlgorithm)
189201
return $f!(A, initialize_output($f!, A, alg), alg)
190202
end
@@ -198,6 +210,9 @@ macro functiondef(f)
198210
@inline function default_algorithm(::typeof($f), A; kwargs...)
199211
return default_algorithm($f!, A; kwargs...)
200212
end
213+
@inline function default_algorithm(::typeof($f), A, B; kwargs...)
214+
return default_algorithm($f!, A, B; kwargs...)
215+
end
201216
# define default algorithm fallbacks for out-of-place functions
202217
# in terms of the corresponding in-place function for types,
203218
# in principle this is covered by the definition above but
@@ -211,6 +226,9 @@ macro functiondef(f)
211226
@inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
212227
return default_algorithm($f!, A; kwargs...)
213228
end
229+
@inline function default_algorithm(::typeof($f), ::Type{A}, ::Type{B}; kwargs...) where {A, B}
230+
return default_algorithm($f!, A, B; kwargs...)
231+
end
214232

215233
# copy documentation to both functions
216234
Core.@__doc__ $f, $f!

src/implementations/gen_eig.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Inputs
2+
# ------
3+
function copy_input(::typeof(gen_eig_full), A::AbstractMatrix, B::AbstractMatrix)
4+
return copy!(similar(A, float(eltype(A))), A), copy!(similar(B, float(eltype(B))), B)
5+
end
6+
function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix)
7+
return copy_input(gen_eig_full, A, B)
8+
end
9+
10+
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV)
11+
ma, na = size(A)
12+
mb, nb = size(B)
13+
ma == na || throw(DimensionMismatch("square input matrix A expected"))
14+
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
15+
ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match"))
16+
na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match"))
17+
W, V = WV
18+
@assert W isa Diagonal && V isa AbstractMatrix
19+
@check_size(W, (ma, ma))
20+
@check_scalar(W, A, complex)
21+
@check_scalar(W, B, complex)
22+
@check_size(V, (ma, ma))
23+
@check_scalar(V, A, complex)
24+
@check_scalar(V, B, complex)
25+
return nothing
26+
end
27+
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W)
28+
ma, na = size(A)
29+
mb, nb = size(B)
30+
ma == na || throw(DimensionMismatch("square input matrix A expected"))
31+
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
32+
ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match"))
33+
na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match"))
34+
@assert W isa AbstractVector
35+
@check_size(W, (na,))
36+
@check_scalar(W, A, complex)
37+
@check_scalar(W, B, complex)
38+
return nothing
39+
end
40+
41+
# Outputs
42+
# -------
43+
function initialize_output(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, ::LAPACK_EigAlgorithm)
44+
n = size(A, 1) # square check will happen later
45+
Tc = complex(eltype(A))
46+
W = Diagonal(similar(A, Tc, n))
47+
V = similar(A, Tc, (n, n))
48+
return (W, V)
49+
end
50+
function initialize_output(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, ::LAPACK_EigAlgorithm)
51+
n = size(A, 1) # square check will happen later
52+
Tc = complex(eltype(A))
53+
D = similar(A, Tc, n)
54+
return D
55+
end
56+
57+
# Implementation
58+
# --------------
59+
# actual implementation
60+
function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm)
61+
check_input(gen_eig_full!, A, B, WV)
62+
W, V = WV
63+
if alg isa LAPACK_Simple
64+
isempty(alg.kwargs) ||
65+
throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments"))
66+
YALAPACK.ggev!(A, B, W.diag, V, similar(W.diag, eltype(A)))
67+
else # alg isa LAPACK_Expert
68+
throw(ArgumentError("LAPACK_Expert is not supported for ggev"))
69+
end
70+
# TODO: make this controllable using a `gaugefix` keyword argument
71+
for j in 1:size(V, 2)
72+
v = view(V, :, j)
73+
s = conj(sign(argmax(abs, v)))
74+
v .*= s
75+
end
76+
return W, V
77+
end
78+
79+
function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm)
80+
check_input(gen_eig_vals!, A, B, W)
81+
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
82+
if alg isa LAPACK_Simple
83+
isempty(alg.kwargs) ||
84+
throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments"))
85+
YALAPACK.ggev!(A, B, W, V, similar(W, eltype(A)))
86+
else # alg isa LAPACK_Expert
87+
throw(ArgumentError("LAPACK_Expert is not supported for ggev"))
88+
end
89+
return W
90+
end

src/interface/eig.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ and the diagonal matrix `D` contains the associated eigenvalues.
3434
The bang method `eig_full!` optionally accepts the output structure and
3535
possibly destroys the input matrix `A`. Always use the return value of the function
3636
as it may not always be possible to use the provided `DV` as output.
37+
38+
eig_full(A, B; kwargs...) -> W, V
39+
eig_full(A, B, alg::AbstractAlgorithm) -> W, V
40+
eig_full!(A, B, [WV]; kwargs...) -> W, V
41+
eig_full!(A, B, [WV], alg::AbstractAlgorithm) -> W, V
42+
43+
Compute the full generalized eigenvalue decomposition of the square matrices `A` and `B`,
44+
such that `A * V = B * V * W`, where the invertible matrix `V` contains the generalized eigenvectors
45+
and the diagonal matrix `W` contains the associated generalized eigenvalues.
46+
47+
!!! note
48+
The bang method `eig_full!` optionally accepts the output structure and
49+
possibly destroys the input matrices `A` and `B`.
50+
Always use the return value of the function as it may not always be
51+
possible to use the provided `WV` as output.
3752
3853
!!! note
3954
$(docs_eig_note)
@@ -72,11 +87,19 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_vals(!)`](@ref eig_vals).
7287
eig_vals!(A, [D], alg::AbstractAlgorithm) -> D
7388
7489
Compute the list of eigenvalues of `A`.
90+
91+
eig_vals(A, B; kwargs...) -> W
92+
eig_vals(A, B, alg::AbstractAlgorithm) -> W
93+
eig_vals!(A, B, [W]; kwargs...) -> W
94+
eig_vals!(A, B, [W], alg::AbstractAlgorithm) -> W
95+
96+
Compute the list of generalized eigenvalues of `A` and `B`.
7597
7698
!!! note
7799
The bang method `eig_vals!` optionally accepts the output structure and
78-
possibly destroys the input matrix `A`. Always use the return value of the function
79-
as it may not always be possible to use the provided `D` as output.
100+
possibly destroys the input matrices `A` and `B`. Always use the return
101+
value of the function as it may not always be possible to use the
102+
provided `W` as output.
80103
81104
!!! note
82105
$(docs_eig_note)
@@ -92,11 +115,17 @@ default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algori
92115
function default_eig_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
93116
return LAPACK_Expert(; kwargs...)
94117
end
118+
function default_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA<:YALAPACK.BlasMat,TB<:YALAPACK.BlasMat}
119+
return LAPACK_Simple(; kwargs...)
120+
end
95121

96122
for f in (:eig_full!, :eig_vals!)
97123
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
98124
return default_eig_algorithm(A; kwargs...)
99125
end
126+
@eval function default_algorithm(::typeof($f), ::Type{A}, ::Type{B}; kwargs...) where {A, B}
127+
return default_eig_algorithm(A, B; kwargs...)
128+
end
100129
end
101130

102131
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)

src/interface/gen_eig.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Gen Eig API
2+
# -------
3+
function gen_eig!(A::AbstractMatrix, B::AbstractMatrix, args...; kwargs...)
4+
return gen_eig_full!(A, B, args...; kwargs...)
5+
end
6+
function gen_eig(A::AbstractMatrix, B::AbstractMatrix, args...; kwargs...)
7+
return gen_eig_full(A, B, args...; kwargs...)
8+
end
9+
10+
# Gen Eig functions
11+
# -------------
12+
13+
# TODO: kwargs for sorting eigenvalues?
14+
15+
docs_gen_eig_note = """
16+
Note that [`gen_eig_full`](@ref) and its variants do not assume additional structure on the inputs,
17+
and therefore will always return complex eigenvalues and eigenvectors. For the real
18+
generalized eigenvalue decomposition is not yet supported.
19+
"""
20+
21+
# TODO: do we need "full"?
22+
"""
23+
gen_eig_full(A, B; kwargs...) -> W, V
24+
gen_eig_full(A, B, alg::AbstractAlgorithm) -> W, V
25+
gen_eig_full!(A, B, [WV]; kwargs...) -> W, V
26+
gen_eig_full!(A, B, [WV], alg::AbstractAlgorithm) -> W, V
27+
28+
Compute the full generalized eigenvalue decomposition of the square matrices `A` and `B`,
29+
such that `A * V = B * V * W`, where the invertible matrix `V` contains the generalized eigenvectors
30+
and the diagonal matrix `W` contains the associated generalized eigenvalues.
31+
32+
!!! note
33+
The bang method `gen_eig_full!` optionally accepts the output structure and
34+
possibly destroys the input matrices `A` and `B`.
35+
Always use the return value of the function as it may not always be
36+
possible to use the provided `WV` as output.
37+
38+
!!! note
39+
$(docs_gen_eig_note)
40+
41+
See also [`gen_eig_vals(!)`](@ref eig_vals).
42+
"""
43+
@functiondef gen_eig_full
44+
45+
"""
46+
gen_eig_vals(A, B; kwargs...) -> W
47+
gen_eig_vals(A, B, alg::AbstractAlgorithm) -> W
48+
gen_eig_vals!(A, B, [W]; kwargs...) -> W
49+
gen_eig_vals!(A, B, [W], alg::AbstractAlgorithm) -> W
50+
51+
Compute the list of generalized eigenvalues of `A` and `B`.
52+
53+
!!! note
54+
The bang method `gen_eig_vals!` optionally accepts the output structure and
55+
possibly destroys the input matrices `A` and `B`. Always use the return
56+
value of the function as it may not always be possible to use the
57+
provided `W` as output.
58+
59+
!!! note
60+
$(docs_gen_eig_note)
61+
62+
See also [`gen_eig_full(!)`](@ref gen_eig_full).
63+
"""
64+
@functiondef gen_eig_vals
65+
66+
# Algorithm selection
67+
# -------------------
68+
default_gen_eig_algorithm(A, B; kwargs...) = default_gen_eig_algorithm(typeof(A), typeof(B); kwargs...)
69+
default_gen_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA, TB} = throw(MethodError(default_gen_eig_algorithm, (TA,TB)))
70+
function default_gen_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA<:YALAPACK.BlasMat,TB<:YALAPACK.BlasMat}
71+
return LAPACK_Simple(; kwargs...)
72+
end
73+
74+
for f in (:gen_eig_full!, :gen_eig_vals!)
75+
@eval function default_algorithm(::typeof($f), ::Tuple{A, B}; kwargs...) where {A, B}
76+
return default_gen_eig_algorithm(A, B; kwargs...)
77+
end
78+
end

0 commit comments

Comments
 (0)