Skip to content

Commit 66f7e53

Browse files
lkdvosJutho
andauthored
Improve consistency of gauge fixing in eigenvalue pullback functions (#221)
* make `ind` functions consistent * first set of fixes and consistencies * more updates to pullbacks of eig and eigh * Apply suggestions from code review Co-authored-by: Lukas Devos <ldevos98@gmail.com> Co-authored-by: Jutho <Jutho@users.noreply.github.com> * bypass issue with `zero!` and `view` --------- Co-authored-by: Jutho Haegeman <jutho.haegeman@ugent.be> Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent 3b65398 commit 66f7e53

4 files changed

Lines changed: 216 additions & 178 deletions

File tree

src/pullbacks/eig.jl

100644100755
Lines changed: 99 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,47 @@
1-
function check_eig_cotangents(
2-
D, VᴴΔV;
3-
degeneracy_atol::Real = default_pullback_rank_atol(D),
4-
gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV)
1+
function check_and_prepare_eig_cotangents(
2+
D, V, ViG, ΔDmat, ΔV, ind = Colon();
3+
degeneracy_atol::Real = default_pullback_rank_atol(S),
4+
gauge_atol::Real = default_pullback_gauge_atol(ΔDmat, ΔV)
55
)
6-
mask = abs.(transpose(D) .- D) .< degeneracy_atol
7-
Δgauge = norm(view(VᴴΔV, mask))
6+
7+
n, p = size(V)
8+
indD = axes(D, 1)[ind]
9+
indV = axes(V, 2)[ind]
10+
if !iszerotangent(ΔV)
11+
n == size(ΔV, 1) || throw(DimensionMismatch())
12+
length(indV) == size(ΔV, 2) || throw(DimensionMismatch())
13+
ΔV₁ = zero(V)
14+
ΔV₁[:, indV] = ΔV
15+
VᴴΔV₁ = V' * ΔV₁
16+
if p == n
17+
ΔV₊ = zero!(ΔV₁)
18+
else
19+
ΔV₊ = mul!(ΔV₁, ViG, VᴴΔV₁, -1, 1)
20+
end
21+
else
22+
ΔV₊ = nothing
23+
VᴴΔV₁ = zero!(similar(V, (p, p)))
24+
end
25+
bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v
26+
return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v)
27+
end
28+
Δgauge = norm(bc, Inf)
29+
830
Δgauge gauge_atol ||
931
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
10-
return
32+
33+
VᴴΔV₁ .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
34+
VᴴAΔV = VᴴΔV₁
35+
36+
if !iszerotangent(ΔDmat)
37+
ΔD = diagview(ΔDmat)
38+
length(indD) == length(ΔD) || throw(DimensionMismatch())
39+
view(diagview(VᴴAΔV), indD) .+= ΔD
40+
else
41+
ΔD = nothing
42+
end
43+
44+
return VᴴAΔV, ΔV₊
1145
end
1246

