Skip to content

Commit a94889f

Browse files
committed
truncation and expansion
1 parent 42e9ce1 commit a94889f

6 files changed

Lines changed: 171 additions & 40 deletions

File tree

src/MPSKit.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ include("algorithms/algorithm.jl")
9494
include("utility/dynamictols.jl")
9595
using .DynamicTols
9696

97+
include("utility/dynamictruncation.jl")
98+
using .DynamicTruncation
99+
97100
include("utility/defaults.jl")
98101
using .Defaults: VERBOSE_NONE, VERBOSE_WARN, VERBOSE_CONV, VERBOSE_ITER, VERBOSE_ALL
99102
include("utility/logging.jl")
@@ -107,6 +110,7 @@ include("utility/multiline.jl")
107110
include("utility/utility.jl") # random utility functions
108111
include("utility/plotting.jl")
109112
include("utility/linearcombination.jl")
113+
include("utility/dynamictruncation.jl")
110114

111115
# maybe we should introduce an abstract state type
112116
include("states/abstractmps.jl")
@@ -156,6 +160,7 @@ include("algorithms/changebonds/vumpssvd.jl")
156160
include("algorithms/changebonds/svdcut.jl")
157161
include("algorithms/changebonds/randexpand.jl")
158162
include("algorithms/changebonds/localexpand.jl")
163+
include("algorithms/changebonds/truncation.jl")
159164

160165
include("algorithms/timestep/tdvp.jl")
161166
include("algorithms/timestep/taylorcluster.jl")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
function noexpand()
2+
return truncrank(0)
3+
end

src/algorithms/groundstate/dmrg.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,20 @@ struct DMRG{A, F} <: Algorithm
2222
"algorithm used for the eigenvalue solvers"
2323
alg_eigsolve::A
2424

25-
expscheme::Algorithm
25+
trscheme::TruncationStrategy
26+
expscheme::TruncationStrategy
2627

2728
"callback function applied after each iteration, of signature `finalize(iter, ψ, H, envs) -> ψ, envs`"
2829
finalize::F
2930
end
3031
function DMRG(;
3132
tol = Defaults.tol, maxiter = Defaults.maxiter, alg_eigsolve = (;),
3233
verbosity = Defaults.verbosity, finalize = Defaults._finalize,
33-
miniter = 0, expscheme = NoExpand()
34+
miniter = 0, trscheme = notrunc(), expscheme = noexpand()
3435
)
3536
alg_eigsolve′ = alg_eigsolve isa NamedTuple ? Defaults.alg_eigsolve(; alg_eigsolve...) :
3637
alg_eigsolve
37-
return DMRG(tol, maxiter, miniter, verbosity, alg_eigsolve′, expscheme, finalize)
38+
return DMRG(tol, maxiter, miniter, verbosity, alg_eigsolve′, trscheme, expscheme, finalize)
3839
end
3940

4041
function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG, envs = environments(ψ, H))
@@ -46,6 +47,8 @@ function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG, envs = environm
4647
@infov 2 loginit!(log, ϵ, expectation_value(ψ, H, envs))
4748
for iter in 1:(alg.maxiter)
4849
alg_eigsolve = updatetol(alg.alg_eigsolve, iter, ϵ)
50+
expscheme = updatetruncation(alg.expscheme; iter = iter, current_rank = maximum(map(left_virtualspace, ψ)))
51+
trscheme = updatetruncation(alg.trscheme; iter = iter)
4952

5053
zerovector!(ϵs)
5154
dir = 1
@@ -57,12 +60,12 @@ function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG, envs = environm
5760
if alg.expscheme isa NoExpand
5861
ψ.AC[pos] = vec
5962
elseif dir == 1
60-
AL, C = left_orth!(vec; positive = true)
61-
AL, C, ψ.AC[pos + 1] = changebonds_left(AL, C, ψ.AC[pos + 1], alg.expscheme)
63+
AL, C = left_orth!(vec; trunc = trscheme)
64+
AL, C, ψ.AC[pos + 1] = changebonds_left(AL, C, ψ.AC[pos + 1], expscheme)
6265
ψ.AC[pos] = (AL, C)
6366
elseif dir == -1
64-
C, temp = right_orth!(_transpose_tail.AC[pos]; copy = true); positive = true)
65-
C, ψ.AC[pos - 1], temp = changebonds_right(C, ψ.AC[pos - 1], temp, alg.expscheme)
67+
C, temp = right_orth!(_transpose_tail.AC[pos]); trunc = trscheme)
68+
C, ψ.AC[pos - 1], temp = changebonds_right(C, ψ.AC[pos - 1], temp, expscheme)
6669
ψ.AC[pos] = (C, _transpose_front(temp))
6770
end
6871
end
@@ -113,17 +116,16 @@ struct DMRG2{A, S, F} <: Algorithm
113116

