Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/common/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,31 @@ quantity needs to be computed.
defaulttol(x::Any) = eps(real(float(one(eltype(x)))))^(2 / 3)

"""
default_pullback_gaugetol(a)
default_pullback_gauge_atol(ΔA...)

Default tolerance for deciding to warn if incoming adjoints of a pullback rule
has components that are not gauge-invariant.
"""
function default_pullback_gaugetol(a)
n = norm(a, Inf)
return eps(eltype(n))^(3 / 4) * max(n, one(n))
end
default_pullback_gauge_atol(A) = eps(norm(A, Inf))^(3 / 4)
default_pullback_gauge_atol(A, As...) = maximum(default_pullback_gauge_atol, (A, As...))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this multi-argument definition is necessary (I realised this yesterday after closing my computer and going to bed). Simply dong default_pullback_gauge_atol((ΔU, ΔVᴴ)) should be ok; norm((a,b), Inf) is automatically compiled to max(norm(a, Inf), norm(b, Inf)).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. If I am very nit-picky, I would say this leads to a type instability in the case of one of the two adjoints of the svd being zero, but I assume this is ok, as it will also not propagate to the output of the functions where this gauge_tol is being used, as it is just in a comparison.

julia> @code_warntype default_pullback_gauge_atol(ZeroTangent(), randn(5,5))
MethodInstance for default_pullback_gauge_atol(::ZeroTangent, ::Matrix{Float64})
  from default_pullback_gauge_atol(A, As...) @ Main REPL[5]:1
Arguments
  #self#::Core.Const(Main.default_pullback_gauge_atol)
  A::Core.Const(ZeroTangent())
  As::Tuple{Matrix{Float64}}
Body::Any
1 ─ %1 = Main.maximum::Core.Const(maximum)
│   %2 = Main.default_pullback_gauge_atol::Core.Const(Main.default_pullback_gauge_atol)
│   %3 = Core.tuple(A)::Core.Const((ZeroTangent(),))
│   %4 = Core._apply_iterate(Base.iterate, Core.tuple, %3, As)::Tuple{ZeroTangent, Matrix{Float64}}
│   %5 = (%1)(%2, %4)::Any
└──      return %5


"""
default_pullback_degeneracy_atol(A)

Default tolerance for deciding when values should be considered as degenerate.
"""
default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4)

"""
default_pullback_rank_atol(A)

Default tolerance for deciding what values should be considered equal to 0.
"""
default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4)

"""
default_hermitian_tol(A)

Default tolerance for deciding to warn if the provided `A` is not hermitian.
"""
function default_hermitian_tol(A)
n = norm(A, Inf)
return eps(eltype(n))^(3 / 4) * max(n, one(n))
end
default_hermitian_tol(A) = eps(norm(A, Inf))^(3 / 4)
20 changes: 8 additions & 12 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
eig_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
tol = default_pullback_gaugetol(DV[1]),
degeneracy_atol = tol,
gauge_atol = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

