Skip to content

Commit 07f1ee2

Browse files
lkdvosJutho
andauthored
[Feature] Add ishermitian, isantihermitian, hermitianpart(!) and antihermitianpart(!) (#64)
* Add `hermitianpart!` and `antihermitianpart!` * Add `ishermitian` and `isantihermitian` * Start using in pullbacks * blocked (anti)hermitian * rename to project_ * also change includes * fix polar pullback * add blocked ishermitian * improve matrixproperties * increase coverage * Update src/MatrixAlgebraKit.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * small rename to strided_ishermitian_exact * add back missing import --------- Co-authored-by: Jutho Haegeman <jutho.haegeman@ugent.be> Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent 4ab245c commit 07f1ee2

11 files changed

Lines changed: 300 additions & 19 deletions

File tree

src/MatrixAlgebraKit.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ using LinearAlgebra: LinearAlgebra
44
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
55
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
66
using LinearAlgebra: sylvester
7-
using LinearAlgebra: isposdef, ishermitian, issymmetric
7+
using LinearAlgebra: isposdef, issymmetric
88
using LinearAlgebra: Diagonal, diag, diagind, isdiag
99
using LinearAlgebra: UpperTriangular, LowerTriangular
1010
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt
1111

12-
export isisometry, isunitary
12+
export isisometry, isunitary, ishermitian, isantihermitian
1313

14+
export project_hermitian, project_antihermitian
15+
export project_hermitian!, project_antihermitian!
1416
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
1517
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
1618
export svd_compact, svd_full, svd_vals, svd_trunc
@@ -33,6 +35,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3335
LAPACK_DivideAndConquer, LAPACK_Jacobi
3436
export LQViaTransposedQR
3537
export DiagonalAlgorithm
38+
export NativeBlocked
3639
export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar,
3740
CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer
3841
export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi,
@@ -74,6 +77,7 @@ include("common/gauge.jl")
7477

7578
include("yalapack.jl")
7679
include("algorithms.jl")
80+
include("interface/projections.jl")
7781
include("interface/decompositions.jl")
7882
include("interface/truncation.jl")
7983
include("interface/qr.jl")
@@ -86,6 +90,7 @@ include("interface/schur.jl")
8690
include("interface/polar.jl")
8791
include("interface/orthnull.jl")
8892

93+
include("implementations/projections.jl")
8994
include("implementations/truncation.jl")
9095
include("implementations/qr.jl")
9196
include("implementations/lq.jl")

src/common/matrixproperties.jl

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ function isunitary(A; isapprox_kwargs...)
3131
return is_left_isometry(A; isapprox_kwargs...) &&
3232
is_right_isometry(A; isapprox_kwargs...)
3333
end
34+
function isunitary(A::AbstractMatrix; isapprox_kwargs...)
35+
size(A, 1) == size(A, 2) || return false
36+
return is_left_isometry(A; isapprox_kwargs...)
37+
end
3438

3539
@doc """
3640
is_left_isometry(A; isapprox_kwargs...) -> Bool
@@ -41,8 +45,11 @@ The `isapprox_kwargs` can be used to control the tolerances of the equality.
4145
See also [`isisometry`](@ref) and [`is_right_isometry`](@ref).
4246
""" is_left_isometry
4347

44-
function is_left_isometry(A::AbstractMatrix; isapprox_kwargs...)
45-
return isapprox(A' * A, LinearAlgebra.I; isapprox_kwargs...)
48+
function is_left_isometry(A::AbstractMatrix; atol::Real = 0, rtol::Real = defaulttol(A), norm = LinearAlgebra.norm)
49+
P = A' * A
50+
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
51+
diagview(P) .-= 1
52+
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
4653
end
4754

4855
@doc """
@@ -54,6 +61,86 @@ The `isapprox_kwargs` can be used to control the tolerances of the equality.
5461
See also [`isisometry`](@ref) and [`is_left_isometry`](@ref).
5562
""" is_right_isometry
5663

57-
function is_right_isometry(A::AbstractMatrix; isapprox_kwargs...)
58-
return isapprox(A * A', LinearAlgebra.I; isapprox_kwargs...)
64+
function is_right_isometry(A::AbstractMatrix; atol::Real = 0, rtol::Real = defaulttol(A), norm = LinearAlgebra.norm)
65+
P = A * A'
66+
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
67+
diagview(P) .-= 1
68+
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
69+
end
70+
71+
"""
72+
ishermitian(A; isapprox_kwargs...)
73+
74+
Test whether a linear map is Hermitian, i.e. `A = A'`.
75+
The `isapprox_kwargs` can be used to control the tolerances of the equality.
76+
"""
77+
function ishermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...)
78+
if iszero(atol) && iszero(rtol)
79+
return ishermitian_exact(A; kwargs...)
80+
else
81+
return 2 * norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
82+
end
83+
end
84+
function ishermitian_exact(A)
85+
return A == A'
86+
end
87+
function ishermitian_exact(A::StridedMatrix; kwargs...)
88+
return strided_ishermitian_exact(A, Val(false); kwargs...)
89+
end
90+
91+
"""
92+
isantihermitian(A; isapprox_kwargs...)
93+
94+
Test whether a linear map is anti-Hermitian, i.e. `A = -A'`.
95+
The `isapprox_kwargs` can be used to control the tolerances of the equality.
96+
"""
97+
function isantihermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...)
98+
if iszero(atol) && iszero(rtol)
99+
return isantihermitian_exact(A; kwargs...)
100+
else
101+
return 2 * norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
102+
end
103+
end
104+
function isantihermitian_exact(A)
105+
return A == -A'
106+
end
107+
function isantihermitian_exact(A::StridedMatrix; kwargs...)
108+
return strided_ishermitian_exact(A, Val(true); kwargs...)
109+
end
110+
111+
# blocked implementation of exact checks for strided matrices
112+
# -----------------------------------------------------------
113+
function strided_ishermitian_exact(A::AbstractMatrix, anti::Val; blocksize = 32)
114+
n = size(A, 1)
115+
for j in 1:blocksize:n
116+
jb = min(blocksize, n - j + 1)
117+
_ishermitian_exact_diag(view(A, j:(j + jb - 1), j:(j + jb - 1)), anti) || return false
118+
for i in 1:blocksize:(j - 1)
119+
ib = blocksize
120+
_ishermitian_exact_offdiag(
121+
view(A, i:(i + ib - 1), j:(j + jb - 1)),
122+
view(A, j:(j + jb - 1), i:(i + ib - 1)),
123+
anti
124+
) || return false
125+
end
126+
end
127+
return true
128+
end
129+
function _ishermitian_exact_diag(A, ::Val{anti}) where {anti}
130+
n = size(A, 1)
131+
@inbounds for j in 1:n
132+
@simd for i in 1:j
133+
A[i, j] == (anti ? -adjoint(A[j, i]) : adjoint(A[j, i])) || return false
134+
end
135+
end
136+
return true
137+
end
138+
function _ishermitian_exact_offdiag(Al, Au, ::Val{anti}) where {anti}
139+
m, n = size(Al) # == reverse(size(Al))
140+
@inbounds for j in 1:n
141+
@simd for i in 1:m
142+
Al[i, j] == (anti ? -adjoint(Au[j, i]) : adjoint(Au[j, i])) || return false
143+
end
144+
end
145+
return true
59146
end

src/implementations/projections.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Inputs
2+
# ------
3+
function copy_input(::typeof(project_hermitian), A::AbstractMatrix)
4+
return copy!(similar(A, float(eltype(A))), A)
5+
end
6+
copy_input(::typeof(project_antihermitian), A) = copy_input(project_hermitian, A)
7+
8+
function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
9+
LinearAlgebra.checksquare(A)
10+
n = Base.require_one_based_indexing(A)
11+
B === A || @check_size(B, (n, n))
12+
return nothing
13+
end
14+
function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
15+
LinearAlgebra.checksquare(A)
16+
n = Base.require_one_based_indexing(A)
17+
B === A || @check_size(B, (n, n))
18+
return nothing
19+
end
20+
21+
# Outputs
22+
# -------
23+
function initialize_output(::typeof(project_hermitian!), A::AbstractMatrix, ::NativeBlocked)
24+
return A
25+
end
26+
function initialize_output(::typeof(project_antihermitian!), A::AbstractMatrix, ::NativeBlocked)
27+
return A
28+
end
29+
30+
# Implementation
31+
# --------------
32+
function project_hermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
33+
check_input(project_hermitian!, A, B, alg)
34+
return project_hermitian_native!(A, B, Val(false); alg.kwargs...)
35+
end
36+
function project_antihermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
37+
check_input(project_antihermitian!, A, B, alg)
38+
return project_hermitian_native!(A, B, Val(true); alg.kwargs...)
39+
end
40+
41+
function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)
42+
n = size(A, 1)
43+
for j in 1:blocksize:n
44+
for i in 1:blocksize:(j - 1)
45+
jb = min(blocksize, n - j + 1)
46+
ib = blocksize
47+
_project_hermitian_offdiag!(
48+
view(A, i:(i + ib - 1), j:(j + jb - 1)),
49+
view(A, j:(j + jb - 1), i:(i + ib - 1)),
50+
view(B, i:(i + ib - 1), j:(j + jb - 1)),
51+
view(B, j:(j + jb - 1), i:(i + ib - 1)),
52+
anti
53+
)
54+
end
55+
jb = min(blocksize, n - j + 1)
56+
_project_hermitian_diag!(
57+
view(A, j:(j + jb - 1), j:(j + jb - 1)),
58+
view(B, j:(j + jb - 1), j:(j + jb - 1)),
59+
anti
60+
)
61+
end
62+
return B
63+
end
64+
65+
function _project_hermitian_offdiag!(
66+
Au::AbstractMatrix, Al::AbstractMatrix, Bu::AbstractMatrix, Bl::AbstractMatrix, ::Val{anti}
67+
) where {anti}
68+
69+
m, n = size(Au) # == reverse(size(Au))
70+
return @inbounds for j in 1:n
71+
@simd for i in 1:m
72+
val = anti ? (Au[i, j] - adjoint(Al[j, i])) / 2 : (Au[i, j] + adjoint(Al[j, i])) / 2
73+
Bu[i, j] = val
74+
aval = adjoint(val)
75+
Bl[j, i] = anti ? -aval : aval
76+
end
77+
end
78+
return nothing
79+
end
80+
function _project_hermitian_diag!(A::AbstractMatrix, B::AbstractMatrix, ::Val{anti}) where {anti}
81+
n = size(A, 1)
82+
@inbounds for j in 1:n
83+
@simd for i in 1:(j - 1)
84+
val = anti ? (A[i, j] - adjoint(A[j, i])) / 2 : (A[i, j] + adjoint(A[j, i])) / 2
85+
B[i, j] = val
86+
aval = adjoint(val)
87+
B[j, i] = anti ? -aval : aval
88+
end
89+
B[j, j] = anti ? _imimag(A[j, j]) : real(A[j, j])
90+
end
91+
return nothing
92+
end
93+
94+
_imimag(x::Real) = zero(x)
95+
_imimag(x::Complex) = im * imag(x)

