Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ Module containing default algorithm parameter values and arguments.
* `svd_rrule_min_krylovdim=$(Defaults.svd_rrule_min_krylovdim)` : Minimal Krylov dimension of the reverse-rule algorithm (if it is a Krylov algorithm).
* `svd_rrule_verbosity=$(Defaults.svd_rrule_verbosity)` : SVD gradient output verbosity.
* `svd_rrule_alg=:$(Defaults.svd_rrule_alg)` : Reverse-rule algorithm for the SVD gradient.
- `:tsvd`: Uses TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD, see [TensorKit](https://github.com/Jutho/TensorKit.jl/blob/f9cddcf97f8d001888a26f4dce7408d5c6e2228f/ext/TensorKitChainRulesCoreExt/factorizations.jl#L3)
- `:full`: Uses a modified version of TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD, see [`FullSVDReverseRule`](@ref).
- `:gmres`: GMRES iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.GMRES) for details
- `:bicgstab`: BiCGStab iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BiCGStab) for details
- `:arnoldi`: Arnoldi Krylov algorithm, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Arnoldi) for details
* `svd_rrule_broadening=$(Defaults.svd_rrule_broadening)` : Lorentzian broadening amplitude which smoothens the divergent term in the SVD adjoint in case of (pseudo) degenerate singular values

## Projectors

Expand Down Expand Up @@ -96,7 +97,8 @@ const svd_fwd_alg = :sdd # ∈ {:sdd, :svd, :iterative}
const svd_rrule_tol = ctmrg_tol
const svd_rrule_min_krylovdim = 48
const svd_rrule_verbosity = -1
const svd_rrule_alg = :tsvd # ∈ {:tsvd, :gmres, :bicgstab, :arnoldi}
const svd_rrule_alg = :full # ∈ {:full, :gmres, :bicgstab, :arnoldi}
const svd_rrule_broadening = 1e-13
const krylovdim_factor = 1.4

# Projectors
Expand Down
2 changes: 1 addition & 1 deletion src/PEPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ include("algorithms/select_algorithm.jl")

using .Defaults: set_scheduler!
export set_scheduler!
export SVDAdjoint, IterSVD
export SVDAdjoint, FullSVDReverseRule, IterSVD
export CTMRGEnv, SequentialCTMRG, SimultaneousCTMRG
export FixedSpaceTruncation, HalfInfiniteProjector, FullInfiniteProjector
export LocalOperator
Expand Down
6 changes: 1 addition & 5 deletions src/algorithms/ctmrg/projectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,7 @@ function svd_algorithm(alg::ProjectorAlgorithm, (dir, r, c))
nothing,
)
end
return SVDAdjoint(;
fwd_alg=fix_svd,
rrule_alg=alg.svd_alg.rrule_alg,
broadening=alg.svd_alg.broadening,
)
return SVDAdjoint(; fwd_alg=fix_svd, rrule_alg=alg.svd_alg.rrule_alg)
else
return alg.svd_alg
end
Expand Down
2 changes: 0 additions & 2 deletions src/algorithms/optimization/fixed_point_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ function _fix_svd_algorithm(alg::SVDAdjoint, signs, info)
return SVDAdjoint(;
fwd_alg=FixedSVD(U_fixed, info.S, V_fixed, U_full_fixed, info.S_full, V_full_fixed),
rrule_alg=alg.rrule_alg,
broadening=alg.broadening,
)
end
function _fix_svd_algorithm(alg::SVDAdjoint{F}, signs, info) where {F<:IterSVD}
Expand All @@ -272,7 +271,6 @@ function _fix_svd_algorithm(alg::SVDAdjoint{F}, signs, info) where {F<:IterSVD}
return SVDAdjoint(;
fwd_alg=FixedSVD(U_fixed, info.S, V_fixed, nothing, nothing, nothing),
rrule_alg=alg.rrule_alg,
broadening=alg.broadening,
)
end

Expand Down
240 changes: 218 additions & 22 deletions src/utility/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,31 @@ using TensorKit:
_create_svdtensors,
_compute_truncdim,
_compute_truncerr
const TensorKitCRCExt = Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)
const KrylovKitCRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreExt)

