Skip to content

Commit 91eed29

Browse files
Truncate MPS cluster with standard virtual arrows (#309)
* Truncate clusters with standard virtual arrows * Update tests accordingly * Rename `revs` to `flips` * Minor fixes
1 parent e175131 commit 91eed29

4 files changed

Lines changed: 101 additions & 122 deletions

File tree

src/algorithms/time_evolution/simpleupdate3site.jl

Lines changed: 78 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,55 @@
11
#=
22
# Mixed canonical form of an open boundary MPS
33
```
4-
|ψ⟩ = M[1]---...---M[N]
4+
|ψ⟩ = M[1]-←-...--M[N]
55
↓ ↓
66
```
7-
The bond between `M[n]` and `M[n+1]` is called
8-
the n-th (internal) bond (n = 1, ..., N - 1).
7+
For convenience, assume all virtual arrows are ←.
98
109
We perform QR and LQ decompositions: starting from
1110
```
12-
M[1]--- = Qa[1]-*-R[1]---
11+
M[1]-←- = Qa[1]--R[1]--
1312
↓ ↓
1413
15-
---M[N] = --L[N-1]-*-Qb[N]
14+
-←-M[N] = --L[N-1]--Qb[N]
1615
↓ ↓
1716
```
1817
we successively calculate
1918
```
20-
---R[n-1]---M[n]--- = ---Qa[n]-*-R[n]---- (n = 2, ..., N - 1)
21-
↓ ↓
19+
-←-R[n-1]-←-M[n]-←- = -←-Qa[n]--R[n]---- (n = 2, ..., N - 1)
20+
2221
23-
--M[n+1]-*-L[n+1]-- = ---L[n]-*-Qb[n+1]-- (n = N - 2, ..., 1)
22+
--M[n+1]--L[n+1]-- = -←-L[n]--Qb[n+1]-- (n = N - 2, ..., 1)
2423
↓ ↓
2524
```
26-
Here `-*-` on the bond means a twist should be applied if
27-
the codomain of R[n], Qb[n+1], L[n+1] is a dual space.
28-
29-
Here we make the `isdual` of the domain and codomain
30-
of `R[n]` and `L[n]` for a given `n` the same.
3125
3226
For each bond (n = 1, ..., N - 1), we perform SVD
3327
```
34-
R[n] L[n] = U[n]-←-s[n]-←-V†[n] (n = 1, ..., N - 1)
28+
R[n] L[n] = U[n] s[n] V†[n] (n = 1, ..., N - 1)
3529
```
3630
Then we define the projectors together with the Schmidt weight
3731
```
38-
---Pa[n]-←- = L[n] V[n]-←-(1/√s[n])-←-
39-
-←-Pb[n]--- = -←-(1/√s[n])-←-U†[n] R[n]
32+
--Pa[n]-←- = L[n] V[n]-←-(1/√s[n])-←-
33+
-←-Pb[n]-- = -←-(1/√s[n])-←-U†[n] R[n]
4034
```
41-
Since the domain and the codomain of R[n] and L[n] has the same `isdual`,
42-
the product `Pa Pb` is the identity operator:
35+
The product `Pa Pb` is the identity operator:
4336
```
4437
Pa[n]-←-Pb[n] = L[n] (R[n] L[n])⁻¹ R[n] = 1
4538
```
46-
The `isdual` for the domain and codomain of `Pa[n] Pb[n]` are also the same.
47-
48-
Note that when `Pa[n] Pb[n]` is identity on a dual space,
49-
a twist should be applied to put it to the bond.
5039
5140
The canonical form is then defined by
5241
```
53-
-←-M̃[n]-←- = -←-Pb[n-1]---M[n]-*-Pa[n]-←-
42+
-←-M̃[n]-←- = -←-Pb[n-1]-←-M[n]--Pa[n]-←-
5443
↓ ↓
5544
```
56-
`-*-` means a twist should be applied if the codomain of `Pa[n]` is a dual space.
5745
5846
Note that
5947
```
6048
M̃[n]
61-
= 1/√s[n-1]←-U†[n-1](R[n-1]--M[n])-*-L[n] V[n]←-1/√s[n]
62-
= 1/√s[n-1]←-U†[n-1] Qa[n] (R[n]-*-L[n]) V[n]←-1/√s[n]
63-
= 1/√s[n-1]←-U†[n-1] Qa[n] U[n]←-s[n]←-(V†[n] V[n])←-1/√s[n]
64-
= 1/√s[n-1]←-U†[n-1] Qa[n] U[n]←-√s[n]
49+
= 1/√s[n-1] U†[n-1] (R[n-1] M[n]) L[n] V[n] 1/√s[n]
50+
= 1/√s[n-1] U†[n-1] Qa[n] (R[n] L[n]) V[n] 1/√s[n]
51+
= 1/√s[n-1] U†[n-1] Qa[n] U[n] s[n] (V†[n] V[n]) 1/√s[n]
52+
= 1/√s[n-1] U†[n-1] Qa[n] U[n] √s[n]
6553
```
6654
Then `M̃[n]` (n = 1, ..., N - 1) satisfies the (generalized) left-orthogonal condition
6755
```
@@ -74,13 +62,12 @@ Then `M̃[n]` (n = 1, ..., N - 1) satisfies the (generalized) left-orthogonal co
7462
Similarly, we can express M̃ using Qb
7563
```
7664
M̃[n]
77-
= 1/√s[n-1]←-U†[n-1] R[n-1]--(M[n]-*-L[n]) V[n]←-1/√s[n]
78-
= 1/√s[n-1]←-U†[n-1] (R[n-1]--L[n-1]) Qb[n] V[n]←-1/√s[n]
79-
= -*-1/√s[n-1]←-U†[n-1] (R[n-1]-*-L[n-1]) Qb[n] V[n]←-1/√s[n]
80-
= -*-1/√s[n-1]←-(U†[n-1] U[n-1])←-s[n-1]←-V†[n-1] * Qb[n] V[n]←-1/√s[n]
81-
= -*-√s[n-1]←-V†[n-1] Qb[n] V[n]←-1/√s[n]
65+
= 1/√s[n-1] U†[n-1] R[n-1] (M[n] L[n]) V[n] 1/√s[n]
66+
= 1/√s[n-1] U†[n-1] (R[n-1] L[n-1]) Qb[n] V[n] 1/√s[n]
67+
= 1/√s[n-1] U†[n-1] (R[n-1] L[n-1]) Qb[n] V[n] 1/√s[n]
68+
= 1/√s[n-1] (U†[n-1] U[n-1]) s[n-1] V†[n-1] Qb[n] V[n] 1/√s[n]
69+
= √s[n-1] V†[n-1] Qb[n] V[n] 1/√s[n]
8270
```
83-
Here `-*-` is a twist to be applied when the codomain of `L[n-1]` is a dual space.
8471
Then `M̃[n]` (n = 2, ..., N) satisfies the (generalized) right-orthogonal condition
8572
```
8673
-←-M̃[n]-←┐ 1 -←-┐
@@ -96,7 +83,7 @@ Here `-*-` is the twist on the physical axis.
9683
Suppose we want to truncate the bond between
9784
the n-th and the (n+1)-th sites such that the truncated state
9885
```
99-
|ψ̃⟩ = M[1]---...---M̃[n]---M̃[n+1]---...---M[N]
86+
|ψ̃⟩ = M[1]-←-...-←-M̃[n]-←-M̃[n+1]-←-...--M[N]
10087
↓ ↓ ↓ ↓
10188
```
10289
maximizes the fidelity
@@ -127,61 +114,56 @@ Then the fidelity is just
127114
"""
128115
Perform QR decomposition through a PEPS tensor
129116
```
130-
╱ ╱
131-
--R0----M--- → ---Q--*-R1--
132-
╱ | ╱ |
117+
118+
--R0-←-M-←- => ---Q-←-R1--
119+
╱ | ╱ |
133120
```
134121
"""
135122
function qr_through(
136123
R0::MPSBondTensor, M::GenericMPSTensor{S, 4}; normalize::Bool = true
137124
) where {S <: ElementarySpace}
125+
@assert !isdual(codomain(R0, 1))
126+
@assert !isdual(domain(M, 1)) && !isdual(codomain(M, 1))
138127
@tensor A[-1 -2 -3 -4; -5] := R0[-1; 1] * M[1 -2 -3 -4; -5]
139128
_, r = left_orth!(A)
140-
if isdual(domain(r, 1)) != isdual(codomain(r, 1))
141-
r = flip(r, 1)
142-
end
143129
normalize && normalize!(r, Inf)
144130
return r
145131
end
132+
# for `M` at the left end of the MPS
146133
function qr_through(
147134
::Nothing, M::GenericMPSTensor{S, 4}; normalize::Bool = true
148135
) where {S <: ElementarySpace}
136+
@assert !isdual(domain(M, 1))
149137
_, r = left_orth(M)
150-
if isdual(domain(r, 1)) != isdual(codomain(r, 1))
151-
r = flip(r, 1)
152-
end
153138
normalize && normalize!(r, Inf)
154139
return r
155140
end
156141

157142
"""
158143
Perform LQ decomposition through a tensor
159144
```
160-
╱ ╱
161-
--L0-*--Q--- ← ---M--*-L1--
162-
╱ | ╱ |
145+
146+
--L0-←-Q-←- <= -←-M-←-L1--
147+
╱ | ╱ |
163148
```
164149
"""
165150
function lq_through(
166151
M::GenericMPSTensor{S, 4}, L1::MPSBondTensor; normalize::Bool = true
167152
) where {S <: ElementarySpace}
168-
@plansor A[-1 -2 -3 -4; -5] := M[-1 -2 -3 -4; 1] * L1[1; -5]
169-
A = permute(A, ((1,), (2, 3, 4, 5)))
153+
@assert !isdual(domain(L1, 1))
154+
@assert !isdual(codomain(M, 1)) && !isdual(domain(M, 1))
155+
@tensor A[-1; -2 -3 -4 -5] := M[-1 -2 -3 -4; 1] * L1[1; -5]
170156
l, _ = right_orth!(A)
171-
if isdual(domain(l, 1)) != isdual(codomain(l, 1))
172-
l = flip(l, 2)
173-
end
174157
normalize && normalize!(l, Inf)
175158
return l
176159
end
160+
# for `M` at the right end of the MPS
177161
function lq_through(
178162
M::GenericMPSTensor{S, 4}, ::Nothing; normalize::Bool = true
179163
) where {S <: ElementarySpace}
164+
@assert !isdual(codomain(M, 1))
180165
A = permute(M, ((1,), (2, 3, 4, 5)))
181166
l, _ = right_orth!(A)
182-
if isdual(domain(l, 1)) != isdual(codomain(l, 1))
183-
l = flip(l, 2)
184-
end
185167
normalize && normalize!(l, Inf)
186168
return l
187169
end
@@ -214,39 +196,32 @@ such that the contraction of `Pa`, `s`, `Pb` is identity when `trunc = notrunc`,
214196
215197
The arrows between `Pa`, `s`, `Pb` are
216198
```
217-
rev = false: - Pa --←-- Pb -
218-
1 ← s ← 2
219-
220-
rev = true: - Pa --→-- Pb -
221-
2 → s → 1
199+
- Pa --←-- Pb -
200+
1 ← s ← 2
222201
```
223202
"""
224203
function _proj_from_RL(
225204
r::MPSBondTensor, l::MPSBondTensor;
226-
trunc::TruncationStrategy = notrunc(), rev::Bool = false,
205+
trunc::TruncationStrategy = notrunc()
227206
)
207+
@assert isdual(domain(r, 1)) == isdual(codomain(r, 1)) == false
208+
@assert isdual(domain(l, 1)) == isdual(codomain(l, 1)) == false
228209
rl = r * l
229-
@assert isdual(domain(rl, 1)) == isdual(codomain(rl, 1))
230210

231211
# TODO: replace this with actual truncation error once TensorKit is updated
232212
uc, sc, vhc = svd_compact!(rl)
233213
u, s, vh, ϵ = _truncate_compact((uc, sc, vhc), trunc)
234214

235-
sinv = PEPSKit.sdiag_pow(s, -1 / 2)
215+
sinv = sdiag_pow(s, -1 / 2)
236216
Pa, Pb = l * vh' * sinv, sinv * u' * r
237-
if rev
238-
Pa, s, Pb = flip_svd(Pa, s, Pb)
239-
end
240217
return Pa, s, Pb, ϵ
241218
end
242219

243220
"""
244221
Given a cluster `Ms` and the pre-calculated `R`, `L` bond matrices,
245222
find all projectors `Pa`, `Pb` and Schmidt weights `wts` on internal bonds.
246223
"""
247-
function _get_allprojs(
248-
Ms, Rs, Ls, truncs::Vector{E}, revs::Vector{Bool}
249-
) where {E <: TruncationStrategy}
224+
function _get_allprojs(Ms, Rs, Ls, truncs::Vector{E}) where {E <: TruncationStrategy}
250225
N = length(Ms)
251226
@assert length(truncs) == N - 1
252227
projs_errs = map(1:(N - 1)) do i
@@ -256,7 +231,7 @@ function _get_allprojs(
256231
else
257232
truncs[i]
258233
end
259-
return _proj_from_RL(Rs[i], Ls[i]; trunc, rev = revs[i])
234+
return _proj_from_RL(Rs[i], Ls[i]; trunc)
260235
end
261236
Pas = map(Base.Fix2(getindex, 1), projs_errs)
262237
wts = map(Base.Fix2(getindex, 2), projs_errs)
@@ -266,18 +241,34 @@ function _get_allprojs(
266241
return Pas, Pbs, wts, ϵs
267242
end
268243

244+
"""
245+
Flip the virtual arrows in the MPS `Ms`
246+
"""
247+
function _flip_virtuals!(
248+
Ms::Vector{T}, flips::Vector{Bool}; inv::Bool = false
249+
) where {T <: GenericMPSTensor}
250+
@assert length(flips) == length(Ms) - 1
251+
for (n, flip) in enumerate(flips)
252+
!flip && continue
253+
M1, M2 = Ms[n], Ms[n + 1]
254+
Ms[n] = TensorKit.flip(M1, numind(M1); inv)
255+
Ms[n + 1] = TensorKit.flip(M2, 1; inv)
256+
end
257+
return Ms
258+
end
259+
269260
"""
270261
Find projectors to truncate internal bonds of the cluster `Ms`.
271262
"""
272263
function _cluster_truncate!(
273-
Ms::Vector{T}, truncs::Vector{E}, revs::Vector{Bool}
264+
Ms::Vector{T}, truncs::Vector{E}
274265
) where {T <: GenericMPSTensor{<:ElementarySpace, 4}, E <: TruncationStrategy}
275266
Rs, Ls = _get_allRLs(Ms)
276-
Pas, Pbs, wts, ϵs = _get_allprojs(Ms, Rs, Ls, truncs, revs)
267+
Pas, Pbs, wts, ϵs = _get_allprojs(Ms, Rs, Ls, truncs)
277268
# apply projectors
278269
# M1 -- (Pa1,wt1,Pb1) -- M2 -- (Pa2,wt2,Pb2) -- M3
279270
for (i, (Pa, Pb)) in enumerate(zip(Pas, Pbs))
280-
@plansor (Ms[i])[-1 -2 -3 -4; -5] := (Ms[i])[-1 -2 -3 -4; 1] * Pa[1; -5]
271+
@tensor (Ms[i])[-1 -2 -3 -4; -5] := (Ms[i])[-1 -2 -3 -4; 1] * Pa[1; -5]
281272
@tensor (Ms[i + 1])[-1 -2 -3 -4; -5] := Pb[-1; 1] * (Ms[i + 1])[1 -2 -3 -4; -5]
282273
end
283274
return wts, ϵs, Pas, Pbs
@@ -290,7 +281,7 @@ When `gate_ax` is 1 or 2, the gate acts from the physical codomain or domain sid
290281
e.g. Cluster in PEPS with `gate_ax = 1`:
291282
```
292283
╱ ╱ ╱
293-
--- M1 ---- M2 ---- M3 ---
284+
--- M1 -←-- M2 --- M3 ---
294285
╱ | ╱ | ╱ |
295286
↓ ↓ ↓
296287
g1 -←-- g2 -←-- g3
@@ -314,15 +305,13 @@ function _apply_gatempo!(
314305
@assert length(Ms) == length(gs)
315306
@assert gate_ax == 1
316307
@assert all(!isdual(space(g, 1)) for g in gs[2:end])
308+
@assert all(!isdual(space(M, 1)) for M in Ms[2:end])
317309
# fusers to merge axes on bonds in the gate-cluster product
318310
# M1 == f1† -- f1 == M2 == f2† -- f2 == M3
319311
fusers = map(Ms[2:end], gs[2:end]) do M, g
320312
V1, V2 = space(M, 1), space(g, 1)
321313
return isomorphism(fuse(V1, V2) V1 V2)
322314
end
323-
for (i, M) in enumerate(Ms[2:end])
324-
isdual(space(M, 1)) && twist!(Ms[i + 1], 1)
325-
end
326315
#= gate on codomain of PEPS
327316
-3 -3 -3
328317
╱ ┌-┐ ┌-┐ ╱ ┌-┐ ┌-┐ ╱
@@ -354,15 +343,13 @@ function _apply_gatempo!(
354343
@assert length(Ms) == length(gs)
355344
@assert gate_ax == 1 || gate_ax == 2
356345
@assert all(!isdual(space(g, 1)) for g in gs[2:end])
346+
@assert all(!isdual(space(M, 1)) for M in Ms[2:end])
357347
# fusers to merge axes on bonds in the gate-cluster product
358348
# M1 == f1† -- f1 == M2 == f2† -- f2 == M3
359349
fusers = map(Ms[2:end], gs[2:end]) do M, g
360350
V1, V2 = space(M, 1), space(g, 1)
361351
return isomorphism(fuse(V1, V2) V1 V2)
362352
end
363-
for (i, M) in enumerate(Ms[2:end])
364-
isdual(space(M, 1)) && twist!(Ms[i + 1], 1)
365-
end
366353
#= gate on codomain of PEPO (gate_ax = 1)
367354
368355
-3 -4 -3 -4 -3 -4
@@ -454,6 +441,7 @@ function get_3site_se(state::InfiniteState, env::SUWeight, row::Int, col::Int)
454441
perms_se = isa(state, InfinitePEPS) ? perms_se_peps : perms_se_pepo
455442
Ms = map(zip(coords_se, perms_se, openaxs_se)) do (coord, perm, openaxs)
456443
M = absorb_weight(state.A[CartesianIndex(coord)], env, coord[1], coord[2], openaxs)
444+
# permute to MPS axes order
457445
return permute(M, perm)
458446
end
459447
return Ms
@@ -469,29 +457,34 @@ function _su3site_se!(
469457
rm1, cp1 = _prev(row, Nr), _next(col, Nc)
470458
# southwest 3-site cluster and arrow direction within it
471459
Ms = get_3site_se(state, env, row, col)
472-
revs = [isdual(space(M, 1)) for M in Ms[2:end]]
460+
flips = [isdual(space(M, 1)) for M in Ms[2:end]]
473461
Vphys = [codomain(M, 2) for M in Ms]
474462
normalize!.(Ms, Inf)
463+
# flip virtual arrows in `Ms` to ←
464+
_flip_virtuals!(Ms, flips)
475465
# sites in the cluster
476466
coords = ((row, col), (row, cp1), (rm1, cp1))
477467
# weights in the cluster
478468
wt_idxs = ((1, row, col), (2, row, cp1))
479-
# apply gate MPOs
469+
# apply gate MPOs and truncate
480470
gate_axs = purified ? (1:1) : (1:2)
481471
ϵs = nothing
482472
for gate_ax in gate_axs
483473
_apply_gatempo!(Ms, gs; gate_ax)
484474
if isa(state, InfinitePEPO)
485475
Ms = [first(_fuse_physicalspaces(M)) for M in Ms]
486476
end
487-
wts, ϵs, = _cluster_truncate!(Ms, truncs, revs)
477+
wts, ϵs, = _cluster_truncate!(Ms, truncs)
488478
if isa(state, InfinitePEPO)
489479
Ms = [first(_unfuse_physicalspace(M, Vphy)) for (M, Vphy) in zip(Ms, Vphys)]
490480
end
491-
for (wt, wt_idx) in zip(wts, wt_idxs)
492-
env[CartesianIndex(wt_idx)] = normalize(wt, Inf)
481+
for (wt, wt_idx, flip) in zip(wts, wt_idxs, flips)
482+
env[CartesianIndex(wt_idx)] = normalize(flip ? _flip_s(wt) : wt, Inf)
493483
end
494484
end
485+
# restore virtual arrows in `Ms`
486+
_flip_virtuals!(Ms, flips)
487+
# update `state` from `Ms`
495488
invperms_se = isa(state, InfinitePEPS) ? invperms_se_peps : invperms_se_pepo
496489
for (M, coord, invperm, openaxs, Vphy) in zip(Ms, coords, invperms_se, openaxs_se, Vphys)
497490
# restore original axes order

src/utility/util.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,9 @@ The axis orders for `s`, `s2` are
114114
```
115115
"""
116116
function flip_svd(u::AbstractTensorMap, s::DiagonalTensorMap, vh::AbstractTensorMap)
117-
return flip(u, numind(u)),
118-
permute(DiagonalTensorMap(flip(s, (1, 2))), ((2,), (1,))),
119-
flip(vh, 1)
117+
return flip(u, numind(u)), _flip_s(s), flip(vh, 1)
120118
end
119+
_flip_s(s::DiagonalTensorMap) = permute(DiagonalTensorMap(flip(s, (1, 2))), ((2,), (1,)))
121120

122121
"""
123122
twistdual(t::AbstractTensorMap, i)

0 commit comments

Comments
 (0)