@@ -15,22 +15,23 @@ The truncation algorithm can be constructed from the following keyword arguments
1515
1616* `trunc::TruncationStrategy`: SVD truncation strategy when initilizing the truncated tensors connected by the bond.
1717* `maxiter::Int=50` : Maximal number of ALS iterations.
18- * `tol::Float64=1e-15 ` : ALS converges when fidelity change between two iterations is smaller than `tol`.
18+ * `tol::Float64=1e-9 ` : ALS converges when the relative change in bond SVD spectrum between two iterations is smaller than `tol`.
1919* `check_interval::Int=0` : Set number of iterations to print information. Output is suppressed when `check_interval <= 0`.
2020"""
2121@kwdef struct ALSTruncation
2222 trunc:: TruncationStrategy
2323 maxiter:: Int = 50
24- tol:: Float64 = 1.0e-15
24+ tol:: Float64 = 1.0e-9
2525 check_interval:: Int = 0
2626end
2727
2828function _als_message (
29- iter:: Int , cost:: Float64 , fid:: Float64 , Δcost:: Float64 , Δfid:: Float64 , time_elapsed:: Float64 ,
29+ iter:: Int , cost:: Float64 , fid:: Float64 , Δcost:: Float64 ,
30+ Δfid:: Float64 , Δs:: Float64 , time_elapsed:: Float64 ,
3031 )
3132 return @sprintf (
3233 " %5d, fid = %.8e, Δfid = %.8e, time = %.4f s\n " , iter, fid, Δfid, time_elapsed
33- ) * @sprintf (" cost = %.3e, Δcost/cost0 = %.3e" , cost, Δcost)
34+ ) * @sprintf (" cost = %.3e, Δcost/cost0 = %.3e, |Δs| = %.4e. " , cost, Δcost, Δs )
3435end
3536
3637"""
@@ -65,14 +66,14 @@ function bond_truncate(
6566 @assert ! isdual (space (a, 2 ))
6667 @assert ! isdual (space (b, 2 ))
6768 @assert codomain (benv) == domain (benv)
69+ need_flip = isdual (space (b, 1 ))
6870 time00 = time ()
6971 verbose = (alg. check_interval > 0 )
7072 a2b2 = _combine_ab (a, b)
7173 # initialize truncated a, b
7274 perm_ab = ((1 , 3 ), (4 , 2 ))
73- a, s, b = svd_trunc (permute (a2b2, perm_ab); trunc = alg. trunc)
74- s /= norm (s, Inf )
75- a, b = absorb_s (a, s, b)
75+ a, s0, b = svd_trunc (permute (a2b2, perm_ab); trunc = alg. trunc)
76+ a, b = absorb_s (a, s0, b)
7677 #= temporarily reorder axes of a and b to
7778 1 -a/b- 2
7879 ↓
@@ -84,8 +85,8 @@ function bond_truncate(
8485 # cost function will be normalized by initial value
8586 cost00 = cost_function_als (benv, ab, a2b2)
8687 fid = fidelity (benv, ab, a2b2)
87- cost0, fid0, Δfid = cost00, fid, 0.0
88- verbose && @info " ALS init" * _als_message (0 , cost0, fid, NaN , NaN , 0.0 )
88+ cost0, fid0, Δcost, Δfid, Δs = cost00, fid, NaN , NaN , NaN
89+ verbose && @info " ALS init" * _als_message (0 , cost0, fid, Δcost, Δfid, Δs , 0.0 )
8990 for iter in 1 : (alg. maxiter)
9091 time0 = time ()
9192 #=
@@ -103,20 +104,27 @@ function bond_truncate(
103104 Rb = _tensor_Rb (benv, a)
104105 Sb = _tensor_Sb (benv, a, a2b2)
105106 b, info_b = _solve_ab (Rb, Sb, b)
107+ @debug " Bond truncation info" info_a info_b
106108 ab = _combine_ab (a, b)
107109 cost = cost_function_als (benv, ab, a2b2)
108110 fid = fidelity (benv, ab, a2b2)
111+ # TODO : replace with truncated svdvals (without calculating u, vh)
112+ _, s, _ = svd_trunc! (permute (ab, perm_ab); trunc = alg. trunc)
113+ # fidelity, cost and normalized bond-s change
114+ s_nrm = norm (s0, Inf )
115+ Δs = ((space (s) == space (s0)) ? _singular_value_distance ((s, s0)) : NaN ) / s_nrm
109116 Δcost = abs (cost - cost0) / cost00
110117 Δfid = abs (fid - fid0)
111- cost0, fid0 = cost, fid
118+ cost0, fid0, s0 = cost, fid, s
112119 time1 = time ()
113- converge = (Δfid < alg. tol)
120+ converge = (Δs < alg. tol)
114121 cancel = (iter == alg. maxiter)
115122 showinfo =
116123 cancel || (verbose && (converge || iter == 1 || iter % alg. check_interval == 0 ))
117124 if showinfo
118125 message = _als_message (
119- iter, cost, fid, Δcost, Δfid, time1 - ((cancel || converge) ? time00 : time0),
126+ iter, cost, fid, Δcost, Δfid, Δs,
127+ time1 - ((cancel || converge) ? time00 : time0),
120128 )
121129 if converge
122130 @info " ALS conv" * message
@@ -129,9 +137,11 @@ function bond_truncate(
129137 converge && break
130138 end
131139 a, s, b = svd_trunc! (permute (_combine_ab (a, b), perm_ab); trunc = alg. trunc)
132- # normalize singular value spectrum
133- s /= norm (s, Inf )
134- return a, s, b, (; fid, Δfid)
140+ a, b = absorb_s (a, s, b)
141+ if need_flip
142+ a, s, b = flip_svd (a, s, b)
143+ end
144+ return a, s, b, (; fid, Δfid, Δs)
135145end
136146
137147function bond_truncate (
@@ -144,18 +154,15 @@ function bond_truncate(
144154 @assert ! isdual (space (a, 2 ))
145155 @assert ! isdual (space (b, 2 ))
146156 @assert codomain (benv) == domain (benv)
157+ need_flip = isdual (space (b, 1 ))
147158 #= initialize bond matrix using QR as `Ra Lb`
148159
149- --- a == b --- ==> - Qa - Ra == Rb - Qb -
160+ --- a == b --- ==> - Qa ← Ra == Rb ← Qb -
150161 ↓ ↓ ↓ ↓
151162 =#
152163 Qa, Ra = left_orth (a)
153164 Rb, Qb = right_orth (b)
154- # if Qa → Ra, a twist is needed to express a as
155- # contraction of Rb, Qb instead of Qa * Ra
156- isdual (space (Ra, 1 )) && twist! (Ra, 1 )
157- # similarly if Rb → Qb
158- isdual (space (Qb, 1 )) && twist! (Rb, 2 )
165+ @assert ! isdual (space (Ra, 1 )) && ! isdual (space (Qb, 1 ))
159166 @tensor b0[- 1 ; - 2 ] := Ra[- 1 1 ] * Rb[1 - 2 ]
160167 #= initialize bond environment around `Ra Lb`
161168
@@ -174,8 +181,12 @@ function bond_truncate(
174181 )
175182 # optimize bond matrix
176183 u, s, vh, info = fullenv_truncate (b0, benv2, alg)
184+ u, vh = absorb_s (u, s, vh)
177185 # truncate a, b tensors with u, s, vh
178186 @tensor a[- 1 - 2 ; - 3 ] := Qa[- 1 - 2 3 ] * u[3 - 3 ]
179187 @tensor b[- 1 ; - 2 - 3 ] := vh[- 1 1 ] * Qb[1 - 2 - 3 ]
188+ if need_flip
189+ a, s, b = flip_svd (a, s, b)
190+ end
180191 return a, s, b, info
181192end
0 commit comments