Skip to content

Commit 421e97d

Browse files
committed
refactor check_hermitian
1 parent 0da1e3e commit 421e97d

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

src/implementations/eigh.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@ copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)
88

99
copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)
1010

11-
_hermitian_tol(A, alg) = default_hermitian_tol(A)
12-
_hermitian_tol(A, alg::Algorithm) = get(alg.kwargs, :hermitian_tol, default_hermitian_tol(A))
13-
function check_hermitian(A, alg, context::Symbol)
11+
check_hermitian(A, ::AbstractAlgorithm) = check_hermitian(A)
12+
check_hermitian(A, alg::Algorithm) = check_hermitian(A; atol = get(alg.kwargs, :hermitian_tol, default_hermitian_tol(A)))
13+
function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = 0)
1414
m, n = size(A)
1515
m == n || throw(DimensionMismatch("square input matrix expected"))
16-
ishermitian(A; atol = _hermitian_tol(A, alg)) ||
17-
throw(DomainError(A, "`eigh_$(context)!(A)` was called on a non-hermitian input matrix `A`. Try `eig_$(context)!(A)` or `eigh_$(context)(project_hermitian(A))` instead."))
16+
ishermitian(A; atol, rtol) ||
17+
throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix."))
1818
return nothing
1919
end
2020

2121
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
22-
check_hermitian(A, alg, :full)
22+
check_hermitian(A, alg)
2323
D, V = DV
2424
m = size(A, 1)
2525
@assert D isa Diagonal && V isa AbstractMatrix
@@ -30,7 +30,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractA
3030
return nothing
3131
end
3232
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::AbstractAlgorithm)
33-
check_hermitian(A, alg, :vals)
33+
check_hermitian(A, alg)
3434
m = size(A, 1)
3535
@assert D isa AbstractVector
3636
@check_size(D, (m,))
@@ -39,7 +39,7 @@ function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::AbstractAl
3939
end
4040

4141
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalAlgorithm)
42-
check_hermitian(A, alg, :full)
42+
check_hermitian(A, alg)
4343
@assert isdiag(A)
4444
m = size(A, 1)
4545
D, V = DV
@@ -51,7 +51,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA
5151
return nothing
5252
end
5353
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithm)
54-
check_hermitian(A, alg, :vals)
54+
check_hermitian(A, alg)
5555
@assert isdiag(A)
5656
m = size(A, 1)
5757
@assert D isa AbstractVector

0 commit comments

Comments
 (0)