Skip to content

Commit c3f1429

Browse files
committed
Add FullReverse (:full) algorithm for modified TensorKit rrule
1 parent ec53668 commit c3f1429

1 file changed

Lines changed: 46 additions & 24 deletions

File tree

src/utility/svd.jl

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,29 @@ using TensorKit:
1111
_compute_truncerr
1212
const KrylovKitCRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreExt)
1313

14+
"""
15+
struct FullReverse
16+
FullReverse(; kwargs...)
17+
18+
SVD reverse-rule algorithm which uses a modified version of TensorKit's `tsvd!` reverse-rule
19+
allowing for Lorentzian broadening and output verbosity control.
20+
21+
## Keyword arguments
22+
23+
* `broadening::Float64=$(Defaults.svd_rrule_broadening)`: Lorentzian broadening amplitude for smoothing divergent term in SVD derivative in case of (pseudo) degenerate singular values.
24+
* `verbosity::Int=0`: Suppresses all output if `≤0`, print gauge dependency warnings if `1`, and always print gauge dependency if `≥2`.
25+
"""
26+
@kwdef struct FullReverse
27+
broadening::Float64 = Defaults.svd_rrule_broadening
28+
verbosity::Int = 0
29+
end
30+
1431
"""
1532
struct SVDAdjoint
1633
SVDAdjoint(; kwargs...)
1734
1835
Wrapper for a SVD algorithm `fwd_alg` with a defined reverse rule `rrule_alg`.
1936
If `isnothing(rrule_alg)`, Zygote differentiates the forward call automatically.
20-
In case of degenerate singular values, one might need a `broadening` scheme which
21-
removes the divergences from the adjoint.
2237
2338
## Keyword arguments
2439
@@ -27,16 +42,14 @@ removes the divergences from the adjoint.
2742
- `:svd`: TensorKit's wrapper for LAPACK's `_gesvd`
2843
- `:iterative`: Iterative SVD only computing the specifed number of singular values and vectors, see ['IterSVD'](@ref)
2944
* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.svd_rrule_alg))`: Reverse-rule algorithm for differentiating the SVD. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following:
30-
- `:tsvd`: Uses a modified version of TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD.
45+
- `:full`: Uses a modified version of TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD.
3146
- `:gmres`: GMRES iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.GMRES) for details
3247
- `:bicgstab`: BiCGStab iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BiCGStab) for details
3348
- `:arnoldi`: Arnoldi Krylov algorithm, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Arnoldi) for details
34-
* `broadening=$(Defaults.svd_rrule_broadening)`: Lorentzian broadening of singular value differences to stabilize the SVD gradient in case of degeneracies. Currently only implemented for `rrule_alg=:tsvd`.
3549
"""
36-
struct SVDAdjoint{F,R,B}
50+
struct SVDAdjoint{F,R}
3751
fwd_alg::F
3852
rrule_alg::R
39-
broadening::B
4053
end # Keep truncation algorithm separate to be able to specify CTMRG dependent information
4154

4255
const SVD_FWD_SYMBOLS = IdDict{Symbol,Any}(
@@ -47,10 +60,10 @@ const SVD_FWD_SYMBOLS = IdDict{Symbol,Any}(
4760
IterSVD(; alg=GKL(; tol, krylovdim), kwargs...),
4861
)
4962
const SVD_RRULE_SYMBOLS = IdDict{Symbol,Type{<:Any}}(
50-
:tsvd => Nothing, :gmres => GMRES, :bicgstab => BiCGStab, :arnoldi => Arnoldi
63+
:full => FullReverse, :gmres => GMRES, :bicgstab => BiCGStab, :arnoldi => Arnoldi
5164
)
5265

53-
function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;), broadening=nothing)
66+
function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;))
5467
# parse forward SVD algorithm
5568
fwd_algorithm = if fwd_alg isa NamedTuple
5669
fwd_kwargs = (; alg=Defaults.svd_fwd_alg, fwd_alg...) # overwrite with specified kwargs
@@ -69,6 +82,7 @@ function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;), broadening=nothing)
6982
alg=Defaults.svd_rrule_alg,
7083
tol=Defaults.svd_rrule_tol,
7184
krylovdim=Defaults.svd_rrule_min_krylovdim,
85+
broadening=Defaults.svd_rrule_broadening,
7286
verbosity=Defaults.svd_rrule_verbosity,
7387
rrule_alg...,
7488
) # overwrite with specified kwargs
@@ -78,24 +92,23 @@ function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;), broadening=nothing)
7892
rrule_type = SVD_RRULE_SYMBOLS[rrule_kwargs.alg]
7993

8094
# IterSVD is incompatible with tsvd rrule -> default to Arnoldi
81-
if rrule_type <: Nothing && fwd_algorithm isa IterSVD
95+
if rrule_type <: FullReverse && fwd_algorithm isa IterSVD
8296
rrule_type = Arnoldi
8397
end
8498

