Skip to content

Commit 4008d91

Browse files
authored
Consistent checksquare usage (#154)
* consistent checksquare usage * remove unused function * fix typo * fix one more typo
1 parent 517cc85 commit 4008d91

4 files changed

Lines changed: 20 additions & 30 deletions

File tree

src/implementations/eig.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ copy_input(::Union{typeof(eig_trunc), typeof(eig_trunc_no_error)}, A) = copy_inp
99
copy_input(::typeof(eig_full), A::Diagonal) = copy(A)
1010

1111
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
12-
m, n = size(A)
13-
m == n || throw(DimensionMismatch("square input matrix expected"))
12+
m = LinearAlgebra.checksquare(A)
1413
D, V = DV
1514
@assert D isa Diagonal && V isa AbstractMatrix
1615
@check_size(D, (m, m))
@@ -20,17 +19,16 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgor
2019
return nothing
2120
end
2221
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
23-
m, n = size(A)
24-
m == n || throw(DimensionMismatch("square input matrix expected"))
22+
m = LinearAlgebra.checksquare(A)
2523
@assert D isa AbstractVector
26-
@check_size(D, (n,))
24+
@check_size(D, (m,))
2725
@check_scalar(D, A, complex)
2826
return nothing
2927
end
3028

3129
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
32-
m, n = size(A)
33-
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
30+
m = LinearAlgebra.checksquare(A)
31+
isdiag(A) || throw(DimensionMismatch("diagonal input matrix expected"))
3432
D, V = DV
3533
@assert D isa Diagonal && V isa AbstractMatrix
3634
@check_size(D, (m, m))
@@ -40,10 +38,10 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgor
4038
return nothing
4139
end
4240
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
43-
m, n = size(A)
44-
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
41+
m = LinearAlgebra.checksquare(A)
42+
isdiag(A) || throw(DimensionMismatch("diagonal input matrix expected"))
4543
@assert D isa AbstractVector
46-
@check_size(D, (n,))
44+
@check_size(D, (m,))
4745
@check_scalar(D, A, complex)
4846
return nothing
4947
end

src/implementations/eigh.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)
1111
check_hermitian(A, ::AbstractAlgorithm) = check_hermitian(A)
1212
check_hermitian(A, alg::Algorithm) = check_hermitian(A; atol = get(alg.kwargs, :hermitian_tol, default_hermitian_tol(A)))
1313
function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = 0)
14-
m, n = size(A)
15-
m == n || throw(DimensionMismatch("square input matrix expected"))
14+
LinearAlgebra.checksquare(A)
1615
ishermitian(A; atol, rtol) ||
1716
throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix."))
1817
return nothing

src/implementations/gen_eig.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@ end
66
copy_input(::typeof(gen_eig_vals), A, B) = copy_input(gen_eig_full, A, B)
77

88
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm)
9-
ma, na = size(A)
10-
mb, nb = size(B)
11-
ma == na || throw(DimensionMismatch("square input matrix A expected"))
12-
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
13-
ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match"))
14-
na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match"))
9+
ma = LinearAlgebra.checksquare(A)
10+
mb = LinearAlgebra.checksquare(B)
11+
ma == mb || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $ma and $mb"))
1512
W, V = WV
1613
@assert W isa Diagonal && V isa AbstractMatrix
1714
@check_size(W, (ma, ma))
@@ -23,13 +20,11 @@ function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatr
2320
return nothing
2421
end
2522
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W, ::AbstractAlgorithm)
26-
ma, na = size(A)
27-
mb, nb = size(B)
28-
ma == na || throw(DimensionMismatch("square input matrix A expected"))
29-
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
30-
ma == mb || throw(DimensionMismatch("dimension of input matrices expected to match"))
23+
ma = LinearAlgebra.checksquare(A)
24+
mb = LinearAlgebra.checksquare(B)
25+
ma == mb || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $ma and $mb"))
3126
@assert W isa AbstractVector
32-
@check_size(W, (na,))
27+
@check_size(W, (ma,))
3328
@check_scalar(W, A, complex)
3429
@check_scalar(W, B, complex)
3530
return nothing

src/implementations/schur.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,21 @@ copy_input(::typeof(schur_vals), A) = copy_input(eig_vals, A)
55

66
# check input
77
function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::AbstractAlgorithm)
8-
m, n = size(A)
9-
m == n || throw(DimensionMismatch("square input matrix expected"))
8+
m = LinearAlgebra.checksquare(A)
109
T, Z, vals = TZv
1110
@assert T isa AbstractMatrix && Z isa AbstractMatrix && vals isa AbstractVector
1211
@check_size(T, (m, m))
1312
@check_scalar(T, A)
1413
@check_size(Z, (m, m))
1514
@check_scalar(Z, A)
16-
@check_size(vals, (n,))
15+
@check_size(vals, (m,))
1716
@check_scalar(vals, A, complex)
1817
return nothing
1918
end
2019
function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals, ::AbstractAlgorithm)
21-
m, n = size(A)
22-
m == n || throw(DimensionMismatch("square input matrix expected"))
20+
m = LinearAlgebra.checksquare(A)
2321
@assert vals isa AbstractVector
24-
@check_size(vals, (n,))
22+
@check_size(vals, (m,))
2523
@check_scalar(vals, A, complex)
2624
return nothing
2725
end

0 commit comments

Comments
 (0)