Skip to content

Commit a8d9401

Browse files
kshyattKatharine Hyattlkdvos
authored
Add generalized eigenvalue decomposition, fix some bugs (#39)
* Add generalized eigenvalue decomposition, fix some bugs * Fix JET by splitting macro up a bit * Coverage improvements * More tests * Split up the functiondef macro * Update src/implementations/gen_eig.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Restore eig file * Make gaugefixation its own function and cleanup exports * Update src/implementations/eig.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Update src/implementations/gen_eig.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Update src/common/gauge.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> --------- Co-authored-by: Katharine Hyatt <katharine.s.hyatt@gmail.com> Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 4d091cb commit a8d9401

10 files changed

Lines changed: 395 additions & 54 deletions

File tree

src/MatrixAlgebraKit.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ export eigh_full, eigh_vals, eigh_trunc
1919
export eigh_full!, eigh_vals!, eigh_trunc!
2020
export eig_full, eig_vals, eig_trunc
2121
export eig_full!, eig_vals!, eig_trunc!
22+
export gen_eig_full, gen_eig_vals
23+
export gen_eig_full!, gen_eig_vals!
2224
export schur_full, schur_vals
2325
export schur_full!, schur_vals!
2426
export left_polar, right_polar
@@ -45,6 +47,7 @@ include("common/safemethods.jl")
4547
include("common/view.jl")
4648
include("common/regularinv.jl")
4749
include("common/matrixproperties.jl")
50+
include("common/gauge.jl")
4851

4952
include("yalapack.jl")
5053
include("algorithms.jl")
@@ -54,6 +57,7 @@ include("interface/lq.jl")
5457
include("interface/svd.jl")
5558
include("interface/eig.jl")
5659
include("interface/eigh.jl")
60+
include("interface/gen_eig.jl")
5761
include("interface/schur.jl")
5862
include("interface/polar.jl")
5963
include("interface/orthnull.jl")
@@ -64,6 +68,7 @@ include("implementations/lq.jl")
6468
include("implementations/svd.jl")
6569
include("implementations/eig.jl")
6670
include("implementations/eigh.jl")
71+
include("implementations/gen_eig.jl")
6772
include("implementations/schur.jl")
6873
include("implementations/polar.jl")
6974
include("implementations/orthnull.jl")

src/algorithms.jl

Lines changed: 110 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,14 @@ explicitly.
106106
New types should prefer to register their default algorithms in the type domain.
107107
""" default_algorithm
108108
default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
109+
default_algorithm(f::F, A, B; kwargs...) where {F} = default_algorithm(f, typeof(A), typeof(B); kwargs...)
109110
# avoid infinite recursion:
110111
function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
111112
throw(MethodError(default_algorithm, (f, T)))
112113
end
114+
function default_algorithm(f::F, ::Type{TA}, ::Type{TB}; kwargs...) where {F,TA,TB}
115+
throw(MethodError(default_algorithm, (f, TA, TB)))
116+
end
113117

114118
@doc """
115119
copy_input(f, A)
@@ -153,28 +157,8 @@ macro algdef(name)
153157
end)
154158
end
155159

