Skip to content

Commit 32299e7

Browse files
committed
add backend and allocator to DMRG and IDMRG
1 parent 38c9a8f commit 32299e7

2 files changed

Lines changed: 51 additions & 44 deletions

File tree

src/algorithms/groundstate/dmrg.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ struct DMRG{A, F} <: Algorithm
2222
"algorithm used for the eigenvalue solvers"
2323
alg_eigsolve::A
2424

25+
prepare::Bool
26+
2527
trscheme::TruncationStrategy
2628
expscheme::TruncationStrategy
2729

@@ -31,19 +33,21 @@ end
3133
function DMRG(;
3234
tol = Defaults.tol, maxiter = Defaults.maxiter, alg_eigsolve = (;),
3335
verbosity = Defaults.verbosity, finalize = Defaults._finalize,
34-
miniter = 0, trscheme = notrunc(), expscheme = noexpand()
36+
miniter = 0, trscheme = notrunc(), expscheme = noexpand(),
37+
prepare=true
3538
)
3639
alg_eigsolve′ = alg_eigsolve isa NamedTuple ? Defaults.alg_eigsolve(; alg_eigsolve...) :
3740
alg_eigsolve
38-
return DMRG(tol, maxiter, miniter, verbosity, alg_eigsolve′, trscheme, expscheme, finalize)
41+
return DMRG(tol, maxiter, miniter, verbosity, alg_eigsolve′, prepare, trscheme, expscheme, finalize)
3942
end
4043

41-
function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG, envs = environments(ψ, H))
44+
function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG, envs = environments(ψ, H), backend=DefaultBackend(), allocator=BufferAllocator())
4245
ϵs = map(pos -> calc_galerkin(pos, ψ, H, ψ, envs), 1:length(ψ))
4346
ϵ = maximum(ϵs)
4447
log = IterLog("DMRG")
4548
timeroutput = TimerOutput("DMRG")
4649
alg.verbosity > 3 || disable_timer!(timeroutput)
50+
prepare = alg.prepare
4751

4852
LoggingExtras.withlevel(; alg.verbosity) do
4953
@infov 2 loginit!(log, ϵ, expectation_value(ψ, H, envs))
@@ -59,7 +63,7 @@ function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG, envs = environm
5963
for pos in [1:(length(ψ) - 1); length(ψ):-1:2]
6064
@timeit timeroutput "AC_eigsolve" begin
6165
pos == length(ψ) && (dir = -1)
62-
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
66+
h = AC_hamiltonian(pos, ψ, H, ψ, envs; prepare, backend, allocator)
6367
_, vec = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
6468
end
6569
ϵs[pos] = max(ϵs[pos], calc_galerkin(pos, ψ, H, ψ, envs))
@@ -125,6 +129,7 @@ struct DMRG2{A, S, F} <: Algorithm
125129
"algorithm used for the singular value decomposition"
126130
alg_svd::S
127131

132+
prepare::Bool
128133
"algorithm used for [truncation](@extref MatrixAlgebraKit.TruncationStrategy) of the two-site update"
129134
trscheme::TruncationStrategy
130135
expscheme::TruncationStrategy
@@ -136,14 +141,14 @@ end
136141
function DMRG2(;
137142
tol = Defaults.tol, maxiter = Defaults.maxiter, verbosity = Defaults.verbosity,
138143
miniter = 0, alg_eigsolve = (;), alg_svd = Defaults.alg_svd(), trscheme = notrunc(),
139-
expscheme = noexpand(), finalize = Defaults._finalize
144+
expscheme = noexpand(), finalize = Defaults._finalize, prepare=true
140145
)
141146
alg_eigsolve′ = alg_eigsolve isa NamedTuple ? Defaults.alg_eigsolve(; alg_eigsolve...) :
142147
alg_eigsolve
143-
return DMRG2(tol, maxiter, miniter, verbosity, alg_eigsolve′, alg_svd, trscheme, expscheme, finalize)
148+
return DMRG2(tol, maxiter, miniter, verbosity, alg_eigsolve′, alg_svd, prepare, trscheme, expscheme, finalize)
144149
end
145150

