Skip to content

Commit 3a8c969

Browse files
committed
Refactor bond direction helpers to use lattice direction constants
1 parent 9976630 commit 3a8c969

5 files changed

Lines changed: 77 additions & 54 deletions

File tree

src/algorithms/time_evolution/get_cluster.jl

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,46 +51,66 @@ end
5151

5252
"""
5353
Given `site1`, `site2` connected by a nearest neighbor bond,
54-
return the bond index and whether it is reversed from the
55-
standard orientation (`site1` on the west/south of `site2`).
54+
return `(dir, r, c)` where `dir` ∈ {NORTH, EAST, SOUTH, WEST}
55+
is the direction from `site1` to `site2`.
56+
`(r, c)` is the position of the source endpoint:
57+
the west site for x-bonds (`dir == EAST` or `dir == WEST`),
58+
the south site for y-bonds (`dir == NORTH` or `dir == SOUTH`).
5659
"""
57-
function _nn_bondrev(site1::CartesianIndex{2}, site2::CartesianIndex{2})
60+
function _nn_bonddir(site1::CartesianIndex{2}, site2::CartesianIndex{2})
5861
diff = site1 - site2
59-
if diff == CartesianIndex(0, -1)
62+
if diff == CartesianIndex(0, -1) # site1 west of site2 → site2 to the EAST
6063
r, c = site1[1], site1[2]
61-
return (1, r, c), false
62-
elseif diff == CartesianIndex(0, 1)
63-
r, c = site2[1], site2[2]
64-
return (1, r, c), true
65-
elseif diff == CartesianIndex(1, 0)
64+
return (EAST, r, c)
65+
elseif diff == CartesianIndex(1, 0) # site1 south of site2 → site2 to the NORTH
6666
r, c = site1[1], site1[2]
67-
return (2, r, c), false
68-
elseif diff == CartesianIndex(-1, 0)
67+
return (NORTH, r, c)
68+
elseif diff == CartesianIndex(-1, 0) # site1 north of site2 → site2 to the SOUTH
69+
r, c = site2[1], site2[2]
70+
return (SOUTH, r, c)
71+
elseif diff == CartesianIndex(0, 1) # site1 east of site2 → site2 to the WEST
6972
r, c = site2[1], site2[2]
70-
return (2, r, c), true
73+
return (WEST, r, c)
7174
else
7275
error("`site1` and `site2` are not nearest neighbors.")
7376
end
7477
end
7578

76-
function _bond_rotation(x, bonddir::Int, rev::Bool; inv::Bool = false)
77-
return if bonddir == 1 # x-bond
78-
rev ? rot180(x) : x
79-
elseif bonddir == 2 # y-bond
80-
if rev
81-
inv ? rotr90(x) : rotl90(x)
82-
else
83-
inv ? rotl90(x) : rotr90(x)
84-
end
85-
else
86-
error("`bonddir` must be 1 (for x-bonds) or 2 (for y-bonds).")
79+
"""
80+
Apply the tensor rotation that maps a bond from orientation `dir_from`
81+
to orientation `dir_to`, where the directions are one of
82+
`NORTH`, `EAST`, `SOUTH`, `WEST`.
83+
"""
84+
function _rotate_by_dir(x, dir_from::Int, dir_to::Int)
85+
delta = mod(dir_to - dir_from, 4)
86+
return if delta == 0
87+
x
88+
elseif delta == 1
89+
rotr90(x)
90+
elseif delta == 2
91+
rot180(x)
92+
else # delta == 3
93+
rotl90(x)
8794
end
8895
end
89-
function _bond_rotation(x::CartesianIndex{2}, bonddir::Int, rev::Bool, unitcell::NTuple{2, Int})
90-
return if bonddir == 1
91-
rev ? siterot180(x, unitcell) : x
92-
else
93-
rev ? siterotl90(x, unitcell) : siterotr90(x, unitcell)
96+
97+
"""
98+
Map a site coordinate from the lattice frame where the bond has orientation
99+
`dir_from` to the frame where the bond has orientation `dir_to`.
100+
"""
101+
function _siterot_by_dir(
102+
site::CartesianIndex{2}, dir_from::Int, dir_to::Int,
103+
unitcell::NTuple{2, Int}
104+
)
105+
delta = mod(dir_to - dir_from, 4)
106+
return if delta == 0
107+
site
108+
elseif delta == 1
109+
siterotr90(site, unitcell)
110+
elseif delta == 2
111+
siterot180(site, unitcell)
112+
else # delta == 3
113+
siterotl90(site, unitcell)
94114
end
95115
end
96116