85-
if rrule_type <: Nothing
86-
broadening = isnothing(broadening) ? Defaults.svd_rrule_broadening : broadening
87-
nothing
99+
if rrule_type <: FullReverse
100+
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg=nothing, tol=0.0, krylovdim=0)) # remove `alg`, `tol` and `krylovdim` keyword arguments
88101
else
89-
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg=nothing)) # remove `alg` keyword argument
102+
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg=nothing, broadening=0.0)) # remove `alg` and `broadening` keyword arguments
90103
rrule_type <: BiCGStab &&
91104
(rrule_kwargs = Base.structdiff(rrule_kwargs, (; krylovdim=nothing))) # BiCGStab doens't take `krylovdim`
92-
rrule_type(; rrule_kwargs...)
93105
end
106+
rrule_type(; rrule_kwargs...)
94107
else
95108
rrule_alg
96109
end
97110

98-
return SVDAdjoint(fwd_algorithm, rrule_algorithm, broadening)
111+
return SVDAdjoint(fwd_algorithm, rrule_algorithm)
99112
end
100113

101114
"""
@@ -245,7 +258,7 @@ end
245258
function TensorKit._compute_svddata!(
246259
f, alg::IterSVD, trunc::Union{NoTruncation,TruncationSpace}
247260
)
248-
InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!)
261+
InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:full!)
249262
I = sectortype(f)
250263
dims = SectorDict{I,Int}()
251264

@@ -285,10 +298,10 @@ end
285298
function ChainRulesCore.rrule(
286299
::typeof(PEPSKit.tsvd!),
287300
t::AbstractTensorMap,
288-
alg::SVDAdjoint{F,R,B};
301+
alg::SVDAdjoint{F,R};
289302
trunc::TruncationScheme=TensorKit.NoTruncation(),
290303
p::Real=2,
291-
) where {F,R<:Nothing,B}
304+
) where {F,R<:FullReverse}
292305
@assert !(alg.fwd_alg isa IterSVD) "IterSVD is not compatible with tsvd reverse-rule"
293306
Ũ, S̃, Ṽ⁺, info = tsvd(t, alg; trunc, p)
294307
U, S, V⁺ = info.U_full, info.S_full, info.V_full # untruncated SVD decomposition
@@ -315,7 +328,8 @@ function ChainRulesCore.rrule(
315328
ΔSdc,
316329
ΔV⁺c;
317330
tol=pullback_tol,
318-
broadening=alg.broadening,
331+
broadening=alg.rrule_alg.broadening,
332+
verbosity=alg.rrule_alg.verbosity,
319333
)
320334
end
321335
return NoTangent(), Δt, NoTangent()
@@ -331,10 +345,10 @@ end
331345
function ChainRulesCore.rrule(
332346
::typeof(PEPSKit.tsvd!),
333347
f,
334-
alg::SVDAdjoint{F,R,B};
348+
alg::SVDAdjoint{F,R};
335349
trunc::TruncationScheme=notrunc(),
336350
p::Real=2,
337-
) where {F,R<:Union{GMRES,BiCGStab,Arnoldi},B}
351+
) where {F,R<:Union{GMRES,BiCGStab,Arnoldi}}
338352
U, S, V, info = tsvd(f, alg; trunc, p)
339353

340354
# update rrule_alg tolerance to be compatible with smallest singular value
@@ -413,6 +427,11 @@ function _lorentz_broaden(x, ε=1e-13)
413427
return x / (x^2 + ε)
414428
end
415429

430+
function _default_pullback_gaugetol(x)
431+
n = norm(x, Inf)
432+
return eps(eltype(n))^(3 / 4) * max(n, one(n))
433+
end
434+
416435
# SVD_pullback: pullback implementation for general (possibly truncated) SVD
417436
#
418437
# This is a modified version of TensorKit's pullback
@@ -439,9 +458,9 @@ function svd_pullback!(
439458
ΔU,
440459
ΔS,
441460
ΔVd;
442-
tol::Real=default_pullback_gaugetol(S),
461+
tol::Real=_default_pullback_gaugetol(S),
443462
broadening::Real=0,
444-
suppress_gauge_warning=false,
463+
verbosity=1,
445464
)
446465

447466
# Basic size checks and determination
@@ -503,8 +522,11 @@ function svd_pullback!(
503522
Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
504523
Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
505524
end
506-
(!suppress_gauge_warning && Δgauge < tol) ||
525+
if verbosity == 1 && Δgauge < tol # warn if verbosity is 1
507526
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
527+
elseif verbosity 2 # always info for debugging purposes
528+
@info "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
529+
end
508530

509531
UdΔAV =
510532
(aUΔU .+ aVΔV) .* _safe_inv.(Sp' .- Sp, tol, broadening) .+

0 commit comments

Comments
 (0)