146-
function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG2, envs = environments(ψ, H), prepare=false)
151+
function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG2, envs = environments(ψ, H), backend=DefaultBackend(), allocator=BufferAllocator())
147152
ϵs = map(pos -> calc_galerkin(pos, ψ, H, ψ, envs), 1:length(ψ))
148153
ϵ = maximum(ϵs)
149154
log = IterLog("DMRG2")
@@ -164,15 +169,15 @@ function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG2, envs = environ
164169
for pos in 1:(length(ψ) - 1)
165170
local ac2, newA2center, al, c, ar
166171
@timeit timeroutput "AC2_eigsolve" begin
167-
@plansor ac2[-1 -2; -3 -4] := ψ.AC[pos][-1 -2; 1] * ψ.AR[pos + 1][1 -4; -3]
168-
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs; prepare)
172+
@plansor backend=backend allocator=allocator ac2[-1 -2; -3 -4] := ψ.AC[pos][-1 -2; 1] * ψ.AR[pos + 1][1 -4; -3]
173+
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs; prepare, backend, allocator)
169174
_, newA2center = fixedpoint(Hac2, ac2, :SR, alg_eigsolve)
170175
end
171176
@timeit timeroutput "svd_trunc" begin
172177
al, c, ar = svd_trunc!(newA2center; trunc = trscheme, alg = alg.alg_svd)
173178
al, c = changebonds_left(al, c, expscheme; ac2 = ac2)
174179
normalize!(c)
175-
v = @plansor ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
180+
v = @plansor backend=backend allocator=allocator ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
176181
ϵs[pos] = max(ϵs[pos], abs(1 - abs(v)))
177182
end
178183
@timeit timeroutput "update_AC" begin
@@ -185,15 +190,15 @@ function find_groundstate!(::FiniteChainStyle, ψ, H, alg::DMRG2, envs = environ
185190
for pos in (length(ψ) - 2):-1:1
186191
local ac2, newA2center, al, c, ar
187192
@timeit timeroutput "AC2_eigsolve" begin
188-
@plansor ac2[-1 -2; -3 -4] := ψ.AL[pos][-1 -2; 1] * ψ.AC[pos + 1][1 -4; -3]
189-
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs; prepare)
193+
@plansor backend=backend allocator=allocator ac2[-1 -2; -3 -4] := ψ.AL[pos][-1 -2; 1] * ψ.AC[pos + 1][1 -4; -3]
194+
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs; prepare = prepare, backend = backend, allocator = allocator)
190195
_, newA2center = fixedpoint(Hac2, ac2, :SR, alg_eigsolve)
191196
end
192197
@timeit timeroutput "svd_trunc" begin
193198
al, c, ar = svd_trunc!(newA2center; trunc = trscheme, alg = alg.alg_svd)
194199
c, ar = changebonds_right(c, ar, expscheme; ac2 = ac2)
195-
normalize!(c)
196-
v = @plansor ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
200+
normalize!(c)
201+
v = @plansor backend=backend allocator=allocator ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
197202
ϵs[pos] = max(ϵs[pos], abs(1 - abs(v)))
198203
end
199204
@timeit timeroutput "update_AC" begin

src/algorithms/groundstate/idmrg.jl

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,22 @@ end
6969

7070

7171
# Internal state of the IDMRG algorithm
72-
struct IDMRGState{S, O, E, T}
72+
struct IDMRGState{S, O, E, T, B, A}
7373
mps::S
7474
operator::O
7575
envs::E
7676
iter::Int
7777
ϵ::Float64 # TODO: Could be any <:Real
7878
energy::T
7979
timeroutput::TimerOutput
80+
backend::B
81+
allocator::A
8082
end
8183
function IDMRGState{T}(
8284
mps::S, operator::O, envs::E, iter::Int, ϵ::Float64, energy,
83-
timeroutput::TimerOutput,
84-
) where {S, O, E, T}
85-
return IDMRGState{S, O, E, T}(mps, operator, envs, iter, ϵ, T(energy), timeroutput)
85+
timeroutput::TimerOutput, backend::B, allocator::A = BufferAllocator()
86+
) where {S, O, E, T, B, A}
87+
return IDMRGState{S, O, E, T, B, A}(mps, operator, envs, iter, ϵ, T(energy), timeroutput, backend, allocator)
8688
end
8789