114117
"algorithm used for [truncation](@extref MatrixAlgebraKit.TruncationStrategy) of the two-site update"
115118
trscheme::TruncationStrategy
116-
117-
expscheme::Algorithm
119+
expscheme::TruncationStrategy
118120

119121
"callback function applied after each iteration, of signature `finalize(iter, ψ, H, envs) -> ψ, envs`"
120122
finalize::F
121123
end
122124
# TODO: find better default truncation
123125
function DMRG2(;
124126
tol = Defaults.tol, maxiter = Defaults.maxiter, verbosity = Defaults.verbosity,
125-
miniter = 0, alg_eigsolve = (;), alg_svd = Defaults.alg_svd(), trscheme,
126-
expscheme = NoExpand(), finalize = Defaults._finalize
127+
miniter = 0, alg_eigsolve = (;), alg_svd = Defaults.alg_svd(), trscheme = notrunc(),
128+
expscheme = noexpand(), finalize = Defaults._finalize
127129
)
128130
alg_eigsolve′ = alg_eigsolve isa NamedTuple ? Defaults.alg_eigsolve(; alg_eigsolve...) :
129131
alg_eigsolve
@@ -138,6 +140,8 @@ function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG2, envs = environ
138140
LoggingExtras.withlevel(; alg.verbosity) do
139141
for iter in 1:(alg.maxiter)
140142
alg_eigsolve = updatetol(alg.alg_eigsolve, iter, ϵ)
143+
trscheme = updatetruncation(alg.trscheme; iter=iter)
144+
expscheme = updatetruncation(alg.expscheme; iter = iter, current_rank = maximum(map(left_virtualspace, ψ)))
141145
zerovector!(ϵs)
142146

143147
# left to right sweep
@@ -146,8 +150,8 @@ function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG2, envs = environ
146150
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
147151
_, newA2center = fixedpoint(Hac2, ac2, :SR, alg_eigsolve)
148152

149-
al, c, ar = svd_trunc!(newA2center; trunc = alg.trscheme, alg = alg.alg_svd)
150-
al, c = changebonds_left(al, c, alg.expscheme)
153+
al, c, ar = svd_trunc!(newA2center; trunc = trscheme, alg = alg.alg_svd)
154+
al, c = changebonds_left(al, c, expscheme)
151155
normalize!(c)
152156
v = @plansor ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
153157
ϵs[pos] = max(ϵs[pos], abs(1 - abs(v)))
@@ -162,8 +166,8 @@ function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG2, envs = environ
162166
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
163167
_, newA2center = fixedpoint(Hac2, ac2, :SR, alg_eigsolve)
164168

165-
al, c, ar = svd_trunc!(newA2center; trunc = alg.trscheme, alg = alg.alg_svd)
166-
c, ar = changebonds_right(c, ar, alg.expscheme)
169+
al, c, ar = svd_trunc!(newA2center; trunc = trscheme, alg = alg.alg_svd)
170+
c, ar = changebonds_right(c, ar, expscheme)
167171
normalize!(c)
168172
v = @plansor ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
169173
ϵs[pos] = max(ϵs[pos], abs(1 - abs(v)))

src/algorithms/groundstate/idmrg.jl

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ $(TYPEDFIELDS)
2222
"algorithm used for gauging the MPS"
2323
alg_gauge = Defaults.alg_gauge()
2424

25+
2526
"algorithm used for the eigenvalue solvers"
2627
alg_eigsolve::A = Defaults.alg_eigsolve()
2728

28-
expscheme::Algorithm = NoExpand()
29+
expscheme::TruncationStrategy = noexpand()
30+
31+
trscheme::TruncationStrategy = notrunc()
2932
end
3033

