Skip to content

Commit 34b27d8

Browse files
Improve efficiency of bond truncation (#366)
* Type stability of bond truncation algorithms * More realistic bond truncation test * Reduce repeated contraction in FET * Reduce repeated contraction in ALS * Reduce test pressure * Rename functions for generalizability * Make ALS cost always real * Drop Val{i} dispatch
1 parent e50c3eb commit 34b27d8

4 files changed

Lines changed: 223 additions & 151 deletions

File tree

src/algorithms/contractions/bondenv/als_solve.jl

Lines changed: 143 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -2,139 +2,168 @@
22
In the following, the names `Ra`, `Sa` etc comes from
33
the fast full update article Physical Review B 92, 035142 (2015)
44
=#
5-
65
"""
7-
$(SIGNATURES)
8-
9-
Construct the tensor
6+
Contract the virtual legs between
107
```
11-
┌-----------------------------------┐
12-
| ┌----┐ |
13-
└---| |- DX0 Db0 - b -- DY0 -┘
14-
| | ↓
15-
|benv| db
16-
| | ↓
17-
┌---| |- DX1 Db1 - b† - DY1 -┐
18-
| └----┘ |
19-
└-----------------------------------┘
8+
-- DX --a-- D --b-- DY --
9+
↓ ↓
10+
da db
2011
```
2112
"""
22-
function _tensor_Ra(benv::BondEnv, b::MPSTensor)
23-
return @autoopt @tensor Ra[DX1 Db1; DX0 Db0] := (
24-
benv[DX1 DY1; DX0 DY0] * b[Db0 db; DY0] * conj(b[Db1 db; DY1])
25-
)
13+
function _combine_ket(a::MPSTensor, b::AbstractTensorMap{T, S, 1, 2}) where {T, S}
14+
return @tensor ket[DX DY; da db] := a[DX da; D] * b[D; db DY]
15+
end
16+
function _combine_ket(a::MPSTensor, b::MPSTensor)
17+
return @tensor ket[DX DY; da db] := a[DX da; D] * b[D db; DY]
2618
end
2719

28-
"""
29-
$(SIGNATURES)
20+
function _combine_ket_for_svd(a::MPSTensor, b::MPSTensor)
21+
return @tensor ket[DX da; db DY] := a[DX da; D] * b[D db; DY]
22+
end
3023

31-
Construct the tensor
24+
"""
25+
Construct the norm with bra bond tensors removed
3226
```
33-
┌-----------------------------------┐
34-
| ┌----┐ |
35-
└---| |- DX0 -- (a2 b2) -- DY0 --┘
36-
| | ↓ ↓
37-
|benv| da db
38-
| | ↓
39-
┌---| |- DX1 Db1 -- b† - DY1 --┐
40-
| └----┘ |
41-
└-----------------------------------┘
27+
┌benv-------┐
28+
├---a---b---┤
29+
| ↓ ↓ |
30+
├-- --┤
31+
└-----------┘
4232
```
4333
"""
44-
function _tensor_Sa(
45-
benv::BondEnv, b::MPSTensor, a2b2::AbstractTensorMap{T, S, 2, 2}
46-
) where {T <: Number, S <: ElementarySpace}
47-
return @autoopt @tensor Sa[DX1 da; Db1] := (
48-
benv[DX1 DY1; DX0 DY0] * conj(b[Db1 db; DY1]) * a2b2[DX0 DY0; da db]
49-
)
34+
function _benv_ket(benv::BondEnv, ket::AbstractTensorMap{T, S, 2, 2}) where {T, S}
35+
return benv * twistdual(ket, 1:2)
5036
end
5137

5238
"""
53-
$(SIGNATURES)
39+
_als_tensor_R(benv::BondEnv, xs::Vector{<:MPSTensor}, i::Int)
5440
55-
Construct the tensor
41+
Construct the bond environment around the `i`th bond tensor
42+
in two-site ALS optimization.
5643
```
57-
┌-----------------------------------┐
58-
| ┌----┐ |
59-
└---| |- DX0 - a -- Da0 DY0 -┘
60-
| | ↓
61-
|benv| da
62-
| | ↓
63-
┌---| |- DX1 - a† - Da1 DY1 -┐
64-
| └----┘ |
65-
└-----------------------------------┘
44+
i = 1 i = 2
45+
┌benv-------┐ ┌benv-------┐
46+
├-- --b---┤ ├---a-- --┤
47+
| ↓ | | ↓ |
48+
├-- --b̄---┤ ├---ā-- --┤
49+
└-----------┘ └-----------┘
6650
```
6751
"""
68-
function _tensor_Rb(benv::BondEnv, a::MPSTensor)
69-
return @autoopt @tensor Rb[Da1 DY1; Da0 DY0] := (
70-
benv[DX1 DY1; DX0 DY0] * a[DX0 da; Da0] * conj(a[DX1 da; Da1])
71-
)
52+
function _als_tensor_R(benv::BondEnv, xs::Vector{<:MPSTensor}, i::Int)
53+
@assert 1 <= i <= 2
54+
return if i == 1
55+
_als_tensor_Ra(benv, xs[2])
56+
else
57+
_als_tensor_Rb(benv, xs[1])
58+
end
7259
end
7360

74-
"""
75-
$(SIGNATURES)
61+
function _als_tensor_Ra(benv::BondEnv, b::MPSTensor)
62+
return @tensor Ra[DX1 D1; DX0 D0] :=
63+
benv[DX1 DY1; DX0 DY0] * b[D0 db; DY0] * conj(b[D1 db; DY1])
64+
end
65+
function _als_tensor_Rb(benv::BondEnv, a::MPSTensor)
66+
return @tensor Rb[D1 DY1; D0 DY0] :=
67+
benv[DX1 DY1; DX0 DY0] * a[DX0 da; D0] * conj(a[DX1 da; D1])
68+
end
7669

77-
Construct the tensor
70+
"""
71+
Calculate the 2-site norm
7872
```
79-
┌-----------------------------------┐
80-
| ┌----┐ |
81-
└---| |- DX0 -- (a2 b2) -- DY0 --┘
82-
| | ↓ ↓
83-
|benv| da db
84-
| | ↓
85-
┌---| |- DX1 -- a† - Da1 DY1 --┐
86-
| └----┘ |
87-
└-----------------------------------┘
73+
┌benv-------┐
74+
├---a---b---┤
75+
| ↓ ↓ |
76+
├---ā---b̄---┤
77+
└-----------┘
8878
```
79+
using pre-calcuated partial contraction results.
8980
"""
90-
function _tensor_Sb(
91-
benv::BondEnv, a::MPSTensor, a2b2::AbstractTensorMap{T, S, 2, 2}
92-
) where {T <: Number, S <: ElementarySpace}
93-
return @autoopt @tensor Sb[Da1 db; DY1] := (
94-
benv[DX1 DY1; DX0 DY0] * conj(a[DX1 da; Da1]) * a2b2[DX0 DY0; da db]
95-
)
81+
function _als_norm(
82+
ket::AbstractTensorMap{T, S, 2, 2}, benv_ket::AbstractTensorMap{T, S, 2, 2}
83+
) where {T, S}
84+
return @tensor benv_ket[DX1 DY1; da db] * conj(ket[DX1 DY1; da db])
85+
end
86+
function _als_norm(a::MPSTensor, Ra::BondEnv)
87+
return @tensor Ra[DX1 D1; DX0 D0] * a[DX0 da; D0] * conj(a[DX1 da; D1])
9688
end
9789

9890
"""
99-
$(SIGNATURES)
91+
_als_tensor_S(
92+
benv_ket::AbstractTensorMap{T, S, 2, 2},
93+
xs::Vector{<:MPSTensor}, i::Int
94+
) where {T <: Number, S <: ElementarySpace}
10095
101-
Calculate the inner product <a1,b1|a2,b2>
96+
Construct the overlap but with one of the bra bond tensor removed.
10297
```
103-
┌--------------------------------┐
104-
| ┌----┐ |
105-
└---| |- DX0 - (a2 b2) - DY0 -┘
106-
| | ↓ ↓
107-
|benv| da db
108-
| | ↓ ↓
109-
┌---| |- DX1 - (a1 b1)†- DY1 -┐
110-
| └----┘ |
111-
└--------------------------------┘
98+
i = 1 i = 2
99+
┌benv-------┐ ┌benv-------┐
100+
├---a₂==b₂--┤ ├---a₂==b₂--┤
101+
| ↓ ↓ | | ↓ ↓ |
102+
├-- --b̄---┤ ├---ā-- --┤
103+
└-----------┘ └-----------┘
112104
```
105+
The ket part is provided by the partial contraction `benv_ket`.
113106
"""
114-
function inner_prod(
115-
benv::BondEnv, a1b1::AbstractTensorMap{T, S, 2, 2}, a2b2::AbstractTensorMap{T, S, 2, 2}
107+
function _als_tensor_S(
108+
benv_ket::AbstractTensorMap{T, S, 2, 2},
109+
xs::Vector{<:MPSTensor}, i::Int
116110
) where {T <: Number, S <: ElementarySpace}
117-
return @autoopt @tensor benv[DX1 DY1; DX0 DY0] *
118-
conj(a1b1[DX1 DY1; da db]) * a2b2[DX0 DY0; da db]
111+
@assert 1 <= i <= 2
112+
return if i == 1
113+
_als_tensor_Sa(benv_ket, xs[2])
114+
else
115+
_als_tensor_Sb(benv_ket, xs[1])
116+
end
117+
end
118+
119+
function _als_tensor_Sa(
120+
benv_ket::AbstractTensorMap{T, S, 2, 2}, b::MPSTensor
121+
) where {T <: Number, S <: ElementarySpace}
122+
return @tensor Sa[DX1 da; D1] :=
123+
benv_ket[DX1 DY1; da db] * conj(b[D1 db; DY1])
124+
end
125+
function _als_tensor_Sb(
126+
benv_ket::AbstractTensorMap{T, S, 2, 2}, a::MPSTensor
127+
) where {T <: Number, S <: ElementarySpace}
128+
return @tensor Sb[D1 db; DY1] :=
129+
benv_ket[DX1 DY1; da db] * conj(a[DX1 da; D1])
119130
end
120131

121132
"""
122-
$(SIGNATURES)
133+
Calculate the inner product (overlap)
134+
```
135+
┌benv-------┐
136+
├---a₂--b₂--┤
137+
| ↓ ↓ |
138+
├---ā---b̄---┤
139+
└-----------┘
140+
```
141+
using pre-calculated partial contraction results.
142+
"""
143+
function _als_overlap(a::MPSTensor, Sa::MPSTensor)
144+
# applies to b, Sb as well
145+
# @tensor Sb[D1 db; DY1] * conj(b[D1 db; DY1])
146+
return @tensor Sa[DX1 da; D1] * conj(a[DX1 da; D1])
147+
end
123148

124-
Contract the axis between `a` and `b` tensors
149+
"""
150+
Calculate the 2-site ALS inner product ⟨a₁,b₁|a₂,b₂⟩
125151
```
126-
-- DX - a - D - b - DY --
127-
↓ ↓
128-
da db
152+
┌benv-------┐
153+
├---a₂--b₂--┤
154+
| ↓ ↓ |
155+
├---ā₁--b̄₁--┤
156+
└-----------┘
129157
```
158+
where `|bra⟩ = |a₁,b₁⟩` and `|ket⟩ = |a₂,b₂⟩`,
159+
with virtual leg between a, b contracted.
130160
"""
131-
function _combine_ab(
132-
a::MPSTensor, b::AbstractTensorMap{T, S, 1, 2}
161+
function inner_prod(
162+
benv::BondEnv, bra::AbstractTensorMap{T, S, 2, 2},
163+
ket::AbstractTensorMap{T, S, 2, 2}
133164
) where {T <: Number, S <: ElementarySpace}
134-
return @tensor ab[DX DY; da db] := a[DX da; D] * b[D; db DY]
135-
end
136-
function _combine_ab(a::MPSTensor, b::MPSTensor)
137-
return @tensor ab[DX DY; da db] := a[DX da; D] * b[D db; DY]
165+
return @autoopt @tensor benv[DX1 DY1; DX0 DY0] *
166+
conj(bra[DX1 DY1; da db]) * ket[DX0 DY0; da db]
138167
end
139168

140169
"""
@@ -161,10 +190,30 @@ function cost_function_als(benv, ψ1, ψ2)
161190
return cost, fid
162191
end
163192

193+
# applies to Rb, Sb, b as well
194+
# b22 is the pre-calculated untruncated norm
195+
function cost_function_als(Ra::BondEnv, Sa::MPSTensor, a::MPSTensor, b22::Real)
196+
b11 = real(_als_norm(a, Ra))
197+
b12 = _als_overlap(a, Sa)
198+
cost = b11 + b22 - 2 * real(b12)
199+
fid = abs2(b12) / abs(b11 * b22)
200+
return cost, fid
201+
end
202+
164203
"""
165204
$(SIGNATURES)
166205
167206
Solve the equations `Rx x = Sx` with initial guess `x0`.
207+
208+
In ALS over `a`, `b`, if we fix `b`, the cost function can
209+
be expressed in the `Ra`, `Sa` tensors as
210+
```
211+
f(a†,a) = a† Ra a - a† Sa - Sa† a + const
212+
```
213+
Therefore `f` is minimized when
214+
```
215+
∂f/∂ā = Ra a - Sa = 0
216+
```
168217
"""
169218
function _solve_als(
170219
Rx::AbstractTensorMap{T, S, N, N},

0 commit comments

Comments
 (0)