Adds the pullback from the full eigenvalue decomposition of `A` to `ΔA`, given the output
Expand All @@ -22,9 +21,8 @@ not small compared to `gauge_atol`.
"""
function eig_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon();
tol::Real = default_pullback_gaugetol(DV[1]),
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

# Basic size checks and determination
Expand Down Expand Up @@ -84,9 +82,8 @@ end
"""
eig_trunc_pullback!(
ΔA::AbstractMatrix, ΔDV, A, DV;
tol = default_pullback_gaugetol(DV[1]),
degeneracy_atol = tol,
gauge_atol = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

Adds the pullback from the truncated eigenvalue decomposition of `A` to `ΔA`, given the
Expand All @@ -106,9 +103,8 @@ not small compared to `gauge_atol`.
"""
function eig_trunc_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV;
tol::Real = default_pullback_gaugetol(DV[1]),
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

# Basic size checks and determination
Expand Down
24 changes: 10 additions & 14 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
eigh_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
tol = default_pullback_gaugetol(DV[1]),
degeneracy_atol = tol,
gauge_atol = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

Adds the pullback from the Hermitian eigenvalue decomposition of `A` to `ΔA`, given the
Expand All @@ -22,9 +21,8 @@ anti-hermitian part of `V' * ΔV`, restricted to rows `i` and columns `j` for wh
"""
function eigh_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon();
tol::Real = default_pullback_gaugetol(DV[1]),
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

# Basic size checks and determination
Expand All @@ -49,7 +47,7 @@ function eigh_pullback!(
Δgauge < gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

aVᴴΔV .*= inv_safe.(D' .- D, tol)
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

if !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
Expand All @@ -74,9 +72,8 @@ end
"""
eigh_trunc_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV;
tol=default_pullback_gaugetol(DV[1]),
degeneracy_atol=tol,
gauge_atol=tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

Adds the pullback from the truncated Hermitian eigenvalue decomposition of `A` to `ΔA`,
Expand All @@ -96,9 +93,8 @@ not small compared to `gauge_atol`.
"""
function eigh_trunc_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV;
tol::Real = default_pullback_gaugetol(DV[1]),
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

# Basic size checks and determination
Expand All @@ -119,7 +115,7 @@ function eigh_trunc_pullback!(
Δgauge < gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

aVᴴΔV .*= inv_safe.(D' .- D, tol)
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

if !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
Expand Down
24 changes: 12 additions & 12 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
lq_pullback!(
ΔA, A, LQ, ΔLQ;
tol::Real = default_pullback_gaugetol(LQ[1]),
rank_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
)

Adds the pullback from the LQ decomposition of `A` to `ΔA` given the output `LQ` and
Expand All @@ -18,17 +17,16 @@ or rows exceed `gauge_atol`, a warning will be printed.
"""
function lq_pullback!(
ΔA::AbstractMatrix, A, LQ, ΔLQ;
tol::Real = default_pullback_gaugetol(LQ[1]),
rank_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
)
# process
L, Q = LQ
m = size(L, 1)
n = size(Q, 2)
minmn = min(m, n)
Ld = diagview(L)
p = findlast(>=(rank_atol) ∘ abs, Ld)
p = @something findlast(>=(rank_atol) ∘ abs, Ld) 0

ΔL, ΔQ = ΔLQ

Expand Down Expand Up @@ -72,7 +70,7 @@ function lq_pullback!(
# Q2' * ΔQ2 as a gauge dependent quantity.
ΔQ2Q1ᴴ = ΔQ2 * Q1'
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge < tol ||
Δgauge < gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
end
Expand Down Expand Up @@ -105,7 +103,10 @@ function lq_pullback!(
end

"""
lq_null_pullback(ΔA, A, Nᴴ, ΔNᴴ)
lq_null_pullback!(
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
)

Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis
`Nᴴ` and its cotangent `ΔNᴴ` of `lq_null(A)`.
Expand All @@ -114,13 +115,12 @@ See also [`lq_pullback!`](@ref).
"""
function lq_null_pullback!(
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
tol::Real = default_pullback_gaugetol(A),
gauge_atol::Real = tol
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
)
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge < tol ||
Δgauge < gauge_atol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
Expand Down
23 changes: 12 additions & 11 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
qr_pullback!(
ΔA, A, QR, ΔQR;
tol::Real = default_pullback_gaugetol(QR[2]),
rank_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(QR[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
)

Adds the pullback from the QR decomposition of `A` to `ΔA` given the output `QR` and
Expand All @@ -18,17 +18,16 @@ and also the adjoint variables `ΔQ` and `ΔR` should have nonzero values only i
"""
function qr_pullback!(
ΔA::AbstractMatrix, A, QR, ΔQR;
tol::Real = default_pullback_gaugetol(QR[2]),
rank_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(QR[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
)
# process
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
minmn = min(m, n)
Rd = diagview(R)
p = findlast(>=(rank_atol) ∘ abs, Rd)
p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0

ΔQ, ΔR = ΔQR

Expand Down Expand Up @@ -71,7 +70,7 @@ function qr_pullback!(
# Q2' * ΔQ2 as a gauge dependent quantity.
Q1dΔQ2 = Q1' * ΔQ2
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge < tol ||
Δgauge < gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
end
Expand Down Expand Up @@ -104,7 +103,10 @@ function qr_pullback!(
end

"""
qr_null_pullback(ΔA, A, N, ΔN)
qr_null_pullback!(
ΔA::AbstractMatrix, A, N, ΔN;
gauge_atol::Real = default_pullback_gauge_atol(ΔN)
)

Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis
`N` and its cotangent `ΔN` of `qr_null(A)`.
Expand All @@ -113,13 +115,12 @@ See also [`qr_pullback!`](@ref).
"""
function qr_null_pullback!(
ΔA::AbstractMatrix, A, N, ΔN;
tol::Real = default_pullback_gaugetol(A),
gauge_atol::Real = tol
gauge_atol::Real = default_pullback_gauge_atol(ΔN)
)
if !iszerotangent(ΔN) && size(N, 2) > 0
aNᴴΔN = project_antihermitian!(N' * ΔN)
Δgauge = norm(aNᴴΔN)
Δgauge < tol ||
Δgauge < gauge_atol ||
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"

Q, R = qr_compact(A; positive = true)
Expand Down
28 changes: 12 additions & 16 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""
svd_pullback!(
ΔA, A, USVᴴ, ΔUSVᴴ, [ind];
tol::Real=default_pullback_gaugetol(USVᴴ[2]),
rank_atol::Real = tol,
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or
Expand All @@ -23,10 +22,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
"""
function svd_pullback!(
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon();
tol::Real = default_pullback_gaugetol(USVᴴ[2]),
rank_atol::Real = tol,
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

# Extract the SVD components
Expand Down Expand Up @@ -106,10 +104,9 @@ end
"""
svd_trunc_pullback!(
ΔA, A, USVᴴ, ΔUSVᴴ;
tol::Real=default_pullback_gaugetol(S),
rank_atol::Real = tol,
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

Adds the pullback from the truncated SVD of `A` to `ΔA`, given the output `USVᴴ` and the
Expand All @@ -128,10 +125,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
"""
function svd_trunc_pullback!(
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ;
tol::Real = default_pullback_gaugetol(USVᴴ[2]),
rank_atol::Real = tol,
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = 0,
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

# Extract the SVD components
Expand Down