Skip to content

Commit b6a7932

Browse files
Simple update refactoring (#346)
1 parent ba4848b commit b6a7932

3 files changed

Lines changed: 45 additions & 50 deletions

File tree

src/algorithms/time_evolution/apply_gate.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ When `A`, `B` are PEPOTensors,
4848
5 1 4 1 4 1
4949
```
5050
"""
51-
function _qr_bond(A::PT, B::PT; gate_ax::Int = 1) where {PT <: Union{PEPSTensor, PEPOTensor}}
51+
function _qr_bond(A::PT, B::PT; gate_ax::Int = 1, kwargs...) where {PT <: Union{PEPSTensor, PEPOTensor}}
5252
@assert 1 <= gate_ax <= numout(A)
5353
permA, permB, permX, permY = if A isa PEPSTensor
5454
((2, 4, 5), (1, 3)), ((2, 3, 4), (1, 5)), (1, 4, 2, 3), Tuple(1:4)
@@ -59,8 +59,8 @@ function _qr_bond(A::PT, B::PT; gate_ax::Int = 1) where {PT <: Union{PEPSTensor,
5959
((1, 3, 5, 6), (2, 4)), ((1, 3, 4, 5), (2, 6)), (1, 2, 5, 3, 4), Tuple(1:5)
6060
end
6161
end
62-
X, a = left_orth!(permute(A, permA; copy = true); positive = true)
63-
Y, b = left_orth!(permute(B, permB; copy = true); positive = true)
62+
X, a = left_orth!(permute(A, permA; copy = true); kwargs...)
63+
Y, b = left_orth!(permute(B, permB; copy = true); kwargs...)
6464
X, Y = permute(X, permX), permute(Y, permY)
6565
b = permute(b, ((3, 2), (1,)))
6666
return X, a, b, Y

src/algorithms/time_evolution/simpleupdate.jl

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -66,50 +66,48 @@ function TimeEvolver(
6666
return TimeEvolver(alg, dt, nstep, gate, state)
6767
end
6868

69-
"""
70-
Optimized simple update of nearest neighbor bonds utilizing
71-
reduced bond tensors without decomposing the gate into a 2-site MPO.
69+
function _bond_rotation(x, bonddir::Int, rev::Bool; inv::Bool = false)
70+
return if bonddir == 1 # x-bond
71+
rev ? rot180(x) : x
72+
elseif bonddir == 2 # y-bond
73+
if rev
74+
inv ? rotr90(x) : rotl90(x)
75+
else
76+
inv ? rotl90(x) : rotr90(x)
77+
end
78+
else
79+
error("`bonddir` must be 1 (for x-bonds) or 2 (for y-bonds).")
80+
end
81+
end
7282

73-
When `purified = true`, `gate` acts on the codomain physical legs of `state`.
74-
Otherwise, `gate` acts on both the codomain and the domain physical legs of `state`.
83+
"""
84+
Simple update optimized for nearest neighbor gates
85+
utilizing reduced bond tensors with the physical leg.
7586
"""
7687
function _su_iter!(
7788
state::InfiniteState, gate::NNGate, env::SUWeight,
78-
sites::Vector{CartesianIndex{2}}, truncs::Vector{E};
79-
purified::Bool = true
80-
) where {E <: TruncationStrategy}
89+
sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate
90+
)
8191
Nr, Nc = size(state)
92+
truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc))
8293
@assert length(sites) == 2 && length(truncs) == 1
8394
Ms, open_vaxs, = _get_cluster(state, sites, env; permute = false)
8495
normalize!.(Ms, Inf)
8596
# rotate
8697
bond, rev = _nn_bondrev(sites..., (Nr, Nc))
87-
A, B = if bond[1] == 1 # x-bond
88-
rev ? map(rot180, Ms) : Ms
89-
else # y-bond
90-
rev ? map(rotl90, Ms) : map(rotr90, Ms)
91-
end
98+
A, B = _bond_rotation.(Ms, bond[1], rev; inv = false)
9299
# apply gate
93100
ϵ, s = 0.0, nothing
94-
gate_axs = purified ? (1:1) : (1:2)
101+
gate_axs = alg.purified ? (1:1) : (1:2)
95102
for gate_ax in gate_axs
96-
X, a, b, Y = _qr_bond(A, B; gate_ax)
103+
X, a, b, Y = _qr_bond(A, B; gate_ax, positive = true)
97104
a, s, b, ϵ′ = _apply_gate(a, b, gate, truncs[1])
98105
ϵ = max(ϵ, ϵ′)
99106
A, B = _qr_bond_undo(X, a, b, Y)
100107
end
101108
# rotate back
102-
if bond[1] == 1 # x-bond
103-
if rev
104-
A, B = rot180(A), rot180(B)
105-
end
106-
else # y-bond
107-
if rev
108-
A, B = rotr90(A), rotr90(B)
109-
else
110-
A, B = rotl90(A), rotl90(B)
111-
end
112-
end
109+
A = _bond_rotation(A, bond[1], rev; inv = true)
110+
B = _bond_rotation(B, bond[1], rev; inv = true)
113111
# remove environment weights
114112
siteA, siteB = map(sites) do site
115113
return CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc))
@@ -120,8 +118,8 @@ function _su_iter!(
120118
normalize!(A, Inf)
121119
normalize!(B, Inf)
122120
normalize!(s, Inf)
123-
state.A[siteA], state.A[siteB] = A, B
124-
env.data[bond...] = s
121+
state[siteA], state[siteB] = A, B
122+
env[bond...] = s
125123
return ϵ
126124
end
127125

@@ -134,40 +132,37 @@ function su_iter(
134132
)
135133
Nr, Nc, = size(state)
136134
state2, env2, ϵ = deepcopy(state), deepcopy(env), 0.0
137-
purified = alg.purified
138135
for (sites, gate) in gates.terms
139136
if length(sites) == 1
140137
# 1-site gate
141138
# TODO: special treatment for bipartite state
142139
site = sites[1]
143140
r, c = mod1(site[1], Nr), mod1(site[2], Nc)
144-
state2.A[r, c] = _apply_sitegate(state2.A[r, c], gate; purified)
141+
state2[r, c] = _apply_sitegate(state2[r, c], gate; alg.purified)
145142
elseif length(sites) == 2
146143
(d, r, c), = _nn_bondrev(sites..., (Nr, Nc))
147144
if alg.bipartite
148145
length(sites) > 2 && error("Multi-site MPO gates are not compatible with bipartite states.")
149146
r > 1 && continue
150147
end
151-
truncs = _get_cluster_trunc(alg.trunc, sites, size(state)[1:2])
152-
ϵ′ = _su_iter!(state2, gate, env2, sites, truncs; purified)
148+
ϵ′ = _su_iter!(state2, gate, env2, sites, alg)
153149
ϵ = max(ϵ, ϵ′)
154150
(!alg.bipartite) && continue
155151
if d == 1
156152
rp1, cp1 = _next(r, Nr), _next(c, Nc)
157-
state2.A[rp1, cp1] = deepcopy(state2.A[r, c])
158-
state2.A[rp1, c] = deepcopy(state2.A[r, cp1])
159-
env2.data[1, rp1, cp1] = deepcopy(env2.data[1, r, c])
153+
state2[rp1, cp1] = deepcopy(state2[r, c])
154+
state2[rp1, c] = deepcopy(state2[r, cp1])
155+
env2[1, rp1, cp1] = deepcopy(env2[1, r, c])
160156
else
161157
rm1, cm1 = _prev(r, Nr), _prev(c, Nc)
162-
state2.A[rm1, cm1] = deepcopy(state2.A[r, c])
163-
state2.A[r, cm1] = deepcopy(state2.A[rm1, c])
164-
env2.data[2, rm1, cm1] = deepcopy(env2.data[2, r, c])
158+
state2[rm1, cm1] = deepcopy(state2[r, c])
159+
state2[r, cm1] = deepcopy(state2[rm1, c])
160+
env2[2, rm1, cm1] = deepcopy(env2[2, r, c])
165161
end
166162
else
167163
# N-site MPO gate (N ≥ 2)
168164
alg.bipartite && error("Multi-site MPO gates are not compatible with bipartite states.")
169-
truncs = _get_cluster_trunc(alg.trunc, sites, size(state)[1:2])
170-
ϵ′ = _su_iter!(state2, gate, env2, sites, truncs; purified)
165+
ϵ′ = _su_iter!(state2, gate, env2, sites, alg)
171166
ϵ = max(ϵ, ϵ′)
172167
end
173168
end

src/algorithms/time_evolution/simpleupdate3site.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ function _get_cluster(
150150
Ms = map(zip(sites, open_vaxs, perms)) do (site, vaxs, perm)
151151
s = CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc))
152152
M = if env === nothing
153-
state.A[s]
153+
state[s]
154154
else
155-
absorb_weight(state.A[s], env, s[1], s[2], vaxs)
155+
absorb_weight(state[s], env, s[1], s[2], vaxs)
156156
end
157157
return permute ? TensorKit.permute(M, perm) : M
158158
end
@@ -164,18 +164,18 @@ Simple update with an N-site MPO `gate` (N ≥ 2).
164164
"""
165165
function _su_iter!(
166166
state::InfiniteState, gate::Vector{T}, env::SUWeight,
167-
sites::Vector{CartesianIndex{2}}, truncs::Vector{E};
168-
purified::Bool = true
169-
) where {T <: AbstractTensorMap, E <: TruncationStrategy}
167+
sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate
168+
) where {T <: AbstractTensorMap}
170169
Nr, Nc = size(state)
170+
truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc))
171171
Ms, open_vaxs, invperms = _get_cluster(state, sites, env)
172172
flips = [isdual(space(M, 1)) for M in Ms[2:end]]
173173
Vphys = [codomain(M, 2) for M in Ms]
174174
normalize!.(Ms, Inf)
175175
# flip virtual arrows in `Ms` to ←
176176
_flip_virtuals!(Ms, flips)
177177
# apply gate MPOs and truncate
178-
gate_axs = purified ? (1:1) : (1:2)
178+
gate_axs = alg.purified ? (1:1) : (1:2)
179179
wts, ϵs = nothing, nothing
180180
for gate_ax in gate_axs
181181
_apply_gatempo!(Ms, gate; gate_ax)
@@ -206,7 +206,7 @@ function _su_iter!(
206206
# remove weights on open axes of the cluster
207207
M = absorb_weight(M, env, s′[1], s′[2], vaxs; inv = true)
208208
# update state tensors
209-
state.A[s′] = normalize(M, Inf)
209+
state[s′] = normalize(M, Inf)
210210
end
211211
return maximum(ϵs)
212212
end

0 commit comments

Comments
 (0)