@@ -8,6 +8,8 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix)
88end
99copy_input (:: typeof (eigh_trunc), A) = copy_input (eigh_full, A)
1010
11+ copy_input (:: typeof (eigh_full), A:: Diagonal ) = copy (A)
12+
1113function check_input (:: typeof (eigh_full!), A:: AbstractMatrix , DV, :: AbstractAlgorithm )
1214 m, n = size (A)
1315 m == n || throw (DimensionMismatch (" square input matrix expected" ))
@@ -21,6 +23,27 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgo
2123end
2224function check_input (:: typeof (eigh_vals!), A:: AbstractMatrix , D, :: AbstractAlgorithm )
2325 m, n = size (A)
26+ m == n || throw (DimensionMismatch (" square input matrix expected" ))
27+ @assert D isa AbstractVector
28+ @check_size (D, (n,))
29+ @check_scalar (D, A, real)
30+ return nothing
31+ end
32+
33+ function check_input (:: typeof (eigh_full!), A:: AbstractMatrix , DV, :: DiagonalAlgorithm )
34+ m, n = size (A)
35+ @assert m == n && isdiag (A)
36+ D, V = DV
37+ @assert D isa Diagonal && V isa Diagonal
38+ @check_size (D, (m, m))
39+ @check_scalar (D, A, real)
40+ @check_size (V, (m, m))
41+ @check_scalar (V, A)
42+ return nothing
43+ end
44+ function check_input (:: typeof (eigh_vals!), A:: AbstractMatrix , D, :: DiagonalAlgorithm )
45+ m, n = size (A)
46+ @assert m == n && isdiag (A)
2447 @assert D isa AbstractVector
2548 @check_size (D, (n,))
2649 @check_scalar (D, A, real)
@@ -45,6 +68,13 @@ function initialize_output(::typeof(eigh_trunc!), A::AbstractMatrix,
4568 return initialize_output (eigh_full!, A, alg. alg)
4669end
4770
71+ function initialize_output (:: typeof (eigh_full!), A:: Diagonal , :: DiagonalAlgorithm )
72+ return eltype (A) <: Real ? A : similar (A, real (eltype (A))), similar (A)
73+ end
74+ function initialize_output (:: typeof (eigh_vals!), A:: Diagonal , :: DiagonalAlgorithm )
75+ return eltype (A) <: Real ? diagview (A) : similar (A, real (eltype (A)), size (A, 1 ))
76+ end
77+
4878# Implementation
4979# --------------
5080function eigh_full! (A:: AbstractMatrix , DV, alg:: LAPACK_EighAlgorithm )
@@ -85,6 +115,25 @@ function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
85115 return truncate! (eigh_trunc!, (D, V), alg. trunc)
86116end
87117
118+ # Diagonal logic
119+ # --------------
120+ function eigh_full! (A:: Diagonal , DV, alg:: DiagonalAlgorithm )
121+ check_input (eigh_full!, A, DV, alg)
122+ D, V = DV
123+ D === A || (diagview (D) .= real .(diagview (A)))
124+ one! (V)
125+ return D, V
126+ end
127+
128+ function eigh_vals! (A:: Diagonal , D, alg:: DiagonalAlgorithm )
129+ check_input (eigh_vals!, A, D, alg)
130+ Ad = diagview (A)
131+ D === Ad || (D .= real .(Ad))
132+ return D
133+ end
134+
135+ # GPU logic
136+ # ---------
88137_gpu_heevj! (A:: AbstractMatrix , Dd:: AbstractVector , V:: AbstractMatrix ; kwargs... ) = throw (MethodError (_gpu_heevj!, (A, Dd, V)))
89138_gpu_heevd! (A:: AbstractMatrix , Dd:: AbstractVector , V:: AbstractMatrix ; kwargs... ) = throw (MethodError (_gpu_heevd!, (A, Dd, V)))
90139_gpu_heev! (A:: AbstractMatrix , Dd:: AbstractVector , V:: AbstractMatrix ; kwargs... ) = throw (MethodError (_gpu_heev!, (A, Dd, V)))
0 commit comments