3134
"""
@@ -58,10 +61,10 @@ $(TYPEDFIELDS)
5861
"algorithm used for the singular value decomposition"
5962
alg_svd::S = Defaults.alg_svd()
6063

61-
expscheme::Algorithm = NoExpand()
64+
expscheme::TruncationStrategy = noexpand()
6265

6366
"algorithm used for [truncation](@extref MatrixAlgebraKit.TruncationStrategy) of the two-site update"
64-
trscheme::TruncationStrategy
67+
trscheme::TruncationStrategy = notrunc()
6568
end
6669

6770

@@ -79,9 +82,9 @@ function IDMRGState{T}(mps::S, operator::O, envs::E, iter::Int, ϵ::Float64, ene
7982
end
8083

8184
function find_groundstate!(
82-
::InfiniteChainStyle, mps, operator, alg::alg_type,
85+
::InfiniteChainStyle, mps::S, operator, alg::alg_type,
8386
envs = environments(mps, operator)
84-
) where {alg_type <: Union{<:IDMRG, <:IDMRG2}}
87+
) where {alg_type <: Union{<:IDMRG, <:IDMRG2}, S}
8588
(length(mps) 1 && alg isa IDMRG2) && throw(ArgumentError("unit cell should be >= 2"))
8689
log = alg isa IDMRG ? IterLog("IDMRG") : IterLog("IDMRG2")
8790
iter = 0
@@ -111,8 +114,9 @@ function find_groundstate!(
111114
@infov 3 logiter!(log, it.iter, ϵ, ΔE)
112115
end
113116

114-
alg_gauge = updatetol(alg.alg_gauge, it.state.iter, it.state.ϵ)
115-
ψ′ = InfiniteMPS(it.state.mps.AR; alg_gauge.tol, alg_gauge.maxiter)
117+
alg_gauge = updatetol(alg.alg_gauge, it.iter, it.ϵ)
118+
ψ′ = S.name.wrapper(it.state.mps.AR; alg_gauge.tol, alg_gauge.maxiter)
119+
116120
envs = recalculate!(it.state.envs, ψ′, it.state.operator, ψ′)
117121
return ψ′, envs, it.state.ϵ
118122
end
@@ -124,8 +128,7 @@ function Base.iterate(
124128
mps, envs, C_old, E_new = localupdate_step!(it, state)
125129

126130
# error criterion
127-
C = mps.C[0]
128-
ϵ = bond_error(C_old, C)
131+
ϵ = bond_error(C_old, mps.C[0])
129132

130133
# New energy
131134
ΔE = (E_new - state.energy) / 2
@@ -141,24 +144,28 @@ function localupdate_step!(
141144
it::IterativeSolver{<:IDMRG}, state
142145
)
143146
alg_eigsolve = updatetol(it.alg_eigsolve, state.iter, state.ϵ)
144-
return _localupdate_sweep_idmrg!(state.mps, state.operator, state.envs, alg_eigsolve, it.alg_expscheme)
147+
expscheme = updatetruncation(it.expscheme; iter = state.iter, current_rank = maximum(map(left_virtualspace, state.mps)))
148+
trscheme = updatetruncation(it.trscheme; iter = state.iter)
149+
return _localupdate_sweep_idmrg!(state.mps, state.operator, state.envs, alg_eigsolve, trscheme, expscheme)
145150
end
146151

147152
function localupdate_step!(
148153
it::IterativeSolver{<:IDMRG2}, state
149154
)
150155
alg_eigsolve = updatetol(it.alg_eigsolve, state.iter, state.ϵ)
151-
return _localupdate_sweep_idmrg2!(state.mps, state.operator, state.envs, alg_eigsolve, it.trscheme, it.alg_svd, it.expscheme)
156+
expscheme = updatetruncation(it.expscheme; iter = state.iter, current_rank = maximum(map(left_virtualspace, state.mps)))
157+
trscheme = updatetruncation(it.trscheme; iter = state.iter)
158+
return _localupdate_sweep_idmrg2!(state.mps, state.operator, state.envs, alg_eigsolve, trscheme, it.alg_svd, expscheme)
152159
end
153160

154-
function _localupdate_sweep_idmrg!(ψ, H, envs, alg_eigsolve, expscheme)
161+
function _localupdate_sweep_idmrg!(ψ, H, envs, alg_eigsolve, alg_trscheme, expscheme)
155162
local E
156163
C_old = ψ.C[0]
157164
# left to right sweep
158165
for pos in 1:length(ψ)
159166
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
160167
_, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
161-
ψ.AL[pos], ψ.C[pos] = left_orth!.AC[pos]; positive = true)
168+
ψ.AL[pos], ψ.C[pos] = left_orth!.AC[pos]; trunc = alg_trscheme)
162169
ψ.AL[pos], ψ.C[pos], ψ.AC[pos + 1] = changebonds_left.AL[pos], ψ.C[pos], ψ.AC[pos + 1], expscheme)
163170
if pos == length(ψ) # AC needed in next sweep
164171
ψ.AC[pos] = _mul_tail.AL[pos], ψ.C[pos])
@@ -171,7 +178,7 @@ function _localupdate_sweep_idmrg!(ψ, H, envs, alg_eigsolve, expscheme)
171178
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
172179
E, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
173180

174-
C, temp = right_orth!(_transpose_tail.AC[pos]); positive = true)
181+
C, temp = right_orth!(_transpose_tail.AC[pos]); trunc = alg_trscheme)
175182
ψ.C[pos - 1], ψ.AC[pos - 1], temp = changebonds_right(C, ψ.AC[pos - 1], temp, expscheme)
176183
ψ.AR[pos] = _transpose_front(temp)
177184
if pos == 1 # AC needed in next sweep
@@ -207,8 +214,8 @@ function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg
207214
# update the edge
208215
ψ.AL[end] = ψ.AC[end] / ψ.C[end]
209216
ψ.AC[1] = _mul_tail.AL[1], ψ.C[1])
210-
ac2 = AC2(ψ, 0; kind = :ALAC)
211-
h_ac2 = AC2_hamiltonian(0, ψ, H, ψ, envs)
217+
ac2 = AC2(ψ, length(ψ); kind = :ALAC)
218+
h_ac2 = AC2_hamiltonian(length(ψ), ψ, H, ψ, envs)
212219
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
213220

214221
al, c, ar = svd_trunc!(ac2′; trunc = alg_trscheme, alg = alg_svd)
@@ -217,11 +224,11 @@ function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg
217224

218225
ψ.AL[end] = al
219226
ψ.C[end] = complex(c)
220-
ψ.AR[1] = _transpose_front(ar)
227+
ψ.AR[end+1] = _transpose_front(ar)
221228

222229
ψ.AC[end] = _mul_tail(al, c)
223-
ψ.AC[1] = _transpose_front(c * ar)
224-
ψ.AL[1] = ψ.AC[1] / ψ.C[1]
230+
ψ.AC[end+1] = _transpose_front(c * ar)
231+
ψ.AL[end+1] = ψ.AC[end+1] / ψ.C[end+1]
225232

226233
C_old = complex(c)
227234

@@ -250,20 +257,20 @@ function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg
250257
end
251258

252259
# update the edge
253-
ψ.AC[end] = _mul_front.C[end - 1], ψ.AR[end])
254-
ψ.AR[1] = _transpose_front.C[end] \ _transpose_tail.AC[1]))
260+
ψ.AC[0] = _mul_front.C[- 1], ψ.AR[0])
261+
ψ.AR[1] = _transpose_front.C[0] \ _transpose_tail.AC[1]))
255262
ac2 = AC2(ψ, 0; kind = :ACAR)
256263
h_ac2 = AC2_hamiltonian(0, ψ, H, ψ, envs)
257264
E, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
258265
al, c, ar = svd_trunc!(ac2′; trunc = alg_trscheme, alg = alg_svd)
259266
al, c, ar = changebonds(al, c, ar, expscheme)
260267
normalize!(c)
261268

262-
ψ.AL[end] = al
263-
ψ.C[end] = complex(c)
269+
ψ.AL[0] = al
270+
ψ.C[0] = complex(c)
264271
ψ.AR[1] = _transpose_front(ar)
265272

266-
ψ.AR[end] = _transpose_front.C[end - 1] \ _transpose_tail(al * c))
273+
ψ.AR[0] = _transpose_front.C[-1] \ _transpose_tail(al * c))
267274
ψ.AC[1] = _transpose_front(c * ar)
268275

269276
transfer_leftenv!(envs, ψ, H, ψ, 1)

src/algorithms/groundstate/vumps.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ $(TYPEDFIELDS)
3333

3434
"callback function applied after each iteration, of signature `finalize(iter, ψ, H, envs) -> ψ, envs`"
3535
finalize::F = Defaults._finalize
36+
37+
parallel::Bool = false
3638
end
3739

3840
# Internal state of the VUMPS algorithm
@@ -118,7 +120,7 @@ function localupdate_step!(
118120
tforeach(eachsite(mps), src_ACs, src_Cs; scheduler) do site, AC₀, C₀
119121
dst_ACs[site] = _localupdate_vumps_step!(
120122
site, mps, state.operator, state.envs, AC₀, C₀;
121-
parallel = false, alg_orth, state.which, alg_eigsolve
123+
parallel = it.parallel, alg_orth, state.which, alg_eigsolve
122124
)
123125
return nothing
124126
end

0 commit comments

Comments
 (0)