1347
"""
@@ -39,51 +73,24 @@ function eig_pullback!(
3973

4074
# Basic size checks and determination
4175
Dmat, V = DV
42-
D = diagview(Dmat)
43-
ΔDmat, ΔV = ΔDV
4476
n = LinearAlgebra.checksquare(V)
77+
D = diagview(Dmat)
4578
n == length(D) || throw(DimensionMismatch())
4679
(n, n) == size(ΔA) || throw(DimensionMismatch())
80+
ViG = inv(V)'
4781

48-
if !iszerotangent(ΔV)
49-
n == size(ΔV, 1) || throw(DimensionMismatch())
50-
pV = size(ΔV, 2)
51-
VᴴΔV = fill!(similar(V), 0)
52-
indV = axes(V, 2)[ind]
53-
length(indV) == pV || throw(DimensionMismatch())
54-
mul!(view(VᴴΔV, :, indV), V', ΔV)
55-
56-
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)
57-
58-
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
59-
60-
if !iszerotangent(ΔDmat)
61-
ΔDvec = diagview(ΔDmat)
62-
pD = length(ΔDvec)
63-
indD = axes(D, 1)[ind]
64-
length(indD) == pD || throw(DimensionMismatch())
65-
view(diagview(VᴴΔV), indD) .+= ΔDvec
66-
end
67-
PΔV = V' \ VᴴΔV
68-
if eltype(ΔA) <: Real
69-
ΔAc = mul!(VᴴΔV, PΔV, V') # recycle VdΔV memory
70-
ΔA .+= real.(ΔAc)
71-
else
72-
ΔA = mul!(ΔA, PΔV, V', 1, 1)
73-
end
74-
elseif !iszerotangent(ΔDmat)
75-
ΔDvec = diagview(ΔDmat)
76-
pD = length(ΔDvec)
77-
indD = axes(D, 1)[ind]
78-
length(indD) == pD || throw(DimensionMismatch())
79-
Vp = view(V, :, indD)
80-
PΔV = Vp' \ Diagonal(ΔDvec)
81-
if eltype(ΔA) <: Real
82-
ΔAc = PΔV * Vp'
83-
ΔA .+= real.(ΔAc)
84-
else
85-
ΔA = mul!(ΔA, PΔV, V', 1, 1)
86-
end
82+
ΔDmat, ΔV = ΔDV
83+
VᴴΔAV, = check_and_prepare_eig_cotangents(
84+
D, V, ViG, ΔDmat, ΔV, ind; degeneracy_atol, gauge_atol
85+
)
86+
87+
if eltype(ΔA) <: Real
88+
Z = ViG * VᴴΔAV
89+
ΔAc = mul!(VᴴΔAV, Z, V') # recycle VᴴΔAV
90+
ΔA .+= real.(ΔAc)
91+
else
92+
Z = ViG * VᴴΔAV
93+
ΔA = mul!(ΔA, Z, V', 1, 1)
8794
end
8895
return ΔA
8996
end
@@ -123,44 +130,56 @@ not small compared to `gauge_atol`.
123130
function eig_trunc_pullback!(
124131
ΔA::AbstractMatrix, A, DV, ΔDV;
125132
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
126-
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
133+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]),
134+
maxiter::Int = 100 # TODO: better default, depending on expected number of steps using quadratic convergence?
127135
)
128136

129137
# Basic size checks and determination
130138
Dmat, V = DV
131-
D = diagview(Dmat)
132-
ΔDmat, ΔV = ΔDV
133139
(n, p) = size(V)
134-
p == length(D) || throw(DimensionMismatch())
135140
(n, n) == size(ΔA) || throw(DimensionMismatch())
141+
D = diagview(Dmat)
142+
p == length(D) || throw(DimensionMismatch())
136143
G = V' * V
144+
ViG = V / LinearAlgebra.cholesky!(G)
137145

138-
if !iszerotangent(ΔV)
139-
(n, p) == size(ΔV) || throw(DimensionMismatch())
140-
VᴴΔV = V' * ΔV
141-
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)
142-
143-
ΔVperp = ΔV - V * inv(G) * VᴴΔV
144-
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
145-
else
146-
VᴴΔV = zero(G)
147-
end
148-
149-
if !iszerotangent(ΔDmat)
150-
ΔDvec = diagview(ΔDmat)
151-
p == length(ΔDvec) || throw(DimensionMismatch())
152-
diagview(VᴴΔV) .+= ΔDvec
153-
end
154-
Z = V' \ VᴴΔV
146+
ΔDmat, ΔV = ΔDV
147+
VᴴΔAV, ΔV₊ = check_and_prepare_eig_cotangents(
148+
D, V, ViG, ΔDmat, ΔV; degeneracy_atol, gauge_atol
149+
)
150+
Z = ViG * VᴴΔAV
155151

156152
# add contribution from orthogonal complement
157-
PA = A - (A * V) / V
158-
Y = mul!(ΔVperp, PA', Z, 1, 1)
159-
X = _sylvester(PA', -Dmat', Y)
160-
Z .+= X
161-
153+
AP = mul!(complex.(A), V * Dmat, ViG', -1, 1)
154+
X₀ = iszerotangent(ΔV₊) ? AP' * Z : mul!(ΔV₊, AP', Z, 1, 1)
155+
X₀ ./= D'
156+
dabsmax = maximum(abs, D)
157+
AP ./= dabsmax
158+
D̄⁻¹ = dabsmax ./ conj.(D)
159+
X₁ = rmul!(AP' * X₀, Diagonal(D̄⁻¹))
160+
X₁ .+= X₀
161+
Xₖ, Xₖ₊₁ = X₁, X₀
162+
APₖ, APₖ₊₁ = AP * AP, AP
163+
D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ .^ 2, D̄⁻¹
164+
for k in 1:maxiter
165+
Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APₖ', Xₖ), Diagonal(D̄⁻¹ₖ))
166+
if norm(Xₖ₊₁, Inf) < degeneracy_atol
167+
break
168+
end
169+
Xₖ₊₁ .+= Xₖ
170+
if k == maxiter
171+
@warn "Sylvester iteration did not converge after $k iterations, final norm of X: $(norm(Xₖ₊₁, Inf)))"
172+
break
173+
end
174+
D̄⁻¹ₖ₊₁ .= D̄⁻¹ₖ .^ 2
175+
APₖ₊₁ = mul!(APₖ₊₁, APₖ, APₖ)
176+
Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ
177+
APₖ, APₖ₊₁ = APₖ₊₁, APₖ
178+
D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ₖ₊₁, D̄⁻¹ₖ
179+
end
180+
Z .+= Xₖ
162181
if eltype(ΔA) <: Real
163-
ΔAc = Z * V'
182+
ΔAc = mul!(AP, Z, V') # recycle AP
164183
ΔA .+= real.(ΔAc)
165184
else
166185
ΔA = mul!(ΔA, Z, V', 1, 1)
@@ -211,15 +230,13 @@ across eigenvectors associated with degenerate eigenvalues), so the correspondin
211230
`ΔV` are projected out.
212231
"""
213232
function remove_eig_gauge_dependence!(
214-
ΔV, D, V, ind = axes(ΔV, 2);
233+
ΔV, D, V;
215234
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
216235
)
217-
length(ind) == size(ΔV, 2) || throw(DimensionMismatch())
218-
indV = axes(V, 2)[ind]
219-
Vp = view(V, :, indV)
220-
Ddiag = view(diagview(D), indV)
221-
gaugepart = Vp' * ΔV
236+
Ddiag = diagview(D)
237+
gaugepart = V' * ΔV
222238
gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0
223-
mul!(ΔV, Vp / (Vp' * Vp), gaugepart, -1, 1)
239+
ViG = V / LinearAlgebra.cholesky!(V' * V)
240+
mul!(ΔV, ViG, gaugepart, -1, 1)
224241
return ΔV
225242
end

0 commit comments

Comments
 (0)