8890
function find_groundstate!(
@@ -105,7 +107,7 @@ function find_groundstate!(
105107
end
106108
end
107109

108-
state = IDMRGState(mps, operator, envs, iter, ϵ, E, timeroutput)
110+
state = IDMRGState(mps, operator, envs, iter, ϵ, E, timeroutput, DefaultBackend(), BufferAllocator())
109111
it = IterativeSolver(alg, state)
110112

111113
return LoggingExtras.withlevel(; alg.verbosity) do
@@ -146,7 +148,7 @@ function Base.iterate(
146148

147149
# update state
148150
it.state = IDMRGState{T}(
149-
mps, state.operator, envs, state.iter + 1, ϵ, E_new, timeroutput,
151+
mps, state.operator, envs, state.iter + 1, ϵ, E_new, timeroutput, state.backend, state.allocator
150152
)
151153

152154
return (mps, envs, ϵ, ΔE), it.state
@@ -160,7 +162,7 @@ function localupdate_step!(
160162
trscheme = updatetruncation(it.trscheme; iter = state.iter)
161163
return _localupdate_sweep_idmrg!(
162164
state.mps, state.operator, state.envs, alg_eigsolve, state.timeroutput,
163-
trscheme, expscheme
165+
trscheme, expscheme, state.backend, state.allocator
164166
)
165167
end
166168

@@ -173,46 +175,46 @@ function localupdate_step!(
173175
return _localupdate_sweep_idmrg2!(
174176
state.mps, state.operator, state.envs, alg_eigsolve,
175177
trscheme, it.alg_svd, state.timeroutput,
176-
expscheme)
178+
expscheme, state.backend, state.allocator)
177179
end
178180

179-
function _localupdate_sweep_idmrg!(ψ, H, envs, alg_eigsolve, timeroutput::TimerOutput, alg_trscheme, expscheme)
181+
function _localupdate_sweep_idmrg!(ψ, H, envs, alg_eigsolve, timeroutput::TimerOutput, alg_trscheme, expscheme, backend = DefaultBackend(), allocator = BufferAllocator())
180182
local E
181183
C_old = ψ.C[0]
182184

183185
# left to right sweep
184186
_idmrg_move_right!(ψ, H, envs, 1, alg_trscheme, expscheme) # We don't update the first site, as the backwards sweep will update it!
185187
for pos in 2:length(ψ)
186188
println("Left: pos = $pos")
187-
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
189+
h = AC_hamiltonian(pos, ψ, H, ψ, envs; backend, allocator)
188190
_, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
189191
_idmrg_move_right!(ψ, H, envs, pos, alg_trscheme, expscheme)
190192
end
191-
_idmrg_move_left!(ψ, H, envs, length(ψ), alg_trscheme, expscheme)
193+
_idmrg_move_left!(ψ, H, envs, length(ψ), alg_trscheme, expscheme, backend, allocator)
192194

193195
# right to left sweep
194196
for pos in length(ψ)-1:-1:1
195197
println("Right: pos = $pos")
196-
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
198+
h = AC_hamiltonian(pos, ψ, H, ψ, envs; backend, allocator)
197199
E, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
198200

199-
_idmrg_move_left!(ψ, H, envs, pos, alg_trscheme, expscheme)
201+
_idmrg_move_left!(ψ, H, envs, pos, alg_trscheme, expscheme, backend, allocator)
200202
end
201203

202204

203205
return ψ, envs, C_old, E
204206
end
205207

206-
function _idmrg_move_right!(ψ, H,envs, pos, alg_trscheme, expscheme)
208+
function _idmrg_move_right!(ψ, H,envs, pos, alg_trscheme, expscheme, backend = DefaultBackend(), allocator = BufferAllocator())
207209
ψ.AL[pos], ψ.C[pos] = left_orth!.AC[pos]; trunc = alg_trscheme)
208210
ψ.AL[pos], ψ.C[pos] = changebonds_left.AL[pos], ψ.C[pos], expscheme)
209211
ψ.AC[pos + 1] = _mul_front.C[pos], ψ.AR[pos + 1])
210212
if pos == length(ψ)
211213
ψ.AL[pos + 1] = ψ.AC[pos + 1] / ψ.C[pos + 1]
212214
end
213-
transfer_leftenv!(envs, ψ, H, ψ, pos + 1)
215+
transfer_leftenv!(envs, ψ, H, ψ, pos + 1, backend, allocator)
214216
end
215-
function _idmrg_move_left!(ψ, H, envs, pos, alg_trscheme, expscheme)
217+
function _idmrg_move_left!(ψ, H, envs, pos, alg_trscheme, expscheme, backend = DefaultBackend(), allocator = BufferAllocator())
216218
C, temp = right_orth!(_transpose_tail.AC[pos]); trunc = alg_trscheme)
217219
C, temp = changebonds_right(C, temp, expscheme)
218220
ψ.C[pos] = C
@@ -222,15 +224,15 @@ function _idmrg_move_left!(ψ, H, envs, pos, alg_trscheme, expscheme)
222224
ψ.AR[pos - 1] = ψ.AC[pos - 1] / ψ.C[pos - 2]
223225
end
224226