src/interface/projections.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
@doc """
2+
project_hermitian(A; kwargs...)
3+
project_hermitian(A, alg)
4+
project_hermitian!(A; kwargs...)
5+
project_hermitian!(A, alg)
6+
7+
Compute the hermitian part of a (square) matrix `A`, defined as `(A + A') / 2`.
8+
For real matrices this corresponds to the symmetric part of `A`.
9+
10+
See also [`project_antihermitian`](@ref).
11+
"""
12+
@functiondef project_hermitian
13+
14+
@doc """
15+
project_antihermitian(A; kwargs...)
16+
project_antihermitian(A, alg)
17+
project_antihermitian!(A; kwargs...)
18+
project_antihermitian!(A, alg)
19+
20+
Compute the anti-hermitian part of a (square) matrix `A`, defined as `(A - A') / 2`.
21+
For real matrices this corresponds to the antisymmetric part of `A`.
22+
23+
See also [`project_hermitian`](@ref).
24+
"""
25+
@functiondef project_antihermitian
26+
27+
"""
28+
NativeBlocked(; blocksize = 32)
29+
30+
Algorithm type to denote a native blocked algorithm with given `blocksize` for computing
31+
the hermitian or anti-hermitian part of a matrix.
32+
"""
33+
@algdef NativeBlocked
34+
# TODO: multithreaded? numthreads keyword?
35+
36+
default_hermitian_algorithm(A; kwargs...) = default_hermitian_algorithm(typeof(A); kwargs...)
37+
function default_hermitian_algorithm(::Type{A}; kwargs...) where {A <: AbstractMatrix}
38+
return NativeBlocked(; kwargs...)
39+
end
40+
41+
for f in (:project_hermitian!, :project_antihermitian!)
42+
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
43+
return default_hermitian_algorithm(A; kwargs...)
44+
end
45+
end