156-
"""
157-
@functiondef f
158-
159-
Convenience macro to define the boilerplate code that dispatches between several versions of `f` and `f!`.
160-
By default, this enables the following signatures to be defined in terms of
161-
the final `f!(A, out, alg::Algorithm)`.
162-
163-
```julia
164-
f(A; kwargs...)
165-
f(A, alg::Algorithm)
166-
f!(A, [out]; kwargs...)
167-
f!(A, alg::Algorithm)
168-
```
169-
170-
See also [`copy_input`](@ref), [`select_algorithm`](@ref) and [`initialize_output`](@ref).
171-
"""
172-
macro functiondef(f)
173-
f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`"))
174-
f! = Symbol(f, :!)
175-
176-
ex = quote
177-
# out of place to inplace
160+
function _arg_expr(::Val{1}, f, f!)
161+
return quote # out of place to inplace
178162
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
179163
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
180164

@@ -215,7 +199,110 @@ macro functiondef(f)
215199
# copy documentation to both functions
216200
Core.@__doc__ $f, $f!
217201
end
218-
return esc(ex)
202+
end
203+
204+
function _arg_expr(::Val{2}, f, f!)
205+
return quote
206+
# out of place to inplace
207+
$f(A, B; kwargs...) = $f!(copy_input($f, A, B)...; kwargs...)
208+
$f(A, B, alg::AbstractAlgorithm) = $f!(copy_input($f, A, B)..., alg)
209+
210+
# fill in arguments
211+
function $f!(A, B; alg=nothing, kwargs...)
212+
return $f!(A, B, select_algorithm($f!, (A, B), alg; kwargs...))
213+
end
214+
function $f!(A, B, out; alg=nothing, kwargs...)
215+
return $f!(A, B, out, select_algorithm($f!, (A, B), alg; kwargs...))
216+
end
217+
function $f!(A, B, alg::AbstractAlgorithm)
218+
return $f!(A, B, initialize_output($f!, A, B, alg), alg)
219+
end
220+
221+
# define fallbacks for algorithm selection
222+
@inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg}
223+
return select_algorithm($f!, A, alg; kwargs...)
224+
end
225+
# define default algorithm fallbacks for out-of-place functions
226+
# in terms of the corresponding in-place function
227+
@inline function default_algorithm(::typeof($f), A, B; kwargs...)
228+
return default_algorithm($f!, A, B; kwargs...)
229+
end
230+
# define default algorithm fallbacks for out-of-place functions
231+
# in terms of the corresponding in-place function for types,
232+
# in principle this is covered by the definition above but
233+
# it is necessary to avoid ambiguity errors with the generic definitions:
234+
# ```julia
235+
# default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
236+
# function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
237+
# throw(MethodError(default_algorithm, (f, T)))
238+
# end
239+
# ```
240+
@inline function default_algorithm(::typeof($f), ::Type{A}, ::Type{B}; kwargs...) where {A, B}
241+
return default_algorithm($f!, A, B; kwargs...)
242+
end
243+
244+
# copy documentation to both functions
245+
Core.@__doc__ $f, $f!
246+
end
247+
end
248+
249+
"""
250+
@functiondef [n_args=1] f
251+
252+
Convenience macro to define the boilerplate code that dispatches between several versions of `f` and `f!`.
253+
By default, `f` accepts a single argument `A`. This enables the following signatures to be defined in terms of
254+
the final `f!(A, out, alg::Algorithm)`.
255+
256+
```julia
257+
f(A; kwargs...)
258+
f(A, alg::Algorithm)
259+
f!(A, [out]; kwargs...)
260+
f!(A, alg::Algorithm)
261+
```
262+
263+
The number of inputs can be set with the `n_args` keyword
264+
argument, so that
265+
266+
```julia
267+
@functiondef n_args=2 f
268+
```
269+
270+
would create
271+
272+
```julia
273+
f(A, B; kwargs...)
274+
f(A, B, alg::Algorithm)
275+
f!(A, B, [out]; kwargs...)
276+
f!(A, B, alg::Algorithm)
277+
```
278+
279+
See also [`copy_input`](@ref), [`select_algorithm`](@ref) and [`initialize_output`](@ref).
280+
"""
281+
macro functiondef(args...)
282+
kwargs = map(args[1:end-1]) do kwarg
283+
if kwarg isa Symbol
284+
:($kwarg = $kwarg)
285+
elseif Meta.isexpr(kwarg, :(=))
286+
kwarg
287+
else
288+
throw(ArgumentError("Invalid keyword argument '$kwarg'"))
289+
end
290+
end
291+
isempty(kwargs) || length(kwargs) == 1 || throw(ArgumentError("Only one keyword argument to `@functiondef` is supported"))
292+
f_n_args = 1 # default
293+
if length(kwargs) == 1
294+
kwarg = only(kwargs) # only one kwarg is currently supported, TODO modify if we support more
295+
key::Symbol, val = kwarg.args
296+
key === :n_args || throw(ArgumentError("Unsupported keyword argument $key to `@functiondef`"))
297+
(isa(val, Integer) && val > 0) || throw(ArgumentError("`n_args` keyword argument to `@functiondef` should be an integer > 0"))
298+
f_n_args = val
299+
end
300+
301+
f = args[end]
302+
f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`"))
303+
f! = Symbol(f, :!)
304+
305+
return esc(_arg_expr(Val(f_n_args), f, f!))
219306
end
220307

221308
"""

src/common/gauge.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
function gaugefix!(V::AbstractMatrix)
2+
for j in axes(V, 2)
3+
v = view(V, :, j)
4+
s = conj(sign(argmax(abs, v)))
5+
@inbounds v .*= s
6+
end
7+
return V
8+
end

src/implementations/eig.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
6161
YALAPACK.geevx!(A, D.diag, V; alg.kwargs...)
6262
end
6363
# TODO: make this controllable using a `gaugefix` keyword argument
64-
for j in 1:size(V, 2)
65-
v = view(V, :, j)
66-
s = conj(sign(argmax(abs, v)))
67-
v .*= s
68-
end
64+
V = gaugefix!(V)
6965
return D, V
7066
end
7167

src/implementations/gen_eig.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Inputs
2+
# ------
3+
function copy_input(::typeof(gen_eig_full), A::AbstractMatrix, B::AbstractMatrix)
4+
return copy!(similar(A, float(eltype(A))), A), copy!(similar(B, float(eltype(B))), B)
5+
end
6+
function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix)
7+
return copy_input(gen_eig_full, A, B)
8+
end
9+
10+
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV)
11+
ma, na = size(A)
12+
mb, nb = size(B)
13+
ma == na || throw(DimensionMismatch("square input matrix A expected"))
14+
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
15+
ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match"))
16+
na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match"))
17+
W, V = WV
18+
@assert W isa Diagonal && V isa AbstractMatrix
19+
@check_size(W, (ma, ma))
20+
@check_scalar(W, A, complex)
21+
@check_scalar(W, B, complex)
22+
@check_size(V, (ma, ma))
23+
@check_scalar(V, A, complex)
24+
@check_scalar(V, B, complex)
25+
return nothing
26+
end
27+
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W)
28+
ma, na = size(A)
29+
mb, nb = size(B)
30+
ma == na || throw(DimensionMismatch("square input matrix A expected"))
31+
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
32+
ma == mb || throw(DimensionMismatch("dimension of input matrices expected to match"))
33+
@assert W isa AbstractVector
34+
@check_size(W, (na,))
35+
@check_scalar(W, A, complex)
36+
@check_scalar(W, B, complex)
37+
return nothing
38+
end
39+
40+
# Outputs
41+
# -------
42+
function initialize_output(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, ::LAPACK_EigAlgorithm)
43+
n = size(A, 1) # square check will happen later
44+
Tc = complex(eltype(A))
45+
W = Diagonal(similar(A, Tc, n))
46+
V = similar(A, Tc, (n, n))
47+
return (W, V)
48+
end
49+
function initialize_output(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, ::LAPACK_EigAlgorithm)
50+
n = size(A, 1) # square check will happen later
51+
Tc = complex(eltype(A))
52+
D = similar(A, Tc, n)
53+
return D
54+
end
55+
56+
# Implementation
57+
# --------------
58+
# actual implementation
59+
function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm)
60+
check_input(gen_eig_full!, A, B, WV)
61+
W, V = WV
62+
if alg isa LAPACK_Simple
63+
isempty(alg.kwargs) ||
64+
throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments"))
65+
YALAPACK.ggev!(A, B, W.diag, V, similar(W.diag, eltype(A)))
66+
else # alg isa LAPACK_Expert
67+
throw(ArgumentError("LAPACK_Expert is not supported for ggev"))
68+
end
69+
# TODO: make this controllable using a `gaugefix` keyword argument
70+
V = gaugefix!(V)
71+
return W, V
72+
end
73+
74+
function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm)
75+
check_input(gen_eig_vals!, A, B, W)
76+
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
77+
if alg isa LAPACK_Simple
78+
isempty(alg.kwargs) ||
79+
throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments"))
80+
YALAPACK.ggev!(A, B, W, V, similar(W, eltype(A)))
81+
else # alg isa LAPACK_Expert
82+
throw(ArgumentError("LAPACK_Expert is not supported for ggev"))
83+
end
84+
return W
85+
end

src/interface/eig.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
# Eig API
2-
# -------
3-
# TODO: export? or not export but mark as public ?
4-
function eig!(A::AbstractMatrix, args...; kwargs...)
5-
return eig_full!(A, args...; kwargs...)
6-
end
7-
function eig(A::AbstractMatrix, args...; kwargs...)
8-
return eig_full(A, args...; kwargs...)
9-
end
10-
111
# Eig functions
122
# -------------
133

src/interface/gen_eig.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Gen Eig functions
2+
# -------------
3+
4+
# TODO: kwargs for sorting eigenvalues?
5+
6+
docs_gen_eig_note = """
7+
Note that [`gen_eig_full`](@ref) and its variants do not assume additional structure on the inputs,
8+
and therefore will always return complex eigenvalues and eigenvectors. For the real
9+
generalized eigenvalue decomposition is not yet supported.
10+
"""
11+
12+
# TODO: do we need "full"?
13+
"""
14+
gen_eig_full(A, B; kwargs...) -> W, V
15+
gen_eig_full(A, B, alg::AbstractAlgorithm) -> W, V
16+
gen_eig_full!(A, B, [WV]; kwargs...) -> W, V
17+
gen_eig_full!(A, B, [WV], alg::AbstractAlgorithm) -> W, V
18+
19+
Compute the full generalized eigenvalue decomposition of the square matrices `A` and `B`,
20+
such that `A * V = B * V * W`, where the invertible matrix `V` contains the generalized eigenvectors
21+
and the diagonal matrix `W` contains the associated generalized eigenvalues.
22+
23+
!!! note
24+
The bang method `gen_eig_full!` optionally accepts the output structure and
25+
possibly destroys the input matrices `A` and `B`.
26+
Always use the return value of the function as it may not always be
27+
possible to use the provided `WV` as output.
28+
29+
!!! note
30+
$(docs_gen_eig_note)
31+
32+
See also [`gen_eig_vals(!)`](@ref eig_vals).
33+
"""
34+
@functiondef n_args=2 gen_eig_full
35+
36+
"""
37+
gen_eig_vals(A, B; kwargs...) -> W
38+
gen_eig_vals(A, B, alg::AbstractAlgorithm) -> W
39+
gen_eig_vals!(A, B, [W]; kwargs...) -> W
40+
gen_eig_vals!(A, B, [W], alg::AbstractAlgorithm) -> W
41+
42+
Compute the list of generalized eigenvalues of `A` and `B`.
43+
44+
!!! note
45+
The bang method `gen_eig_vals!` optionally accepts the output structure and
46+
possibly destroys the input matrices `A` and `B`. Always use the return
47+
value of the function as it may not always be possible to use the
48+
provided `W` as output.
49+
50+
!!! note
51+
$(docs_gen_eig_note)
52+
53+
See also [`gen_eig_full(!)`](@ref gen_eig_full).
54+
"""
55+
@functiondef n_args=2 gen_eig_vals
56+
57+
# Algorithm selection
58+
# -------------------
59+
default_gen_eig_algorithm(A, B; kwargs...) = default_gen_eig_algorithm(typeof(A), typeof(B); kwargs...)
60+
default_gen_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA, TB} = throw(MethodError(default_gen_eig_algorithm, (TA,TB)))
61+
function default_gen_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA<:YALAPACK.BlasMat,TB<:YALAPACK.BlasMat}
62+
return LAPACK_Simple(; kwargs...)
63+
end
64+
65+
for f in (:gen_eig_full!, :gen_eig_vals!)
66+
@eval function default_algorithm(::typeof($f), ::Tuple{A, B}; kwargs...) where {A, B}
67+
return default_gen_eig_algorithm(A, B; kwargs...)
68+
end
69+
end

0 commit comments

Comments
 (0)