"""
struct FullSVDReverseRule
FullSVDReverseRule(; kwargs...)

SVD reverse-rule algorithm which uses a modified version of TensorKit's `tsvd!` reverse-rule
allowing for Lorentzian broadening and output verbosity control.

## Keyword arguments

* `broadening::Float64=$(Defaults.svd_rrule_broadening)`: Lorentzian broadening amplitude for smoothing divergent term in SVD derivative in case of (pseudo) degenerate singular values.
* `verbosity::Int=0`: Suppresses all output if `≤0`, print gauge dependency warnings if `1`, and always print gauge dependency if `≥2`.
"""
@kwdef struct FullSVDReverseRule
broadening::Float64 = Defaults.svd_rrule_broadening
verbosity::Int = 0
end

"""
struct SVDAdjoint
SVDAdjoint(; kwargs...)

Wrapper for a SVD algorithm `fwd_alg` with a defined reverse rule `rrule_alg`.
If `isnothing(rrule_alg)`, Zygote differentiates the forward call automatically.
In case of degenerate singular values, one might need a `broadening` scheme which
removes the divergences from the adjoint.

## Keyword arguments

Expand All @@ -28,16 +42,14 @@ removes the divergences from the adjoint.
- `:svd`: TensorKit's wrapper for LAPACK's `_gesvd`
- `:iterative`: Iterative SVD only computing the specifed number of singular values and vectors, see ['IterSVD'](@ref)
* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.svd_rrule_alg))`: Reverse-rule algorithm for differentiating the SVD. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following:
- `:tsvd`: Uses TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD, see [TensorKit](https://github.com/Jutho/TensorKit.jl/blob/f9cddcf97f8d001888a26f4dce7408d5c6e2228f/ext/TensorKitChainRulesCoreExt/factorizations.jl#L3)
- `:full`: Uses a modified version of TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD, see [`FullSVDReverseRule`](@ref).
- `:gmres`: GMRES iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.GMRES) for details
- `:bicgstab`: BiCGStab iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BiCGStab) for details
- `:arnoldi`: Arnoldi Krylov algorithm, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Arnoldi) for details
* `broadening=nothing`: Broadening of singular value differences to stabilize the SVD gradient. Currently not implemented.
"""
struct SVDAdjoint{F,R,B}
struct SVDAdjoint{F,R}
fwd_alg::F
rrule_alg::R
broadening::B
end # Keep truncation algorithm separate to be able to specify CTMRG dependent information

const SVD_FWD_SYMBOLS = IdDict{Symbol,Any}(
Expand All @@ -48,10 +60,10 @@ const SVD_FWD_SYMBOLS = IdDict{Symbol,Any}(
IterSVD(; alg=GKL(; tol, krylovdim), kwargs...),
)
const SVD_RRULE_SYMBOLS = IdDict{Symbol,Type{<:Any}}(
:tsvd => Nothing, :gmres => GMRES, :bicgstab => BiCGStab, :arnoldi => Arnoldi
:full => FullSVDReverseRule, :gmres => GMRES, :bicgstab => BiCGStab, :arnoldi => Arnoldi
)

function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;), broadening=nothing)
function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;))
# parse forward SVD algorithm
fwd_algorithm = if fwd_alg isa NamedTuple
fwd_kwargs = (; alg=Defaults.svd_fwd_alg, fwd_alg...) # overwrite with specified kwargs
Expand All @@ -70,6 +82,7 @@ function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;), broadening=nothing)
alg=Defaults.svd_rrule_alg,
tol=Defaults.svd_rrule_tol,
krylovdim=Defaults.svd_rrule_min_krylovdim,
broadening=Defaults.svd_rrule_broadening,
verbosity=Defaults.svd_rrule_verbosity,
rrule_alg...,
) # overwrite with specified kwargs
Expand All @@ -79,23 +92,23 @@ function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;), broadening=nothing)
rrule_type = SVD_RRULE_SYMBOLS[rrule_kwargs.alg]

# IterSVD is incompatible with tsvd rrule -> default to Arnoldi
if rrule_type <: Nothing && fwd_algorithm isa IterSVD
if rrule_type <: FullSVDReverseRule && fwd_algorithm isa IterSVD
rrule_type = Arnoldi
end

if rrule_type <: Nothing
nothing
if rrule_type <: FullSVDReverseRule
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg=nothing, tol=0.0, krylovdim=0)) # remove `alg`, `tol` and `krylovdim` keyword arguments
else
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg=nothing)) # remove `alg` keyword argument
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg=nothing, broadening=0.0)) # remove `alg` and `broadening` keyword arguments
rrule_type <: BiCGStab &&
(rrule_kwargs = Base.structdiff(rrule_kwargs, (; krylovdim=nothing))) # BiCGStab doens't take `krylovdim`
rrule_type(; rrule_kwargs...)
end
rrule_type(; rrule_kwargs...)
else
rrule_alg
end

