|
| 1 | +# Inputs |
| 2 | +# ------ |
| 3 | +function copy_input(::typeof(gen_eig_full), A::AbstractMatrix, B::AbstractMatrix) |
| 4 | + return copy!(similar(A, float(eltype(A))), A), copy!(similar(B, float(eltype(B))), B) |
| 5 | +end |
| 6 | +function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix) |
| 7 | + return copy_input(gen_eig_full, A, B) |
| 8 | +end |
| 9 | + |
| 10 | +function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV) |
| 11 | + ma, na = size(A) |
| 12 | + mb, nb = size(B) |
| 13 | + ma == na || throw(DimensionMismatch("square input matrix A expected")) |
| 14 | + mb == nb || throw(DimensionMismatch("square input matrix B expected")) |
| 15 | + ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match")) |
| 16 | + na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match")) |
| 17 | + W, V = WV |
| 18 | + @assert W isa Diagonal && V isa AbstractMatrix |
| 19 | + @check_size(W, (ma, ma)) |
| 20 | + @check_scalar(W, A, complex) |
| 21 | + @check_scalar(W, B, complex) |
| 22 | + @check_size(V, (ma, ma)) |
| 23 | + @check_scalar(V, A, complex) |
| 24 | + @check_scalar(V, B, complex) |
| 25 | + return nothing |
| 26 | +end |
| 27 | +function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W) |
| 28 | + ma, na = size(A) |
| 29 | + mb, nb = size(B) |
| 30 | + ma == na || throw(DimensionMismatch("square input matrix A expected")) |
| 31 | + mb == nb || throw(DimensionMismatch("square input matrix B expected")) |
| 32 | + ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match")) |
| 33 | + na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match")) |
| 34 | + @assert W isa AbstractVector |
| 35 | + @check_size(W, (na,)) |
| 36 | + @check_scalar(W, A, complex) |
| 37 | + @check_scalar(W, B, complex) |
| 38 | + return nothing |
| 39 | +end |
| 40 | + |
| 41 | +# Outputs |
| 42 | +# ------- |
| 43 | +function initialize_output(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, ::LAPACK_EigAlgorithm) |
| 44 | + n = size(A, 1) # square check will happen later |
| 45 | + Tc = complex(eltype(A)) |
| 46 | + W = Diagonal(similar(A, Tc, n)) |
| 47 | + V = similar(A, Tc, (n, n)) |
| 48 | + return (W, V) |
| 49 | +end |
| 50 | +function initialize_output(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, ::LAPACK_EigAlgorithm) |
| 51 | + n = size(A, 1) # square check will happen later |
| 52 | + Tc = complex(eltype(A)) |
| 53 | + D = similar(A, Tc, n) |
| 54 | + return D |
| 55 | +end |
| 56 | + |
| 57 | +# Implementation |
| 58 | +# -------------- |
| 59 | +# actual implementation |
| 60 | +function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm) |
| 61 | + check_input(gen_eig_full!, A, B, WV) |
| 62 | + W, V = WV |
| 63 | + if alg isa LAPACK_Simple |
| 64 | + isempty(alg.kwargs) || |
| 65 | + throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments")) |
| 66 | + YALAPACK.ggev!(A, B, W.diag, V, similar(W.diag, eltype(A))) |
| 67 | + else # alg isa LAPACK_Expert |
| 68 | + throw(ArgumentError("LAPACK_Expert is not supported for ggev")) |
| 69 | + end |
| 70 | + # TODO: make this controllable using a `gaugefix` keyword argument |
| 71 | + for j in 1:size(V, 2) |
| 72 | + v = view(V, :, j) |
| 73 | + s = conj(sign(argmax(abs, v))) |
| 74 | + v .*= s |
| 75 | + end |
| 76 | + return W, V |
| 77 | +end |
| 78 | + |
| 79 | +function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm) |
| 80 | + check_input(gen_eig_vals!, A, B, W) |
| 81 | + V = similar(A, complex(eltype(A)), (size(A, 1), 0)) |
| 82 | + if alg isa LAPACK_Simple |
| 83 | + isempty(alg.kwargs) || |
| 84 | + throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments")) |
| 85 | + YALAPACK.ggev!(A, B, W, V, similar(W, eltype(A))) |
| 86 | + else # alg isa LAPACK_Expert |
| 87 | + throw(ArgumentError("LAPACK_Expert is not supported for ggev")) |
| 88 | + end |
| 89 | + return W |
| 90 | +end |
0 commit comments