src/algorithms/time_evolution/ntupdate.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ function ntu_iter(
9898
state2[site] = _apply_sitegate(state2[site], gate)
9999
info′ = (; fid = 1.0)
100100
elseif length(sites) == 2
101-
(d, r, c), = _nn_bondrev(sites...)
101+
(dir, r, c) = _nn_bonddir(sites...)
102+
d = dir in (EAST, WEST) ? 1 : 2
102103
alg.bipartite && iseven(r) && continue
103104
state2, wts, info′ = _ntu_iter(state2, gate, wts, sites, alg)
104105
(!alg.bipartite) && continue

src/algorithms/time_evolution/ntupdate3site.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,14 @@ function _bond_truncate(
6464
(stype1, stype2)::NTuple{2, Symbol},
6565
alg::NeighbourUpdate; gate::Union{NNGate, Nothing} = nothing
6666
)
67-
# rotate bond to standard x direction `A ← B`
67+
# get bond direction and rotate state to standard x-direction `A ← B`
68+
(dir, _, _) = _nn_bonddir(site1, site2)
6869
ucell = size(state)[1:2]
69-
bond, rev = _nn_bondrev(site1, site2)
70-
dir = first(bond)
71-
state2 = _bond_rotation(state, dir, rev; inv = false)
72-
wts2 = _bond_rotation(wts, dir, rev; inv = false)
70+
state2 = _rotate_by_dir(state, dir, EAST)
71+
wts2 = _rotate_by_dir(wts, dir, EAST)
7372

74-
# rotated bond tensors
75-
siteA = _bond_rotation(site1, dir, rev, ucell)
73+
# site1 position in the rotated frame
74+
siteA = _siterot_by_dir(site1, dir, EAST, ucell)
7675
row, col = siteA[1], siteA[2]
7776
A, B = state2[row, col], state2[row, col + 1]
7877

@@ -123,8 +122,8 @@ function _bond_truncate(
123122
state2[row, col + 1] = normalize!(B, Inf)
124123
wts2[1, row, col] = normalize!(s, Inf)
125124

126-
# rotate back tensors and bond weight
127-
state2 = _bond_rotation(state2, dir, rev; inv = true)
128-
wts2 = _bond_rotation(wts2, dir, rev; inv = true)
125+
# rotate back state and weights
126+
state2 = _rotate_by_dir(state2, EAST, dir)
127+
wts2 = _rotate_by_dir(wts2, EAST, dir)
129128
return state2, wts2, info
130129
end

src/algorithms/time_evolution/simpleupdate.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,8 @@ function _su_iter!(
8686
Ms, open_vaxs, = _get_cluster(state, sites)
8787
_absorb_weight!(Ms, sites, open_vaxs, env)
8888
# rotate
89-
bond, rev = _nn_bondrev(sites...)
90-
dir = first(bond)
91-
A, B = _bond_rotation.(Ms, dir, rev; inv = false)
89+
(dir, r, c) = _nn_bonddir(sites...)
90+
A, B = _rotate_by_dir.(Ms, dir, EAST)
9291
# apply gate
9392
ϵ = 0.0
9493
local s
@@ -102,17 +101,18 @@ function _su_iter!(
102101
alg.purified && break # only apply gate to 1st physical leg
103102
end
104103
# rotate back
105-
A = _bond_rotation(A, dir, rev; inv = true)
106-
B = _bond_rotation(B, dir, rev; inv = true)
107-
rev && (s = transpose(s))
104+
A = _rotate_by_dir(A, EAST, dir)
105+
B = _rotate_by_dir(B, EAST, dir)
106+
(dir in (WEST, SOUTH)) && (s = transpose(s))
108107
# remove environment weights
109108
siteA, siteB = sites
110109
A = absorb_weight(A, env, siteA[1], siteA[2], open_vaxs[1]; inv = true)
111110
B = absorb_weight(B, env, siteB[1], siteB[2], open_vaxs[2]; inv = true)
112111
# update tensor dict and weight on current bond
113112
state[siteA] = normalize!(A, Inf)
114113
state[siteB] = normalize!(B, Inf)
115-
env[bond...] = normalize!(s, Inf)
114+
d = dir in (EAST, WEST) ? 1 : 2
115+
env[d, r, c] = normalize!(s, Inf)
116116
return ϵ
117117
end
118118

@@ -131,7 +131,8 @@ function su_iter(
131131
site = only(sites)
132132
state2[site] = _apply_sitegate(state2[site], gate; alg.purified)
133133
elseif length(sites) == 2
134-
(d, r, c), = _nn_bondrev(sites...)
134+
(dir, r, c) = _nn_bonddir(sites...)
135+
d = dir in (EAST, WEST) ? 1 : 2
135136
alg.bipartite && iseven(r) && continue
136137
ϵ′ = _su_iter!(state2, gate, env2, sites, alg)
137138
ϵ = max(ϵ, ϵ′)

src/algorithms/time_evolution/simpleupdate3site.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ function _su_iter!(
4747
# restore virtual arrows in `Ms`
4848
_flip_virtuals!(Ms, flips)
4949
# update env weights
50-
bond_revs = map(sites, Iterators.drop(sites, 1)) do site1, site2
51-
_nn_bondrev(site1, site2)
50+
bond_dirs = map(sites, Iterators.drop(sites, 1)) do site1, site2
51+
_nn_bonddir(site1, site2)
5252
end
53-
for (wt, (bond, rev), flip) in zip(wts, bond_revs, flips)
53+
for (wt, (dir, r, c), flip) in zip(wts, bond_dirs, flips)
5454
wt_new = flip ? _fliptwist_s(wt) : wt
55-
wt_new = rev ? transpose(wt_new) : wt_new
56-
env[CartesianIndex(bond)] = normalize!(wt_new, Inf)
55+
wt_new = (dir in (WEST, SOUTH)) ? transpose(wt_new) : wt_new
56+
d = dir in (EAST, WEST) ? 1 : 2
57+
env[d, r, c] = normalize!(wt_new, Inf)
5758
end
5859
# update state tensors
5960
for (M, s, invperm, vaxs) in zip(Ms, sites, invperms, open_vaxs)
@@ -72,9 +73,10 @@ function _get_cluster_trunc(
7273
trunc::TruncationStrategy, sites::Vector{CartesianIndex{2}}
7374
)
7475
return map(sites, Iterators.drop(sites, 1)) do site1, site2
75-
(d, r, c), rev = _nn_bondrev(site1, site2)
76+
(dir, r, c) = _nn_bonddir(site1, site2)
77+
d = dir in (EAST, WEST) ? 1 : 2
7678
t = truncation_strategy(trunc, d, r, c)
77-
if rev && isa(t, TruncationSpace)
79+
if (dir in (WEST, SOUTH)) && isa(t, TruncationSpace)
7880
t = truncspace(flip(t.space)')
7981
end
8082
return t

0 commit comments

Comments
 (0)