diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl old mode 100644 new mode 100755 index 7b78121b..bf410169 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -1,13 +1,53 @@ -function check_eig_cotangents( - D, VᴴΔV; - degeneracy_atol::Real = default_pullback_rank_atol(D), - gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV) +function check_and_prepare_eig_cotangents( + D, V, ViG, ΔDmat, ΔV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(S), + gauge_atol::Real = default_pullback_gauge_atol(ΔDmat, ΔV) ) - mask = abs.(transpose(D) .- D) .< degeneracy_atol - Δgauge = norm(view(VᴴΔV, mask)) + + n, p = size(V) + indD = axes(D, 1)[ind] + indV = axes(V, 2)[ind] + if !iszerotangent(ΔV) + n == size(ΔV, 1) || throw(DimensionMismatch()) + length(indV) == size(ΔV, 2) || throw(DimensionMismatch()) + if indV == 1:p + ΔV₁ = copy(ΔV) + else + ΔV₁ = zero(V) + for (j, i) in enumerate(indV) + ΔV₁[:, i] .= view(ΔV, :, j) + end + end + VᴴΔV₁ = V' * ΔV₁ + if p == n + ΔV₊ = zero!(ΔV₁) + else + ΔV₊ = mul!(ΔV₁, ViG, VᴴΔV₁, -1, 1) + end + else + ΔV₊ = nothing + VᴴΔV₁ = zero!(similar(V, (p, p))) + end + bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v + return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) + end + Δgauge = norm(bc, Inf) + Δgauge ≤ gauge_atol || @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return + + VᴴΔV₁ .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) + VᴴAΔV = VᴴΔV₁ + + if !iszerotangent(ΔDmat) + ΔD = diagview(ΔDmat) + length(indD) == length(ΔD) || throw(DimensionMismatch()) + view(diagview(VᴴAΔV), indD) .+= ΔD + else + ΔD = nothing + end + + return VᴴAΔV, ΔV₊ end """ @@ -39,51 +79,24 @@ function eig_pullback!( # Basic size checks and determination Dmat, V = DV - D = diagview(Dmat) - ΔDmat, ΔV = ΔDV n = LinearAlgebra.checksquare(V) + D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + ViG = inv(V)' - if !iszerotangent(ΔV) - n == size(ΔV, 1) || throw(DimensionMismatch()) - pV = size(ΔV, 2) - VᴴΔV = fill!(similar(V), 0) - indV = axes(V, 2)[ind] - length(indV) == pV || throw(DimensionMismatch()) - mul!(view(VᴴΔV, :, indV), V', ΔV) - - check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) - - VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) - - if !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - pD = length(ΔDvec) - indD = axes(D, 1)[ind] - length(indD) == pD || throw(DimensionMismatch()) - view(diagview(VᴴΔV), indD) .+= ΔDvec - end - PΔV = V' \ VᴴΔV - if eltype(ΔA) <: Real - ΔAc = mul!(VᴴΔV, PΔV, V') # recycle VdΔV memory - ΔA .+= real.(ΔAc) - else - ΔA = mul!(ΔA, PΔV, V', 1, 1) - end - elseif !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - pD = length(ΔDvec) - indD = axes(D, 1)[ind] - length(indD) == pD || throw(DimensionMismatch()) - Vp = view(V, :, indD) - PΔV = Vp' \ Diagonal(ΔDvec) - if eltype(ΔA) <: Real - ΔAc = PΔV * Vp' - ΔA .+= real.(ΔAc) - else - ΔA = mul!(ΔA, PΔV, V', 1, 1) - end + ΔDmat, ΔV = ΔDV + VᴴΔAV, = check_and_prepare_eig_cotangents( + D, V, ViG, ΔDmat, ΔV, ind; degeneracy_atol, gauge_atol + ) + + if eltype(ΔA) <: Real + Z = ViG * VᴴΔAV + ΔAc = mul!(VᴴΔAV, Z, V') # recycle VᴴΔAV + ΔA .+= real.(ΔAc) + else + Z = ViG * VᴴΔAV + ΔA = mul!(ΔA, Z, V', 1, 1) end return ΔA end @@ -123,44 +136,56 @@ not small compared to `gauge_atol`. function eig_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]), + maxiter::Int = 100 # TODO: better default, depending on expected number of steps using quadratic convergence? ) # Basic size checks and determination Dmat, V = DV - D = diagview(Dmat) - ΔDmat, ΔV = ΔDV (n, p) = size(V) - p == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + D = diagview(Dmat) + p == length(D) || throw(DimensionMismatch()) G = V' * V + ViG = V / LinearAlgebra.cholesky!(G) - if !iszerotangent(ΔV) - (n, p) == size(ΔV) || throw(DimensionMismatch()) - VᴴΔV = V' * ΔV - check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) - - ΔVperp = ΔV - V * inv(G) * VᴴΔV - VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) - else - VᴴΔV = zero(G) - end - - if !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - p == length(ΔDvec) || throw(DimensionMismatch()) - diagview(VᴴΔV) .+= ΔDvec - end - Z = V' \ VᴴΔV + ΔDmat, ΔV = ΔDV + VᴴΔAV, ΔV₊ = check_and_prepare_eig_cotangents( + D, V, ViG, ΔDmat, ΔV; degeneracy_atol, gauge_atol + ) + Z = ViG * VᴴΔAV # add contribution from orthogonal complement - PA = A - (A * V) / V - Y = mul!(ΔVperp, PA', Z, 1, 1) - X = _sylvester(PA', -Dmat', Y) - Z .+= X - + AP = mul!(complex.(A), V * Dmat, ViG', -1, 1) + X₀ = iszerotangent(ΔV₊) ? AP' * Z : mul!(ΔV₊, AP', Z, 1, 1) + X₀ ./= D' + dabsmax = maximum(abs, D) + AP ./= dabsmax + D̄⁻¹ = dabsmax ./ conj.(D) + X₁ = rmul!(AP' * X₀, Diagonal(D̄⁻¹)) + X₁ .+= X₀ + Xₖ, Xₖ₊₁ = X₁, X₀ + APₖ, APₖ₊₁ = AP * AP, AP + D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ .^ 2, D̄⁻¹ + for k in 1:maxiter + Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APₖ', Xₖ), Diagonal(D̄⁻¹ₖ)) + if norm(Xₖ₊₁, Inf) < degeneracy_atol + break + end + Xₖ₊₁ .+= Xₖ + if k == maxiter + @warn "Sylvester iteration did not converge after $k iterations, final norm of X: $(norm(Xₖ₊₁, Inf)))" + break + end + D̄⁻¹ₖ₊₁ .= D̄⁻¹ₖ .^ 2 + APₖ₊₁ = mul!(APₖ₊₁, APₖ, APₖ) + Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ + APₖ, APₖ₊₁ = APₖ₊₁, APₖ + D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ₖ₊₁, D̄⁻¹ₖ + end + Z .+= Xₖ if eltype(ΔA) <: Real - ΔAc = Z * V' + ΔAc = mul!(AP, Z, V') # recycle AP ΔA .+= real.(ΔAc) else ΔA = mul!(ΔA, Z, V', 1, 1) @@ -211,15 +236,13 @@ across eigenvectors associated with degenerate eigenvalues), so the correspondin `ΔV` are projected out. """ function remove_eig_gauge_dependence!( - ΔV, D, V, ind = axes(ΔV, 2); + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) ) - length(ind) == size(ΔV, 2) || throw(DimensionMismatch()) - indV = axes(V, 2)[ind] - Vp = view(V, :, indV) - Ddiag = view(diagview(D), indV) - gaugepart = Vp' * ΔV + Ddiag = diagview(D) + gaugepart = V' * ΔV gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0 - mul!(ΔV, Vp / (Vp' * Vp), gaugepart, -1, 1) + ViG = V / LinearAlgebra.cholesky!(V' * V) + mul!(ΔV, ViG, gaugepart, -1, 1) return ΔV end diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl old mode 100644 new mode 100755 index 3b517b97..fc8d86e2 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -1,13 +1,54 @@ -function check_eigh_cotangents( - D, aVᴴΔV; - degeneracy_atol::Real = default_pullback_rank_atol(D), - gauge_atol::Real = default_pullback_gauge_atol(aVᴴΔV) +function check_and_prepare_eigh_cotangents( + D, V, ΔDmat, ΔV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(S), + gauge_atol::Real = default_pullback_gauge_atol(ΔDmat, ΔV) ) - mask = abs.(D' .- D) .< degeneracy_atol - Δgauge = norm(view(aVᴴΔV, mask)) + + n, p = size(V) + indD = axes(D, 1)[ind] + indV = axes(V, 2)[ind] + if !iszerotangent(ΔV) + n == size(ΔV, 1) || throw(DimensionMismatch()) + length(indV) == size(ΔV, 2) || throw(DimensionMismatch()) + if indV == 1:p + ΔV₁ = copy(ΔV) + else + ΔV₁ = zero(V) + for (j, i) in enumerate(indV) + ΔV₁[:, i] .= view(ΔV, :, j) + end + end + VᴴΔV₁ = V' * ΔV₁ + if p == n + ΔV₊ = zero!(ΔV₁) + else + ΔV₊ = mul!(ΔV₁, V, VᴴΔV₁, -1, 1) + end + aVᴴΔV₁ = project_antihermitian!(VᴴΔV₁) + else + ΔV₊ = nothing + aVᴴΔV₁ = zero!(similar(V, (p, p))) + end + bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v + return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) + end + Δgauge = norm(bc, Inf) + Δgauge ≤ gauge_atol || @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return + + aVᴴΔV₁ .*= inv_safe.(D' .- D, degeneracy_atol) + VᴴAΔV = aVᴴΔV₁ + + if !iszerotangent(ΔDmat) + ΔD = diagview(ΔDmat) + length(indD) == length(ΔD) || throw(DimensionMismatch()) + view(diagview(VᴴAΔV), indD) .+= real.(ΔD) + else + ΔD = nothing + end + + return VᴴAΔV, ΔV₊ end """ @@ -39,42 +80,17 @@ function eigh_pullback!( # Basic size checks and determination Dmat, V = DV - D = diagview(Dmat) - ΔDmat, ΔV = ΔDV n = LinearAlgebra.checksquare(V) + D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) - if !iszerotangent(ΔV) - n == size(ΔV, 1) || throw(DimensionMismatch()) - pV = size(ΔV, 2) - VᴴΔV = fill!(similar(V), 0) - indV = axes(V, 2)[ind] - length(indV) == pV || throw(DimensionMismatch()) - mul!(view(VᴴΔV, :, indV), V', ΔV) - aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work - - check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol) - - aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) - - if !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - pD = length(ΔDvec) - indD = axes(D, 1)[ind] - length(indD) == pD || throw(DimensionMismatch()) - view(diagview(aVᴴΔV), indD) .+= real.(ΔDvec) - end - # recycle VdΔV space - ΔA = mul!(ΔA, mul!(VᴴΔV, V, aVᴴΔV), V', 1, 1) - elseif !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - pD = length(ΔDvec) - indD = axes(D, 1)[ind] - length(indD) == pD || throw(DimensionMismatch()) - Vp = view(V, :, indD) - ΔA = mul!(ΔA, Vp * Diagonal(real(ΔDvec)), Vp', 1, 1) - end + ΔDmat, ΔV = ΔDV + VᴴΔAV, = check_and_prepare_eigh_cotangents( + D, V, ΔDmat, ΔV, ind; degeneracy_atol, gauge_atol + ) + + ΔA = mul!(ΔA, V * VᴴΔAV, V', 1, 1) return ΔA end function eigh_pullback!( @@ -113,47 +129,58 @@ not small compared to `gauge_atol`. function eigh_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]), + maxiter::Int = 100 # TODO: better default, depending on expected number of steps using quadratic convergence? ) # Basic size checks and determination Dmat, V = DV - D = diagview(Dmat) - ΔDmat, ΔV = ΔDV (n, p) = size(V) + D = diagview(Dmat) p == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) - if !iszerotangent(ΔV) - (n, p) == size(ΔV) || throw(DimensionMismatch()) - VᴴΔV = V' * ΔV - aVᴴΔV = project_antihermitian!(VᴴΔV) - - check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol) - - aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) - - if !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - p == length(ΔDvec) || throw(DimensionMismatch()) - diagview(aVᴴΔV) .+= real.(ΔDvec) + ΔDmat, ΔV = ΔDV + VᴴΔAV, ΔV₊ = check_and_prepare_eigh_cotangents( + D, V, ΔDmat, ΔV; degeneracy_atol, gauge_atol + ) + Z = V * VᴴΔAV + if !iszerotangent(ΔV₊) + X₀ = rdiv!(ΔV₊, Diagonal(D)) + AP = mul!(copy(A), V * Dmat, V', -1, 1) + dabsmax = maximum(abs, D) + AP ./= dabsmax + D⁻¹ = dabsmax ./ D + X₁ = rmul!(AP * X₀, Diagonal(D⁻¹)) + X₁ .+= X₀ + Xₖ, Xₖ₊₁ = X₁, X₀ + APₖ, APₖ₊₁ = AP * AP, AP + D⁻¹ₖ, D⁻¹ₖ₊₁ = D⁻¹ .^ 2, D⁻¹ + for k in 1:maxiter + Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APₖ, Xₖ), Diagonal(D⁻¹ₖ)) + if norm(Xₖ₊₁, Inf) < degeneracy_atol + break + end + Xₖ₊₁ .+= Xₖ + if k == maxiter + @warn "Sylvester iteration did not converge after $k iterations, final norm of X: $(norm(Xₖ₊₁, Inf)))" + break + end + D⁻¹ₖ₊₁ .= D⁻¹ₖ .^ 2 + APₖ₊₁ = mul!(APₖ₊₁, APₖ, APₖ) + Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ + APₖ, APₖ₊₁ = APₖ₊₁, APₖ + D⁻¹ₖ, D⁻¹ₖ₊₁ = D⁻¹ₖ₊₁, D⁻¹ₖ end - - Z = V * aVᴴΔV - - # add contribution from orthogonal complement - W = qr_null(V) - WᴴΔV = W' * ΔV - X = _sylvester(W' * A * W, -Dmat, WᴴΔV) - Z = mul!(Z, W, X, 1, 1) - - # put everything together: symmetrize for hermitian case - ΔA = mul!(ΔA, Z, V', 1 // 2, 1) - ΔA = mul!(ΔA, V, Z', 1 // 2, 1) - elseif !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - p == length(ΔDvec) || throw(DimensionMismatch()) - ΔA = mul!(ΔA, V * Diagonal(real(ΔDvec)), V', 1, 1) + Z .+= Xₖ + # we cannot directly multiply Z * V' into ΔA, because we have to + # take the Hermitian part, and cannot apply project_hermitian! to + # the current contents of ΔA + ΔA′ = project_hermitian!(mul!(AP, Z, V', 1, 1)) # recycle AP + ΔA .+= ΔA′ + else + # in this case, Z * V' is automatically Hermitian, so we can directly add it to ΔA + ΔA = mul!(ΔA, Z, V', 1, 1) end return ΔA end @@ -201,15 +228,12 @@ across eigenvectors associated with degenerate eigenvalues), so the correspondin components of `V' * ΔV` are projected out. """ function remove_eigh_gauge_dependence!( - ΔV, D, V, ind = axes(ΔV, 2); + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) ) - length(ind) == size(ΔV, 2) || throw(DimensionMismatch()) - indV = axes(V, 2)[ind] - Vp = view(V, :, indV) - Ddiag = view(diagview(D), indV) - gaugepart = project_antihermitian!(Vp' * ΔV) + Ddiag = diagview(D) + gaugepart = project_antihermitian!(V' * ΔV) gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0 - mul!(ΔV, Vp, gaugepart, -1, 1) + mul!(ΔV, V, gaugepart, -1, 1) return ΔV end diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl old mode 100644 new mode 100755 index 4b352416..832d04a1 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -82,7 +82,7 @@ function check_and_prepare_svd_cotangents( aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r))) end bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v - return abs(s₁ - s₂) < degeneracy_atol ? zero(u) + zero(v) : u + v + return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v) end Δgauge = max(Δgauge, norm(bc, Inf)) @@ -104,13 +104,13 @@ function check_and_prepare_svd_cotangents( Δgauge ≤ gauge_atol || @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - UdΔAV = (aUᴴΔU₁ .+ aVᴴΔV₁) .* inv_safe.(S₁' .- S₁, degeneracy_atol) .+ + UᴴΔAV = (aUᴴΔU₁ .+ aVᴴΔV₁) .* inv_safe.(S₁' .- S₁, degeneracy_atol) .+ (aUᴴΔU₁ .- aVᴴΔV₁) .* inv_safe.(S₁' .+ S₁, degeneracy_atol) if !iszerotangent(ΔS₁) - diagview(UdΔAV) .+= real.(ΔS₁) + diagview(UᴴΔAV) .+= real.(ΔS₁) end - return UdΔAV, ΔU₊, ΔV₊ᴴ + return UᴴΔAV, ΔU₊, ΔV₊ᴴ end """ @@ -155,10 +155,10 @@ function svd_pullback!( S₁ = view(S, 1:r) ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( + UᴴΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, r, ind; degeneracy_atol, gauge_atol ) - ΔA = mul!(ΔA, U₁, UdΔAV * V₁ᴴ, 1, 1) # add the contribution to ΔA + ΔA = mul!(ΔA, U₁, UᴴΔAV * V₁ᴴ, 1, 1) # add the contribution to ΔA # Add the remaining contributions if m > r && !iszerotangent(ΔU₊) # ΔU₁ is already orthogonal to U₁ @@ -210,7 +210,7 @@ function svd_trunc_pullback!( rank_atol::Real = 0, degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...), - maxiter::Int = 1000, + maxiter::Int = 100 # TODO: better default, depending on expected number of steps using quadratic convergence? ) # Extract the SVD components U, Smat, Vᴴ = USVᴴ @@ -223,17 +223,19 @@ function svd_trunc_pullback!( # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( + UᴴΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, p; degeneracy_atol, gauge_atol ) - ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA + ΔAV = U * UᴴΔAV + ΔA = mul!(ΔA, ΔAV, Vᴴ, 1, 1) # add the contribution to ΔA # The contribtutions from the orthogonal complement need to be treated differently # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ if !(iszerotangent(ΔU₊) && iszerotangent(ΔV₊ᴴ)) X₀ = iszerotangent(ΔU₊) ? zero(U) : rdiv!(ΔU₊, Diagonal(S)) Y₀ᴴ = iszerotangent(ΔV₊ᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔV₊ᴴ) - AP = mul!(copy(A), U, Smat * Vᴴ, -1, 1) + US = mul!(ΔAV, U, Smat) # recycle ΔAV + AP = mul!(copy(A), US, Vᴴ, -1, 1) AP ./= S[end] S⁻¹ = S[end] ./ S X₁ = rmul!(AP * Y₀ᴴ', Diagonal(S⁻¹)) @@ -254,7 +256,7 @@ function svd_trunc_pullback!( Xₖ₊₁ .+= Xₖ Yₖ₊₁ᴴ .+= Yₖᴴ if k == maxiter - @warn "Sylvester iteration did not converge after $k iterations, final norms: (X: $(norm(Xₖ₊₁, Inf)), Yᴴ: $(norm(Yₖ₊₁ᴴ, Inf)))" + @warn "Sylvester iteration did not converge after $k iterations, final norms of X: $(norm(Xₖ₊₁, Inf)), Yᴴ: $(norm(Yₖ₊₁ᴴ, Inf)))" break end S⁻¹ₖ₊₁ .= S⁻¹ₖ .^ 2 diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl old mode 100644 new mode 100755 index 558afc83..25b49841 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -258,9 +258,13 @@ function test_chainrules_eig( output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + Ddiag = diagview(DV[1]) + p = sortperm(Ddiag, by = abs, rev = true) + if abs(Ddiag[p[r + 1]]) < abs(Ddiag[p[r]]) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end end truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) @@ -273,10 +277,6 @@ function test_chainrules_eig( cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) - ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) end end end