Skip to content

Commit 9d9f163

Browse files
Improve bond truncation algorithms (#303)
* Change bond truncation `tol` to normalized bond SVD diff * Remove bond s normalization in FET * Fix formatting * Reduce repeated normalization of `s` when calculating `Δs` * Update docstring
1 parent e3f06e1 commit 9d9f163

3 files changed

Lines changed: 40 additions & 32 deletions

File tree

src/algorithms/truncation/bond_truncation.jl

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,23 @@ The truncation algorithm can be constructed from the following keyword arguments
1515
1616
* `trunc::TruncationStrategy`: SVD truncation strategy when initilizing the truncated tensors connected by the bond.
1717
* `maxiter::Int=50` : Maximal number of ALS iterations.
18-
* `tol::Float64=1e-15` : ALS converges when fidelity change between two iterations is smaller than `tol`.
18+
* `tol::Float64=1e-9` : ALS converges when the relative change in bond SVD spectrum between two iterations is smaller than `tol`.
1919
* `check_interval::Int=0` : Set number of iterations to print information. Output is suppressed when `check_interval <= 0`.
2020
"""
2121
@kwdef struct ALSTruncation
2222
trunc::TruncationStrategy
2323
maxiter::Int = 50
24-
tol::Float64 = 1.0e-15
24+
tol::Float64 = 1.0e-9
2525
check_interval::Int = 0
2626
end
2727

2828
function _als_message(
29-
iter::Int, cost::Float64, fid::Float64, Δcost::Float64, Δfid::Float64, time_elapsed::Float64,
29+
iter::Int, cost::Float64, fid::Float64, Δcost::Float64,
30+
Δfid::Float64, Δs::Float64, time_elapsed::Float64,
3031
)
3132
return @sprintf(
3233
"%5d, fid = %.8e, Δfid = %.8e, time = %.4f s\n", iter, fid, Δfid, time_elapsed
33-
) * @sprintf(" cost = %.3e, Δcost/cost0 = %.3e", cost, Δcost)
34+
) * @sprintf(" cost = %.3e, Δcost/cost0 = %.3e, |Δs| = %.4e.", cost, Δcost, Δs)
3435
end
3536

3637
"""
@@ -65,14 +66,14 @@ function bond_truncate(
6566
@assert !isdual(space(a, 2))
6667
@assert !isdual(space(b, 2))
6768
@assert codomain(benv) == domain(benv)
69+
need_flip = isdual(space(b, 1))
6870
time00 = time()
6971
verbose = (alg.check_interval > 0)
7072
a2b2 = _combine_ab(a, b)
7173
# initialize truncated a, b
7274
perm_ab = ((1, 3), (4, 2))
73-
a, s, b = svd_trunc(permute(a2b2, perm_ab); trunc = alg.trunc)
74-
s /= norm(s, Inf)
75-
a, b = absorb_s(a, s, b)
75+
a, s0, b = svd_trunc(permute(a2b2, perm_ab); trunc = alg.trunc)
76+
a, b = absorb_s(a, s0, b)
7677
#= temporarily reorder axes of a and b to
7778
1 -a/b- 2
7879
@@ -84,8 +85,8 @@ function bond_truncate(
8485
# cost function will be normalized by initial value
8586
cost00 = cost_function_als(benv, ab, a2b2)
8687
fid = fidelity(benv, ab, a2b2)
87-
cost0, fid0, Δfid = cost00, fid, 0.0
88-
verbose && @info "ALS init" * _als_message(0, cost0, fid, NaN, NaN, 0.0)
88+
cost0, fid0, Δcost, Δfid, Δs = cost00, fid, NaN, NaN, NaN
89+
verbose && @info "ALS init" * _als_message(0, cost0, fid, Δcost, Δfid, Δs, 0.0)
8990
for iter in 1:(alg.maxiter)
9091
time0 = time()
9192
#=
@@ -103,20 +104,27 @@ function bond_truncate(
103104
Rb = _tensor_Rb(benv, a)
104105
Sb = _tensor_Sb(benv, a, a2b2)
105106
b, info_b = _solve_ab(Rb, Sb, b)
107+
@debug "Bond truncation info" info_a info_b
106108
ab = _combine_ab(a, b)
107109
cost = cost_function_als(benv, ab, a2b2)
108110
fid = fidelity(benv, ab, a2b2)
111+
# TODO: replace with truncated svdvals (without calculating u, vh)
112+
_, s, _ = svd_trunc!(permute(ab, perm_ab); trunc = alg.trunc)
113+
# fidelity, cost and normalized bond-s change
114+
s_nrm = norm(s0, Inf)
115+
Δs = ((space(s) == space(s0)) ? _singular_value_distance((s, s0)) : NaN) / s_nrm
109116
Δcost = abs(cost - cost0) / cost00
110117
Δfid = abs(fid - fid0)
111-
cost0, fid0 = cost, fid
118+
cost0, fid0, s0 = cost, fid, s
112119
time1 = time()
113-
converge = (Δfid < alg.tol)
120+
converge = (Δs < alg.tol)
114121
cancel = (iter == alg.maxiter)
115122
showinfo =
116123
cancel || (verbose && (converge || iter == 1 || iter % alg.check_interval == 0))
117124
if showinfo
118125
message = _als_message(
119-
iter, cost, fid, Δcost, Δfid, time1 - ((cancel || converge) ? time00 : time0),
126+
iter, cost, fid, Δcost, Δfid, Δs,
127+
time1 - ((cancel || converge) ? time00 : time0),
120128
)
121129
if converge
122130
@info "ALS conv" * message
@@ -129,9 +137,11 @@ function bond_truncate(
129137
converge && break
130138
end
131139
a, s, b = svd_trunc!(permute(_combine_ab(a, b), perm_ab); trunc = alg.trunc)
132-
# normalize singular value spectrum
133-
s /= norm(s, Inf)
134-
return a, s, b, (; fid, Δfid)
140+
a, b = absorb_s(a, s, b)
141+
if need_flip
142+
a, s, b = flip_svd(a, s, b)
143+
end
144+
return a, s, b, (; fid, Δfid, Δs)
135145
end
136146

137147
function bond_truncate(
@@ -144,18 +154,15 @@ function bond_truncate(
144154
@assert !isdual(space(a, 2))
145155
@assert !isdual(space(b, 2))
146156
@assert codomain(benv) == domain(benv)
157+
need_flip = isdual(space(b, 1))
147158
#= initialize bond matrix using QR as `Ra Lb`
148159
149-
--- a == b --- ==> - Qa - Ra == Rb - Qb -
160+
--- a == b --- ==> - Qa Ra == Rb Qb -
150161
↓ ↓ ↓ ↓
151162
=#
152163
Qa, Ra = left_orth(a)
153164
Rb, Qb = right_orth(b)
154-
# if Qa → Ra, a twist is needed to express a as
155-
# contraction of Rb, Qb instead of Qa * Ra
156-
isdual(space(Ra, 1)) && twist!(Ra, 1)
157-
# similarly if Rb → Qb
158-
isdual(space(Qb, 1)) && twist!(Rb, 2)
165+
@assert !isdual(space(Ra, 1)) && !isdual(space(Qb, 1))
159166
@tensor b0[-1; -2] := Ra[-1 1] * Rb[1 -2]
160167
#= initialize bond environment around `Ra Lb`
161168
@@ -174,8 +181,12 @@ function bond_truncate(
174181
)
175182
# optimize bond matrix
176183
u, s, vh, info = fullenv_truncate(b0, benv2, alg)
184+
u, vh = absorb_s(u, s, vh)
177185
# truncate a, b tensors with u, s, vh
178186
@tensor a[-1 -2; -3] := Qa[-1 -2 3] * u[3 -3]
179187
@tensor b[-1; -2 -3] := vh[-1 1] * Qb[1 -2 -3]
188+
if need_flip
189+
a, s, b = flip_svd(a, s, b)
190+
end
180191
return a, s, b, info
181192
end

src/algorithms/truncation/fullenv_truncation.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The truncation algorithm can be constructed from the following keyword arguments
1515
1616
* `trunc::TruncationStrategy` : SVD truncation strategy when optimizing the new bond matrix.
1717
* `maxiter::Int=50` : Maximal number of FET iterations.
18-
* `tol::Float64=1e-15` : FET converges when fidelity change between two FET iterations is smaller than `tol`.
18+
* `tol::Float64=1e-9` : FET converges when the relative change in bond SVD spectrum between two FET iterations is smaller than `tol`.
1919
* `trunc_init::Bool=true` : Controls whether the initialization of the new bond matrix is obtained from truncated SVD of the old bond matrix.
2020
* `check_interval::Int=0` : Set number of iterations to print information. Output is suppressed when `check_interval <= 0`.
2121
@@ -26,7 +26,7 @@ The truncation algorithm can be constructed from the following keyword arguments
2626
@kwdef struct FullEnvTruncation
2727
trunc::TruncationStrategy
2828
maxiter::Int = 50
29-
tol::Float64 = 1.0e-15
29+
tol::Float64 = 1.0e-9
3030
trunc_init::Bool = true
3131
check_interval::Int = 0
3232
end
@@ -230,8 +230,6 @@ function fullenv_truncate(
230230
# initialize u, s, vh with truncated or untruncated SVD
231231
u, s, vh = svd_trunc(b0; trunc = (alg.trunc_init ? alg.trunc : notrunc()))
232232
b1 = similar(b0)
233-
# normalize `s` (bond matrices can always be normalized)
234-
s /= norm(s, Inf)
235233
s0 = deepcopy(s)
236234
Δfid, Δs, fid, fid0 = NaN, NaN, 0.0, 0.0
237235
for iter in 1:(alg.maxiter)
@@ -245,25 +243,25 @@ function fullenv_truncate(
245243
r, info_r = linsolve(Base.Fix1(*, B), p, r, 0, 1)
246244
@tensor b1[-1; -2] = u[-1; 1] * r[1 -2]
247245
u, s, vh = svd_trunc(b1; trunc = alg.trunc)
248-
s /= norm(s, Inf)
249246
# update `- l ← = - u ← s ←`
250247
@tensor l[-1 -2] := u[-1; 1] * s[1; -2]
251248
@tensor p[-1 -2] := conj(vh[-2; 2]) * benv[-1 2; 3 4] * b0[3; 4]
252249
@tensor B[-1 -2; -3 -4] := conj(vh[-2; 2]) * benv[-1 2; -3 4] * vh[-4; 4]
253250
_linearmap_twist!(p)
254251
_linearmap_twist!(B)
255252
l, info_l = linsolve(Base.Fix1(*, B), p, l, 0, 1)
253+
@debug "Bond truncation info" info_l info_r
256254
@tensor b1[-1; -2] = l[-1 1] * vh[1; -2]
257255
fid = fidelity(benv, b0, b1)
258256
u, s, vh = svd_trunc!(b1; trunc = alg.trunc)
259-
s /= norm(s, Inf)
260257
# determine convergence
261-
Δs = (space(s) == space(s0)) ? _singular_value_distance((s, s0)) : NaN
258+
s_nrm = norm(s0, Inf)
259+
Δs = ((space(s) == space(s0)) ? _singular_value_distance((s, s0)) : NaN) / s_nrm
262260
Δfid = fid - fid0
263261
s0 = deepcopy(s)
264262
fid0 = fid
265263
time1 = time()
266-
converge = (Δfid < alg.tol)
264+
converge = (Δs < alg.tol)
267265
cancel = (iter == alg.maxiter)
268266
showinfo =
269267
cancel || (verbose && (converge || iter == 1 || iter % alg.check_interval == 0))

test/bondenv/bond_truncate.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using LinearAlgebra
77
using KrylovKit
88

99
Random.seed!(0)
10-
maxiter = 500
10+
maxiter = 600
1111
check_interval = 20
1212
trunc = truncerror(; atol = 1.0e-10) & truncrank(8)
1313
Vext = Vect[FermionParity](0 => 100, 1 => 100)
@@ -35,8 +35,7 @@ for Vbondl in (Vint, Vint'), Vbondr in (Vint, Vint')
3535
)
3636
a1, ss[label], b1, info = PEPSKit.bond_truncate(a2, b2, benv, alg)
3737
@info "$label improved fidelity = $(info.fid)."
38-
display(ss[label])
39-
a1, b1 = PEPSKit.absorb_s(a1, ss[label], b1)
38+
# display(ss[label])
4039
@test info.fid PEPSKit.fidelity(benv, PEPSKit._combine_ab(a1, b1), a2b2)
4140
@test info.fid > fid0
4241
end

0 commit comments

Comments
 (0)