225-
transfer_rightenv!(envs, ψ, H, ψ, pos - 1)
227+
transfer_rightenv!(envs, ψ, H, ψ, pos - 1, backend, allocator)
226228
end
227229

228-
function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg_svd, expscheme, tol=1e-8)
230+
function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg_svd, expscheme, tol=1e-8, backend = DefaultBackend(), allocator = BufferAllocator())
229231
# sweep from left to right
230232
for pos in 1:(length(ψ) - 1)
231233
@timeit timeroutput "AC2_eigsolve" begin
232234
ac2 = AC2(ψ, pos; kind = :ACAR)
233-
h_ac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
235+
h_ac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs; backend, allocator)
234236
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
235237
end
236238
@timeit timeroutput "svd_trunc" begin
@@ -244,8 +246,8 @@ function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg
244246
ψ.AC[pos + 1] = _transpose_front(c * ar)
245247
end
246248
@timeit timeroutput "transfer_env" begin
247-
transfer_leftenv!(envs, ψ, H, ψ, pos + 1)
248-
transfer_rightenv!(envs, ψ, H, ψ, pos)
249+
transfer_leftenv!(envs, ψ, H, ψ, pos + 1, backend, allocator)
250+
transfer_rightenv!(envs, ψ, H, ψ, pos, backend, allocator)
249251
end
250252
end
251253

@@ -256,7 +258,7 @@ function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg
256258
ψ.AC[1] = _mul_tail.AL[1], ψ.C[1])
257259
@timeit timeroutput "AC2_eigsolve" begin
258260
ac2 = AC2(ψ, length(ψ); kind = :ALAC)
259-
h_ac2 = AC2_hamiltonian(length(ψ), ψ, H, ψ, envs)
261+
h_ac2 = AC2_hamiltonian(length(ψ), ψ, H, ψ, envs; backend, allocator)
260262
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
261263
end
262264
@timeit timeroutput "svd_trunc" begin
@@ -279,15 +281,15 @@ function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg
279281

280282
# update environments
281283
@timeit timeroutput "transfer_env" begin
282-
transfer_leftenv!(envs, ψ, H, ψ, 1)
283-
transfer_rightenv!(envs, ψ, H, ψ, 0)
284+
transfer_leftenv!(envs, ψ, H, ψ, 1, backend, allocator)
285+
transfer_rightenv!(envs, ψ, H, ψ, 0, backend, allocator)
284286
end
285287

286288
# sweep from right to left
287289
for pos in (length(ψ) - 1):-1:1
288290
@timeit timeroutput "AC2_eigsolve" begin
289291
ac2 = AC2(ψ, pos; kind = :ALAC)
290-
h_ac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
292+
h_ac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs; backend, allocator)
291293
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
292294
end
293295
@timeit timeroutput "svd_trunc" begin
@@ -302,8 +304,8 @@ function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg
302304
ψ.AC[pos + 1] = _transpose_front(c * ar)
303305
end
304306
@timeit timeroutput "transfer_env" begin
305-
transfer_leftenv!(envs, ψ, H, ψ, pos + 1)
306-
transfer_rightenv!(envs, ψ, H, ψ, pos)
307+
transfer_leftenv!(envs, ψ, H, ψ, pos + 1, backend, allocator)
308+
transfer_rightenv!(envs, ψ, H, ψ, pos, backend, allocator)
307309
end
308310
end
309311

@@ -314,7 +316,7 @@ function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg
314316
ψ.AR[1] = _transpose_front(pinv.C[0]; atol=tol) * _transpose_tail.AC[1]))
315317

316318
ac2 = AC2(ψ, 0; kind = :ACAR)
317-
h_ac2 = AC2_hamiltonian(0, ψ, H, ψ, envs)
319+
h_ac2 = AC2_hamiltonian(0, ψ, H, ψ, envs; backend, allocator)
318320
E, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
319321
end
320322
@timeit timeroutput "svd_trunc" begin
@@ -332,8 +334,8 @@ function _localupdate_sweep_idmrg2!(ψ, H, envs, alg_eigsolve, alg_trscheme, alg
332334
end
333335

334336
@timeit timeroutput "transfer_env" begin
335-
transfer_leftenv!(envs, ψ, H, ψ, 1)
336-
transfer_rightenv!(envs, ψ, H, ψ, 0)
337+
transfer_leftenv!(envs, ψ, H, ψ, 1, backend, allocator)
338+
transfer_rightenv!(envs, ψ, H, ψ, 0, backend, allocator)
337339
end
338340
return ψ, envs, C_old, E
339341
end

0 commit comments

Comments
 (0)