return SVDAdjoint(fwd_algorithm, rrule_algorithm, broadening)
return SVDAdjoint(fwd_algorithm, rrule_algorithm)
end

"""
Expand Down Expand Up @@ -245,7 +258,7 @@ end
function TensorKit._compute_svddata!(
f, alg::IterSVD, trunc::Union{NoTruncation,TruncationSpace}
)
InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!)
InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:full!)
I = sectortype(f)
dims = SectorDict{I,Int}()

Expand Down Expand Up @@ -285,10 +298,10 @@ end
function ChainRulesCore.rrule(
::typeof(PEPSKit.tsvd!),
t::AbstractTensorMap,
alg::SVDAdjoint{F,R,B};
alg::SVDAdjoint{F,R};
trunc::TruncationScheme=TensorKit.NoTruncation(),
p::Real=2,
) where {F,R<:Nothing,B}
) where {F,R<:FullSVDReverseRule}
@assert !(alg.fwd_alg isa IterSVD) "IterSVD is not compatible with tsvd reverse-rule"
Ũ, S̃, Ṽ⁺, info = tsvd(t, alg; trunc, p)
U, S, V⁺ = info.U_full, info.S_full, info.V_full # untruncated SVD decomposition
Expand All @@ -306,8 +319,17 @@ function ChainRulesCore.rrule(
ΔUc, ΔSc, ΔV⁺c = block(ΔU, c), block(ΔS, c), block(ΔV⁺, c)
Sdc = view(Sc, diagind(Sc))
ΔSdc = (ΔSc isa AbstractZero) ? ΔSc : view(ΔSc, diagind(ΔSc))
TensorKitCRCExt.svd_pullback!(
b, Uc, Sdc, V⁺c, ΔUc, ΔSdc, ΔV⁺c; tol=pullback_tol
svd_pullback!(
b,
Uc,
Sdc,
V⁺c,
ΔUc,
ΔSdc,
ΔV⁺c;
tol=pullback_tol,
broadening=alg.rrule_alg.broadening,
verbosity=alg.rrule_alg.verbosity,
)
end
return NoTangent(), Δt, NoTangent()
Expand All @@ -323,10 +345,10 @@ end
function ChainRulesCore.rrule(
::typeof(PEPSKit.tsvd!),
f,
alg::SVDAdjoint{F,R,B};
alg::SVDAdjoint{F,R};
trunc::TruncationScheme=notrunc(),
p::Real=2,
) where {F,R<:Union{GMRES,BiCGStab,Arnoldi},B}
) where {F,R<:Union{GMRES,BiCGStab,Arnoldi}}
U, S, V, info = tsvd(f, alg; trunc, p)

# update rrule_alg tolerance to be compatible with smallest singular value
Expand Down Expand Up @@ -389,3 +411,177 @@ function ChainRulesCore.rrule(

return (U, S, V, info), tsvd!_itersvd_pullback
end

# scalar inverses with a cutoff tolerance and Lorentzian broadening
function _safe_inv(x, tol, ε=0)
if abs(x) < tol
return zero(x)
else
return iszero(ε) ? inv(x) : _lorentz_broaden(x, ε)
end
end

# Lorentzian broadening for divergent term in SVD rrule, see
# https://journals.aps.org/prresearch/abstract/10.1103/PhysRevResearch.7.013237
function _lorentz_broaden(x, ε=eps(real(scalartype(x)))^(3 / 4))
return x / (x^2 + ε)
end

function _default_pullback_gaugetol(x)
n = norm(x, Inf)
return eps(eltype(n))^(3 / 4) * max(n, one(n))
end

# SVD_pullback: pullback implementation for general (possibly truncated) SVD
#
# This is a modified version of TensorKit's pullback
# https://github.com/Jutho/TensorKit.jl/blob/fa1551472ac74d7f2a61bdb2135cf418c8c53378/ext/TensorKitChainRulesCoreExt/factorizations.jl#L190)
# with support for Lorentzian broadening and improved verbosity control
#
# Arguments are U, S and Vd of full (non-truncated, but still thin) SVD, as well as
# cotangent ΔU, ΔS, ΔVd variables of truncated SVD
#
# Checks whether the cotangent variables are such that they would couple to gauge-dependent
# degrees of freedom (phases of singular vectors), and prints a warning if this is the case
#
# An implementation that only uses U, S, and Vd from truncated SVD is also possible, but
# requires solving a Sylvester equation, which does not seem to be supported on GPUs.
#
# Other implementation considerations for GPU compatibility:
# no scalar indexing, lots of broadcasting and views
#
function svd_pullback!(
ΔA::AbstractMatrix,
U::AbstractMatrix,
S::AbstractVector,
Vd::AbstractMatrix,
ΔU,
ΔS,
ΔVd;
tol::Real=_default_pullback_gaugetol(S),
broadening::Real=0,
verbosity=1,
)

# Basic size checks and determination
m, n = size(U, 1), size(Vd, 2)
size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch())
p = -1
if !(ΔU isa AbstractZero)
m == size(ΔU, 1) || throw(DimensionMismatch())
p = size(ΔU, 2)
end
if !(ΔVd isa AbstractZero)
n == size(ΔVd, 2) || throw(DimensionMismatch())
if p == -1
p = size(ΔVd, 1)
else
p == size(ΔVd, 1) || throw(DimensionMismatch())
end
end
if !(ΔS isa AbstractZero)
if p == -1
p = length(ΔS)
else
p == length(ΔS) || throw(DimensionMismatch())
end
end
Up = view(U, :, 1:p)
Vp = view(Vd, 1:p, :)'
Sp = view(S, 1:p)

# rank
r = searchsortedlast(S, tol; rev=true)

# compute antihermitian part of projection of ΔU and ΔV onto U and V
# also already subtract this projection from ΔU and ΔV
if !(ΔU isa AbstractZero)
UΔU = Up' * ΔU
aUΔU = rmul!(UΔU - UΔU', 1 / 2)
if m > p
ΔU -= Up * UΔU
end
else
aUΔU = fill!(similar(U, (p, p)), 0)
end
if !(ΔVd isa AbstractZero)
VΔV = Vp' * ΔVd'
aVΔV = rmul!(VΔV - VΔV', 1 / 2)
if n > p
ΔVd -= VΔV' * Vp'
end
else
aVΔV = fill!(similar(Vd, (p, p)), 0)
end

# check whether cotangents arise from gauge-invariance objective function
mask = abs.(Sp' .- Sp) .< tol
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
if p > r
rprange = (r + 1):p
Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
end
if verbosity == 1 && Δgauge > tol # warn if verbosity is 1
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
elseif verbosity ≥ 2 # always info for debugging purposes
@info "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end

UdΔAV =
(aUΔU .+ aVΔV) .* _safe_inv.(Sp' .- Sp, tol, broadening) .+
(aUΔU .- aVΔV) .* _safe_inv.(Sp' .+ Sp, tol)
if !(ΔS isa ZeroTangent)
UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
# in principle, ΔS is real, but maybe not if coming from an anyonic tensor
end
mul!(ΔA, Up, UdΔAV * Vp')

if r > p # contribution from truncation
Ur = view(U, :, (p + 1):r)
Vr = view(Vd, (p + 1):r, :)'
Sr = view(S, (p + 1):r)

if !(ΔU isa AbstractZero)
UrΔU = Ur' * ΔU
if m > r
ΔU -= Ur * UrΔU # subtract this part from ΔU
end
else
UrΔU = fill!(similar(U, (r - p, p)), 0)
end
if !(ΔVd isa AbstractZero)
VrΔV = Vr' * ΔVd'
if n > r
ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
end
else
VrΔV = fill!(similar(Vd, (r - p, p)), 0)
end

X =
(1//2) .* (
(UrΔU .+ VrΔV) .* _safe_inv.(Sp' .- Sr, tol, broadening) .+
(UrΔU .- VrΔV) .* _safe_inv.(Sp' .+ Sr, tol)
)
Y =
(1//2) .* (
(UrΔU .+ VrΔV) .* _safe_inv.(Sp' .- Sr, tol, broadening) .-
(UrΔU .- VrΔV) .* _safe_inv.(Sp' .+ Sr, tol)
)

# ΔA += Ur * X * Vp' + Up * Y' * Vr'
mul!(ΔA, Ur, X * Vp', 1, 1)
mul!(ΔA, Up * Y', Vr', 1, 1)
end

if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)]
# ΔA += (ΔU .* _safe_inv.(Sp', tol)) * Vp'
mul!(ΔA, ΔU .* _safe_inv.(Sp', tol), Vp', 1, 1)
end
if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)]
# ΔA += U * (_safe_inv.(Sp, tol) .* ΔVd)
mul!(ΔA, Up, _safe_inv.(Sp, tol) .* ΔVd, 1, 1)
end
return ΔA
end
Loading