Skip to content

Commit 60b48c0

Browse files
committed
Make eigh_vals/full sort for Diagonal
1 parent 29bc9df commit 60b48c0

1 file changed

Lines changed: 15 additions & 4 deletions

File tree

src/implementations/eigh.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA
4343
@assert isdiag(A)
4444
m = size(A, 1)
4545
D, V = DV
46-
@assert D isa Diagonal && V isa Diagonal
46+
@assert D isa Diagonal
4747
@check_size(D, (m, m))
4848
@check_scalar(D, A, real)
4949
@check_size(V, (m, m))
@@ -79,7 +79,7 @@ function initialize_output(::Union{typeof(eigh_trunc!), typeof(eigh_trunc_no_err
7979
end
8080

8181
function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm)
82-
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A)
82+
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A, size(A)...)
8383
end
8484
function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm)
8585
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
@@ -146,15 +146,26 @@ end
146146
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
147147
check_input(eigh_full!, A, DV, alg)
148148
D, V = DV
149-
D === A || (diagview(D) .= real.(diagview(A)))
149+
I = sortperm(real.(diagview(A)))
150+
if D === A
151+
sort!(diagview(A))
152+
else
153+
diagview(D) .= real.(diagview(A))[I]
154+
end
150155
one!(V)
156+
Base.permutecols!!(V, I)
151157
return D, V
152158
end
153159

154160
function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm)
155161
check_input(eigh_vals!, A, D, alg)
156162
Ad = diagview(A)
157-
D === Ad || (D .= real.(Ad))
163+
if D === Ad
164+
sort!(Ad)
165+
else
166+
I = sortperm(real.(Ad))
167+
D .= real.(Ad[I])
168+
end
158169
return D
159170
end
160171

0 commit comments

Comments
 (0)