|
1 | | -function check_eig_cotangents( |
2 | | - D, VᴴΔV; |
3 | | - degeneracy_atol::Real = default_pullback_rank_atol(D), |
4 | | - gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV) |
| 1 | +function check_and_prepare_eig_cotangents( |
| 2 | + D, V, ViG, ΔDmat, ΔV, ind = Colon(); |
| 3 | + degeneracy_atol::Real = default_pullback_rank_atol(S), |
| 4 | + gauge_atol::Real = default_pullback_gauge_atol(ΔDmat, ΔV) |
5 | 5 | ) |
6 | | - mask = abs.(transpose(D) .- D) .< degeneracy_atol |
7 | | - Δgauge = norm(view(VᴴΔV, mask)) |
| 6 | + |
| 7 | + n, p = size(V) |
| 8 | + indD = axes(D, 1)[ind] |
| 9 | + indV = axes(V, 2)[ind] |
| 10 | + if !iszerotangent(ΔV) |
| 11 | + n == size(ΔV, 1) || throw(DimensionMismatch()) |
| 12 | + length(indV) == size(ΔV, 2) || throw(DimensionMismatch()) |
| 13 | + ΔV₁ = zero(V) |
| 14 | + ΔV₁[:, indV] = ΔV |
| 15 | + VᴴΔV₁ = V' * ΔV₁ |
| 16 | + if p == n |
| 17 | + ΔV₊ = zero!(ΔV₁) |
| 18 | + else |
| 19 | + ΔV₊ = mul!(ΔV₁, ViG, VᴴΔV₁, -1, 1) |
| 20 | + end |
| 21 | + else |
| 22 | + ΔV₊ = nothing |
| 23 | + VᴴΔV₁ = zero!(similar(V, (p, p))) |
| 24 | + end |
| 25 | + bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v |
| 26 | + return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) |
| 27 | + end |
| 28 | + Δgauge = norm(bc, Inf) |
| 29 | + |
8 | 30 | Δgauge ≤ gauge_atol || |
9 | 31 | @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" |
10 | | - return |
| 32 | + |
| 33 | + VᴴΔV₁ .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) |
| 34 | + VᴴAΔV = VᴴΔV₁ |
| 35 | + |
| 36 | + if !iszerotangent(ΔDmat) |
| 37 | + ΔD = diagview(ΔDmat) |
| 38 | + length(indD) == length(ΔD) || throw(DimensionMismatch()) |
| 39 | + view(diagview(VᴴAΔV), indD) .+= ΔD |
| 40 | + else |
| 41 | + ΔD = nothing |
| 42 | + end |
| 43 | + |
| 44 | + return VᴴAΔV, ΔV₊ |
11 | 45 | end |
12 | 46 |
|
13 | 47 | """ |
@@ -39,51 +73,24 @@ function eig_pullback!( |
39 | 73 |
|
40 | 74 | # Basic size checks and determination |
41 | 75 | Dmat, V = DV |
42 | | - D = diagview(Dmat) |
43 | | - ΔDmat, ΔV = ΔDV |
44 | 76 | n = LinearAlgebra.checksquare(V) |
| 77 | + D = diagview(Dmat) |
45 | 78 | n == length(D) || throw(DimensionMismatch()) |
46 | 79 | (n, n) == size(ΔA) || throw(DimensionMismatch()) |
| 80 | + ViG = inv(V)' |
47 | 81 |
|
48 | | - if !iszerotangent(ΔV) |
49 | | - n == size(ΔV, 1) || throw(DimensionMismatch()) |
50 | | - pV = size(ΔV, 2) |
51 | | - VᴴΔV = fill!(similar(V), 0) |
52 | | - indV = axes(V, 2)[ind] |
53 | | - length(indV) == pV || throw(DimensionMismatch()) |
54 | | - mul!(view(VᴴΔV, :, indV), V', ΔV) |
55 | | - |
56 | | - check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) |
57 | | - |
58 | | - VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) |
59 | | - |
60 | | - if !iszerotangent(ΔDmat) |
61 | | - ΔDvec = diagview(ΔDmat) |
62 | | - pD = length(ΔDvec) |
63 | | - indD = axes(D, 1)[ind] |
64 | | - length(indD) == pD || throw(DimensionMismatch()) |
65 | | - view(diagview(VᴴΔV), indD) .+= ΔDvec |
66 | | - end |
67 | | - PΔV = V' \ VᴴΔV |
68 | | - if eltype(ΔA) <: Real |
69 | | - ΔAc = mul!(VᴴΔV, PΔV, V') # recycle VdΔV memory |
70 | | - ΔA .+= real.(ΔAc) |
71 | | - else |
72 | | - ΔA = mul!(ΔA, PΔV, V', 1, 1) |
73 | | - end |
74 | | - elseif !iszerotangent(ΔDmat) |
75 | | - ΔDvec = diagview(ΔDmat) |
76 | | - pD = length(ΔDvec) |
77 | | - indD = axes(D, 1)[ind] |
78 | | - length(indD) == pD || throw(DimensionMismatch()) |
79 | | - Vp = view(V, :, indD) |
80 | | - PΔV = Vp' \ Diagonal(ΔDvec) |
81 | | - if eltype(ΔA) <: Real |
82 | | - ΔAc = PΔV * Vp' |
83 | | - ΔA .+= real.(ΔAc) |
84 | | - else |
85 | | - ΔA = mul!(ΔA, PΔV, V', 1, 1) |
86 | | - end |
| 82 | + ΔDmat, ΔV = ΔDV |
| 83 | + VᴴΔAV, = check_and_prepare_eig_cotangents( |
| 84 | + D, V, ViG, ΔDmat, ΔV, ind; degeneracy_atol, gauge_atol |
| 85 | + ) |
| 86 | + |
| 87 | + if eltype(ΔA) <: Real |
| 88 | + Z = ViG * VᴴΔAV |
| 89 | + ΔAc = mul!(VᴴΔAV, Z, V') # recycle VᴴΔAV |
| 90 | + ΔA .+= real.(ΔAc) |
| 91 | + else |
| 92 | + Z = ViG * VᴴΔAV |
| 93 | + ΔA = mul!(ΔA, Z, V', 1, 1) |
87 | 94 | end |
88 | 95 | return ΔA |
89 | 96 | end |
@@ -123,44 +130,56 @@ not small compared to `gauge_atol`. |
123 | 130 | function eig_trunc_pullback!( |
124 | 131 | ΔA::AbstractMatrix, A, DV, ΔDV; |
125 | 132 | degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), |
126 | | - gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) |
| 133 | + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]), |
| 134 | + maxiter::Int = 100 # TODO: better default, depending on expected number of steps using quadratic convergence? |
127 | 135 | ) |
128 | 136 |
|
129 | 137 | # Basic size checks and determination |
130 | 138 | Dmat, V = DV |
131 | | - D = diagview(Dmat) |
132 | | - ΔDmat, ΔV = ΔDV |
133 | 139 | (n, p) = size(V) |
134 | | - p == length(D) || throw(DimensionMismatch()) |
135 | 140 | (n, n) == size(ΔA) || throw(DimensionMismatch()) |
| 141 | + D = diagview(Dmat) |
| 142 | + p == length(D) || throw(DimensionMismatch()) |
136 | 143 | G = V' * V |
| 144 | + ViG = V / LinearAlgebra.cholesky!(G) |
137 | 145 |
|
138 | | - if !iszerotangent(ΔV) |
139 | | - (n, p) == size(ΔV) || throw(DimensionMismatch()) |
140 | | - VᴴΔV = V' * ΔV |
141 | | - check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) |
142 | | - |
143 | | - ΔVperp = ΔV - V * inv(G) * VᴴΔV |
144 | | - VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) |
145 | | - else |
146 | | - VᴴΔV = zero(G) |
147 | | - end |
148 | | - |
149 | | - if !iszerotangent(ΔDmat) |
150 | | - ΔDvec = diagview(ΔDmat) |
151 | | - p == length(ΔDvec) || throw(DimensionMismatch()) |
152 | | - diagview(VᴴΔV) .+= ΔDvec |
153 | | - end |
154 | | - Z = V' \ VᴴΔV |
| 146 | + ΔDmat, ΔV = ΔDV |
| 147 | + VᴴΔAV, ΔV₊ = check_and_prepare_eig_cotangents( |
| 148 | + D, V, ViG, ΔDmat, ΔV; degeneracy_atol, gauge_atol |
| 149 | + ) |
| 150 | + Z = ViG * VᴴΔAV |
155 | 151 |
|
156 | 152 | # add contribution from orthogonal complement |
157 | | - PA = A - (A * V) / V |
158 | | - Y = mul!(ΔVperp, PA', Z, 1, 1) |
159 | | - X = _sylvester(PA', -Dmat', Y) |
160 | | - Z .+= X |
161 | | - |
| 153 | + AP = mul!(complex.(A), V * Dmat, ViG', -1, 1) |
| 154 | + X₀ = iszerotangent(ΔV₊) ? AP' * Z : mul!(ΔV₊, AP', Z, 1, 1) |
| 155 | + X₀ ./= D' |
| 156 | + dabsmax = maximum(abs, D) |
| 157 | + AP ./= dabsmax |
| 158 | + D̄⁻¹ = dabsmax ./ conj.(D) |
| 159 | + X₁ = rmul!(AP' * X₀, Diagonal(D̄⁻¹)) |
| 160 | + X₁ .+= X₀ |
| 161 | + Xₖ, Xₖ₊₁ = X₁, X₀ |
| 162 | + APₖ, APₖ₊₁ = AP * AP, AP |
| 163 | + D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ .^ 2, D̄⁻¹ |
| 164 | + for k in 1:maxiter |
| 165 | + Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APₖ', Xₖ), Diagonal(D̄⁻¹ₖ)) |
| 166 | + if norm(Xₖ₊₁, Inf) < degeneracy_atol |
| 167 | + break |
| 168 | + end |
| 169 | + Xₖ₊₁ .+= Xₖ |
| 170 | + if k == maxiter |
| 171 | + @warn "Sylvester iteration did not converge after $k iterations, final norm of X: $(norm(Xₖ₊₁, Inf)))" |
| 172 | + break |
| 173 | + end |
| 174 | + D̄⁻¹ₖ₊₁ .= D̄⁻¹ₖ .^ 2 |
| 175 | + APₖ₊₁ = mul!(APₖ₊₁, APₖ, APₖ) |
| 176 | + Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ |
| 177 | + APₖ, APₖ₊₁ = APₖ₊₁, APₖ |
| 178 | + D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ₖ₊₁, D̄⁻¹ₖ |
| 179 | + end |
| 180 | + Z .+= Xₖ |
162 | 181 | if eltype(ΔA) <: Real |
163 | | - ΔAc = Z * V' |
| 182 | + ΔAc = mul!(AP, Z, V') # recycle AP |
164 | 183 | ΔA .+= real.(ΔAc) |
165 | 184 | else |
166 | 185 | ΔA = mul!(ΔA, Z, V', 1, 1) |
@@ -211,15 +230,13 @@ across eigenvectors associated with degenerate eigenvalues), so the correspondin |
211 | 230 | `ΔV` are projected out. |
212 | 231 | """ |
213 | 232 | function remove_eig_gauge_dependence!( |
214 | | - ΔV, D, V, ind = axes(ΔV, 2); |
| 233 | + ΔV, D, V; |
215 | 234 | degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) |
216 | 235 | ) |
217 | | - length(ind) == size(ΔV, 2) || throw(DimensionMismatch()) |
218 | | - indV = axes(V, 2)[ind] |
219 | | - Vp = view(V, :, indV) |
220 | | - Ddiag = view(diagview(D), indV) |
221 | | - gaugepart = Vp' * ΔV |
| 236 | + Ddiag = diagview(D) |
| 237 | + gaugepart = V' * ΔV |
222 | 238 | gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0 |
223 | | - mul!(ΔV, Vp / (Vp' * Vp), gaugepart, -1, 1) |
| 239 | + ViG = V / LinearAlgebra.cholesky!(V' * V) |
| 240 | + mul!(ΔV, ViG, gaugepart, -1, 1) |
224 | 241 | return ΔV |
225 | 242 | end |
0 commit comments