Skip to content

Commit 1e0610d

Browse files
committed
rework tests
1 parent 421e97d commit 1e0610d

3 files changed

Lines changed: 26 additions & 10 deletions

File tree

src/common/matrixproperties.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ end
8080
ishermitian_exact(A) = A == A'
8181
ishermitian_exact(A::StridedMatrix; kwargs...) = strided_ishermitian_exact(A, Val(false); kwargs...)
8282
function ishermitian_approx(A; atol, rtol, kwargs...)
83-
return 2 * norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
83+
return norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
8484
end
8585
ishermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(false); kwargs...)
8686

@@ -104,7 +104,7 @@ function isantihermitian_exact(A::StridedMatrix; kwargs...)
104104
return strided_ishermitian_exact(A, Val(true); kwargs...)
105105
end
106106
function isantihermitian_approx(A; atol, rtol, kwargs...)
107-
return 2 * norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
107+
return norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
108108
end
109109
isantihermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(true); kwargs...)
110110

@@ -159,7 +159,7 @@ function strided_ishermitian_approx(
159159
ϵ² < ϵ²max || return false
160160
for i in 1:blocksize:(j - 1)
161161
ib = blocksize
162-
ϵ² += _ishermitian_approx_offdiag(
162+
ϵ² += 2 * _ishermitian_approx_offdiag(
163163
view(A, i:(i + ib - 1), j:(j + jb - 1)),
164164
view(A, j:(j + jb - 1), i:(i + ib - 1)),
165165
anti
@@ -175,7 +175,7 @@ function _ishermitian_approx_diag(A, ::Val{anti}) where {anti}
175175
ϵ² = abs2(zero(eltype(A)))
176176
@inbounds for j in 1:n
177177
@simd for i in 1:j
178-
val = anti ? (A[i, j] + adjoint(A[j, i])) : (A[i, j] - adjoint(A[j, i]))
178+
val = _project_hermitian(A[i, j], A[j, i], !anti)
179179
ϵ² += abs2(val) * (1 + Int(i < j))
180180
end
181181
end
@@ -186,9 +186,9 @@ function _ishermitian_approx_offdiag(Al, Au, ::Val{anti}) where {anti}
186186
ϵ² = abs2(zero(eltype(Al)))
187187
@inbounds for j in 1:n
188188
@simd for i in 1:m
189-
val = anti ? (Al[i, j] + adjoint(Au[j, i])) : (Al[i, j] - adjoint(Au[j, i]))
189+
val = _project_hermitian(Al[i, j], Au[j, i], !anti)
190190
ϵ² += abs2(val)
191191
end
192192
end
193-
return 2ϵ²
193+
return ϵ²
194194
end

src/implementations/projections.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,16 @@ function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::V
8585
return B
8686
end
8787

88+
@inline function _project_hermitian(Aij::Number, Aji::Number, anti::Bool)
89+
return anti ? (Aij - Aji') / 2 : (Aij + Aji') / 2
90+
end
8891
function _project_hermitian_offdiag!(
8992
Au::AbstractMatrix, Al::AbstractMatrix, Bu::AbstractMatrix, Bl::AbstractMatrix, ::Val{anti}
9093
) where {anti}
91-
9294
m, n = size(Au) # == reverse(size(Au))
9395
return @inbounds for j in 1:n
9496
@simd for i in 1:m
95-
val = anti ? (Au[i, j] - adjoint(Al[j, i])) / 2 : (Au[i, j] + adjoint(Al[j, i])) / 2
97+
val = _project_hermitian(Au[i, j], Al[j, i], anti)
9698
Bu[i, j] = val
9799
aval = adjoint(val)
98100
Bl[j, i] = anti ? -aval : aval
@@ -104,7 +106,7 @@ function _project_hermitian_diag!(A::AbstractMatrix, B::AbstractMatrix, ::Val{an
104106
n = size(A, 1)
105107
@inbounds for j in 1:n
106108
@simd for i in 1:(j - 1)
107-
val = anti ? (A[i, j] - adjoint(A[j, i])) / 2 : (A[i, j] + adjoint(A[j, i])) / 2
109+
val = _project_hermitian(A[i, j], A[j, i], anti)
108110
B[i, j] = val
109111
aval = adjoint(val)
110112
B[j, i] = anti ? -aval : aval

test/projections.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using StableRNGs
5-
using LinearAlgebra: LinearAlgebra, Diagonal, norm
5+
using LinearAlgebra: LinearAlgebra, Diagonal, norm, normalize!
66

77
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
88

@@ -43,6 +43,20 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
4343
@test isantihermitian(Ba)
4444
@test Ba Aa
4545
end
46+
47+
# test approximate error calculation
48+
A = normalize!(randn(rng, T, m, m))
49+
Ah = project_hermitian(A)
50+
Ah_approx = Ah + noisefactor * Aa
51+
ϵ = norm(project_antihermitian(Ah_approx))
52+
@test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ)
53+
@test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ)
54+
55+
Aa = project_antihermitian(A)
56+
Aa_approx = Aa + noisefactor * Ah
57+
ϵ = norm(project_hermitian(Aa_approx))
58+
@test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ)
59+
@test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ)
4660
end
4761

4862
@testset "project_isometric! for T = $T" for T in BLASFloats

0 commit comments

Comments
 (0)