src/pullbacks/eigh.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function eigh_pullback!(
4242
indV = axes(V, 2)[ind]
4343
length(indV) == pV || throw(DimensionMismatch())
4444
mul!(view(VᴴΔV, :, indV), V', ΔV)
45-
aVᴴΔV = rmul!(VᴴΔV - VᴴΔV', 1 / 2)
45+
aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work
4646

4747
mask = abs.(D' .- D) .< degeneracy_atol
4848
Δgauge = norm(view(aVᴴΔV, mask))
@@ -58,7 +58,7 @@ function eigh_pullback!(
5858
length(indD) == pD || throw(DimensionMismatch())
5959
view(diagview(aVᴴΔV), indD) .+= real.(ΔDvec)
6060
end
61-
# recylce VdΔV space
61+
# recycle VdΔV space
6262
ΔA = mul!(ΔA, mul!(VᴴΔV, V, aVᴴΔV), V', 1, 1)
6363
elseif !iszerotangent(ΔDmat)
6464
ΔDvec = diagview(ΔDmat)
@@ -112,7 +112,7 @@ function eigh_trunc_pullback!(
112112
if !iszerotangent(ΔV)
113113
(n, p) == size(ΔV) || throw(DimensionMismatch())
114114
VᴴΔV = V' * ΔV
115-
aVᴴΔV = rmul!(VᴴΔV - VᴴΔV', 1 / 2)
115+
aVᴴΔV = project_antihermitian!(VᴴΔV)
116116

117117
mask = abs.(D' .- D) .< degeneracy_atol
118118
Δgauge = norm(view(aVᴴΔV, mask))

src/pullbacks/lq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ function lq_null_pullback!(
118118
gauge_atol::Real = tol
119119
)
120120
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
121-
NᴴΔN = Nᴴ * ΔNᴴ'
122-
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
121+
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
122+
Δgauge = norm(aNᴴΔN)
123123
Δgauge < tol ||
124124
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
125125
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?

src/pullbacks/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP)
1111
# Extract and check the cotangents
1212
ΔW, ΔP = ΔWP
1313
if !iszerotangent(ΔP)
14-
ΔP = (ΔP + ΔP') / 2
14+
ΔP = project_hermitian(ΔP)
1515
end
1616
M = zero(P)
1717
!iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1)
@@ -41,7 +41,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ)
4141
# Extract and check the cotangents
4242
ΔP, ΔWᴴ = ΔPWᴴ
4343
if !iszerotangent(ΔP)
44-
ΔP = (ΔP + ΔP') / 2
44+
ΔP = project_hermitian(ΔP)
4545
end
4646
M = zero(P)
4747
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)

src/pullbacks/qr.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ function qr_null_pullback!(
117117
gauge_atol::Real = tol
118118
)
119119
if !iszerotangent(ΔN) && size(N, 2) > 0
120-
NᴴΔN = N' * ΔN
121-
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
120+
aNᴴΔN = project_antihermitian!(N' * ΔN)
121+
Δgauge = norm(aNᴴΔN)
122122
Δgauge < tol ||
123123
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
124124

src/pullbacks/svd.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ function svd_pullback!(
6969
end
7070

7171
# Project onto antihermitian part; hermitian part outside of Grassmann tangent space
72-
aUΔU = rmul!(UΔU - UΔU', 1 / 2)
73-
aVΔV = rmul!(VΔV - VΔV', 1 / 2)
72+
aUΔU = project_antihermitian!(UΔU)
73+
aVΔV = project_antihermitian!(VΔV)
7474

7575
# check whether cotangents arise from gauge-invariance objective function
7676
mask = abs.(Sr' .- Sr) .< degeneracy_atol
@@ -159,8 +159,8 @@ function svd_trunc_pullback!(
159159
end
160160

161161
# Project onto antihermitian part; hermitian part outside of Grassmann tangent space
162-
aUΔU = rmul!(UΔU - UΔU', 1 / 2)
163-
aVΔV = rmul!(VΔV - VΔV', 1 / 2)
162+
aUΔU = project_antihermitian!(UΔU)
163+
aVΔV = project_antihermitian!(VΔV)
164164

165165
# check whether cotangents arise from gauge-invariance objective function
166166
mask = abs.(S' .- S) .< degeneracy_atol

0 commit comments

Comments
 (0)