Skip to content

Commit 16a6155

Browse files
authored
Add TimerOutputs instrumentation to ground-state algorithms (QuantumKitHub#431)
* add TimerOutputs dependency * add VUMPS TimerOutputs support * add DMRG TimerOutputs support * add IDMRG TimerOutputs support * add GrassMann TimerOutputs support * add parallel annotations
1 parent 8c7fdb4 commit 16a6155

10 files changed

Lines changed: 366 additions & 183 deletions

File tree

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ uuid = "bb1c41ca-d63c-52ed-829e-0820dda26502"
33
version = "0.13.12"
44
authors = "Lukas Devos, Maarten Van Damme and contributors"
55

6+
[workspace]
7+
projects = ["test", "docs"]
8+
69
[deps]
710
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
811
BlockTensorKit = "5f87ffc2-9cf1-4a46-8172-465d160bd8cd"
@@ -21,6 +24,7 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
2124
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
2225
TensorKitManifolds = "11fa318c-39cb-4a83-b1ed-cdc7ba1e3684"
2326
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
27+
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2428
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2529

2630
[weakdeps]
@@ -29,9 +33,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2933
[extensions]
3034
MPSKitAdaptExt = "Adapt"
3135

32-
[workspace]
33-
projects = ["test", "docs"]
34-
3536
[compat]
3637
Accessors = "0.1"
3738
Adapt = "4"
@@ -51,5 +52,6 @@ RecipesBase = "1.1"
5152
TensorKit = "0.16.5"
5253
TensorKitManifolds = "0.7, 0.8"
5354
TensorOperations = "5.5.1"
55+
TimerOutputs = "0.5.29"
5456
VectorInterface = "0.2, 0.3, 0.4, 0.5"
5557
julia = "1.10"

src/MPSKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ using Random
8585
using Base: @kwdef, @propagate_inbounds
8686
using LoggingExtras
8787
using OhMyThreads
88+
using TimerOutputs: TimerOutput, @timeit, timeit, reset_timer!, disable_timer!, enable_timer!
8889

8990
# Includes
9091
# --------

src/algorithms/algorithm.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,13 @@ function Base.show(io::IO, ::MIME"text/plain", alg::Algorithm)
2020
end
2121
return nothing
2222
end
23+
24+
# TIMEROUTPUT utility
25+
# -------------------
26+
# Shared sentinel passed as the default `timeroutput` kwarg by functions that optionally accept a timer.
27+
# `@timeit` is a no-op when the destination timer is disabled, so unrelated callers pay no instrumentation cost.
28+
# The merge sites that mutate `timeroutput` must gate on `timeroutput.enabled` so they don't pollute this shared object.
29+
const DISABLED_TIMER = let t = TimerOutput("DISABLED")
30+
disable_timer!(t)
31+
t
32+
end

src/algorithms/grassmann.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module GrassmannMPS
1212

1313
using ..MPSKit
1414
using ..MPSKit: AbstractMPSEnvironments, InfiniteEnvironments, MultilineEnvironments,
15-
AC_projection, recalculate!
15+
AC_projection, recalculate!, TimerOutput, DISABLED_TIMER, @timeit
1616
using TensorKit
1717
using OhMyThreads
1818
import TensorKitManifolds.Grassmann
@@ -143,11 +143,12 @@ Compute the cost function and the tangent vector with respect to the `AL` parame
143143
"""
144144
function fg(
145145
state::FiniteMPS, operator::Union{O, LazySum{O}},
146-
envs::AbstractMPSEnvironments = environments(state, operator)
146+
envs::AbstractMPSEnvironments = environments(state, operator);
147+
timeroutput::TimerOutput = DISABLED_TIMER,
147148
) where {O <: FiniteMPOHamiltonian}
148-
f = expectation_value(state, operator, envs)
149+
f = @timeit timeroutput "expval" expectation_value(state, operator, envs)
149150
isapprox(imag(f), 0; atol = eps(abs(f))^(3 / 4)) || @warn "MPO might not be Hermitian: $f"
150-
gs = map(1:length(state)) do i
151+
gs = @timeit timeroutput "gradient" map(1:length(state)) do i
151152
AC′ = AC_projection(i, state, operator, state, envs)
152153
g = Grassmann.project(AC′, state.AL[i])
153154
return rmul(g, state.C[i]')
@@ -156,15 +157,18 @@ function fg(
156157
end
157158
function fg(
158159
state::InfiniteMPS, operator::Union{O, LazySum{O}},
159-
envs::AbstractMPSEnvironments = environments(state, operator)
160+
envs::AbstractMPSEnvironments = environments(state, operator);
161+
timeroutput::TimerOutput = DISABLED_TIMER,
160162
) where {O <: InfiniteMPOHamiltonian}
161-
recalculate!(envs, state, operator, state)
162-
f = expectation_value(state, operator, envs)
163+
@timeit timeroutput "envs (parallel)" recalculate!(envs, state, operator, state; timeroutput)
164+
f = @timeit timeroutput "expval" expectation_value(state, operator, envs)
163165
isapprox(imag(f), 0; atol = eps(abs(f))^(3 / 4)) || @warn "MPO might not be Hermitian: $f"
164166

165167
A = Core.Compiler.return_type(Grassmann.project, Tuple{eltype(state), eltype(state)})
166168
gs = Vector{A}(undef, length(state))
167-
tmap!(gs, 1:length(state); scheduler = MPSKit.Defaults.scheduler[]) do i
169+
@timeit timeroutput "gradient" tmap!(
170+
gs, 1:length(state); scheduler = MPSKit.Defaults.scheduler[]
171+
) do i
168172
AC′ = AC_projection(i, state, operator, state, envs)
169173
g = Grassmann.project(AC′, state.AL[i])
170174
return rmul(g, state.C[i]')
@@ -173,15 +177,18 @@ function fg(
173177
end
174178
function fg(
175179
state::InfiniteMPS, operator::Union{O, LazySum{O}},
176-
envs::AbstractMPSEnvironments = environments(state, operator)
180+
envs::AbstractMPSEnvironments = environments(state, operator);
181+
timeroutput::TimerOutput = DISABLED_TIMER,
177182
) where {O <: InfiniteMPO}
178-
recalculate!(envs, state, operator, state)
179-
f = expectation_value(state, operator, envs)
183+
@timeit timeroutput "envs (parallel)" recalculate!(envs, state, operator, state; timeroutput)
184+
f = @timeit timeroutput "expval" expectation_value(state, operator, envs)
180185
isapprox(imag(f), 0; atol = eps(abs(f))^(3 / 4)) || @warn "MPO might not be Hermitian: $f"
181186

182187
A = Core.Compiler.return_type(Grassmann.project, Tuple{eltype(state), eltype(state)})
183188
gs = Vector{A}(undef, length(state))
184-
tmap!(gs, eachindex(state); scheduler = MPSKit.Defaults.scheduler[]) do i
189+
@timeit timeroutput "gradient" tmap!(
190+
gs, eachindex(state); scheduler = MPSKit.Defaults.scheduler[]
191+
) do i
185192
AC′ = AC_projection(i, state, operator, state, envs)
186193
g = rmul!(Grassmann.project(AC′, state.AL[i]), -inv(f))
187194
return rmul(g, state.C[i]')
@@ -190,16 +197,19 @@ function fg(
190197
end
191198
function fg(
192199
state::MultilineMPS, operator::MultilineMPO,
193-
envs::MultilineEnvironments = environments(state, operator)
200+
envs::MultilineEnvironments = environments(state, operator);
201+
timeroutput::TimerOutput = DISABLED_TIMER,
194202
)
195203
@assert length(state) == 1 "not implemented"
196-
recalculate!(envs, state, operator, state)
197-
f = expectation_value(state, operator, envs)
204+
@timeit timeroutput "envs (parallel)" recalculate!(envs, state, operator, state; timeroutput)
205+
f = @timeit timeroutput "expval" expectation_value(state, operator, envs)
198206
isapprox(imag(f), 0; atol = eps(abs(f))^(3 / 4)) || @warn "MPO might not be Hermitian: $f"
199207

200208
A = Core.Compiler.return_type(Grassmann.project, Tuple{eltype(state), eltype(state)})
201209
gs = Matrix{A}(undef, size(state))
202-
tforeach(eachindex(state); scheduler = MPSKit.Defaults.scheduler[]) do i
210+
@timeit timeroutput "gradient" tforeach(
211+
eachindex(state); scheduler = MPSKit.Defaults.scheduler[]
212+
) do i
203213
AC′ = AC_projection(i, state, operator, state, envs)
204214
g = rmul!(Grassmann.project(AC′, state.AL[i]), -inv(f))
205215
gs[i] = rmul(g, state.C[i]')

src/algorithms/groundstate/dmrg.jl

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,39 @@ function find_groundstate!(ψ::AbstractFiniteMPS, H, alg::DMRG, envs = environme
3636
ϵs = map(pos -> calc_galerkin(pos, ψ, H, ψ, envs), 1:length(ψ))
3737
ϵ = maximum(ϵs)
3838
log = IterLog("DMRG")
39+
timeroutput = TimerOutput("DMRG")
40+
alg.verbosity > 3 || disable_timer!(timeroutput)
3941

4042
LoggingExtras.withlevel(; alg.verbosity) do
4143
@infov 2 loginit!(log, ϵ, expectation_value(ψ, H, envs))
4244
for iter in 1:(alg.maxiter)
4345
alg_eigsolve = updatetol(alg.alg_eigsolve, iter, ϵ)
4446

4547
zerovector!(ϵs)
46-
for pos in [1:(length(ψ) - 1); length(ψ):-1:2]
47-
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
48-
_, vec = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
49-
ϵs[pos] = max(ϵs[pos], calc_galerkin(pos, ψ, H, ψ, envs))
50-
ψ.AC[pos] = vec
48+
@timeit timeroutput "sweep" begin
49+
for pos in [1:(length(ψ) - 1); length(ψ):-1:2]
50+
local vec
51+
@timeit timeroutput "AC_eigsolve" begin
52+
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
53+
_, vec = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
54+
end
55+
ϵs[pos] = max(ϵs[pos], calc_galerkin(pos, ψ, H, ψ, envs))
56+
@timeit timeroutput "AC_update" ψ.AC[pos] = vec
57+
end
5158
end
5259
ϵ = maximum(ϵs)
5360

54-
ψ, envs = alg.finalize(iter, ψ, H, envs)::Tuple{typeof(ψ), typeof(envs)}
61+
ψ, envs = @timeit timeroutput "finalize" alg.finalize(
62+
iter, ψ, H, envs
63+
)::Tuple{typeof(ψ), typeof(envs)}
5564

5665
if ϵ <= alg.tol
66+
@infov 4 timeroutput
5767
@infov 2 logfinish!(log, iter, ϵ, expectation_value(ψ, H, envs))
5868
break
5969
end
6070
if iter == alg.maxiter
71+
@infov 4 timeroutput
6172
@warnv 1 logcancel!(log, iter, ϵ, expectation_value(ψ, H, envs))
6273
else
6374
@infov 3 logiter!(log, iter, ϵ, expectation_value(ψ, H, envs))
@@ -113,50 +124,68 @@ function find_groundstate!(ψ::AbstractFiniteMPS, H, alg::DMRG2, envs = environm
113124
ϵs = map(pos -> calc_galerkin(pos, ψ, H, ψ, envs), 1:length(ψ))
114125
ϵ = maximum(ϵs)
115126
log = IterLog("DMRG2")
127+
timeroutput = TimerOutput("DMRG2")
128+
alg.verbosity > 3 || disable_timer!(timeroutput)
116129

117130
LoggingExtras.withlevel(; alg.verbosity) do
118131
for iter in 1:(alg.maxiter)
119132
alg_eigsolve = updatetol(alg.alg_eigsolve, iter, ϵ)
120133
zerovector!(ϵs)
121134

122-
# left to right sweep
123-
for pos in 1:(length(ψ) - 1)
124-
@plansor ac2[-1 -2; -3 -4] := ψ.AC[pos][-1 -2; 1] * ψ.AR[pos + 1][1 -4; -3]
125-
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
126-
_, newA2center = fixedpoint(Hac2, ac2, :SR, alg_eigsolve)
127-
128-
al, c, ar = svd_trunc!(newA2center; trunc = alg.trscheme, alg = alg.alg_svd)
129-
normalize!(c)
130-
v = @plansor ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
131-
ϵs[pos] = max(ϵs[pos], abs(1 - abs(v)))
132-
133-
ψ.AC[pos] = (al, complex(c))
134-
ψ.AC[pos + 1] = (complex(c), _transpose_front(ar))
135-
end
136-
137-
# right to left sweep
138-
for pos in (length(ψ) - 2):-1:1
139-
@plansor ac2[-1 -2; -3 -4] := ψ.AL[pos][-1 -2; 1] * ψ.AC[pos + 1][1 -4; -3]
140-
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
141-
_, newA2center = fixedpoint(Hac2, ac2, :SR, alg_eigsolve)
142-
143-
al, c, ar = svd_trunc!(newA2center; trunc = alg.trscheme, alg = alg.alg_svd)
144-
normalize!(c)
145-
v = @plansor ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
146-
ϵs[pos] = max(ϵs[pos], abs(1 - abs(v)))
147-
148-
ψ.AC[pos + 1] = (complex(c), _transpose_front(ar))
149-
ψ.AC[pos] = (al, complex(c))
135+
@timeit timeroutput "sweep" begin
136+
# left to right sweep
137+
for pos in 1:(length(ψ) - 1)
138+
local ac2, newA2center, al, c, ar
139+
@timeit timeroutput "AC2_eigsolve" begin
140+
@plansor ac2[-1 -2; -3 -4] := ψ.AC[pos][-1 -2; 1] * ψ.AR[pos + 1][1 -4; -3]
141+
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
142+
_, newA2center = fixedpoint(Hac2, ac2, :SR, alg_eigsolve)
143+
end
144+
@timeit timeroutput "svd_trunc" begin
145+
al, c, ar = svd_trunc!(newA2center; trunc = alg.trscheme, alg = alg.alg_svd)
146+
normalize!(c)
147+
v = @plansor ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
148+
ϵs[pos] = max(ϵs[pos], abs(1 - abs(v)))
149+
end
150+
@timeit timeroutput "update_AC" begin
151+
ψ.AC[pos] = (al, complex(c))
152+
ψ.AC[pos + 1] = (complex(c), _transpose_front(ar))
153+
end
154+
end
155+
156+
# right to left sweep
157+
for pos in (length(ψ) - 2):-1:1
158+
local ac2, newA2center, al, c, ar
159+
@timeit timeroutput "AC2_eigsolve" begin
160+
@plansor ac2[-1 -2; -3 -4] := ψ.AL[pos][-1 -2; 1] * ψ.AC[pos + 1][1 -4; -3]
161+
Hac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
162+
_, newA2center = fixedpoint(Hac2, ac2, :SR, alg_eigsolve)
163+
end
164+
@timeit timeroutput "svd_trunc" begin
165+
al, c, ar = svd_trunc!(newA2center; trunc = alg.trscheme, alg = alg.alg_svd)
166+
normalize!(c)
167+
v = @plansor ac2[1 2; 3 4] * conj(al[1 2; 5]) * conj(c[5; 6]) * conj(ar[6; 3 4])
168+
ϵs[pos] = max(ϵs[pos], abs(1 - abs(v)))
169+
end
170+
@timeit timeroutput "update_AC" begin
171+
ψ.AC[pos + 1] = (complex(c), _transpose_front(ar))
172+
ψ.AC[pos] = (al, complex(c))
173+
end
174+
end
150175
end
151176

152177
ϵ = maximum(ϵs)
153-
ψ, envs = alg.finalize(iter, ψ, H, envs)::Tuple{typeof(ψ), typeof(envs)}
178+
ψ, envs = @timeit timeroutput "finalize" alg.finalize(
179+
iter, ψ, H, envs
180+
)::Tuple{typeof(ψ), typeof(envs)}
154181

155182
if ϵ <= alg.tol
183+
@infov 4 timeroutput
156184
@infov 2 logfinish!(log, iter, ϵ, expectation_value(ψ, H, envs))
157185
break
158186
end
159187
if iter == alg.maxiter
188+
@infov 4 timeroutput
160189
@warnv 1 logcancel!(log, iter, ϵ, expectation_value(ψ, H, envs))
161190
else
162191
@infov 3 logiter!(log, iter, ϵ, expectation_value(ψ, H, envs))

src/algorithms/groundstate/gradient_grassmann.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,34 @@ function find_groundstate(
6262
@warn "This is not fully supported - split the mps up in a sum of mps's and optimize separately"
6363
normalize!(ψ)
6464

65-
fg(x) = GrassmannMPS.fg(x, H, envs)
65+
timeroutput = TimerOutput("GradientGrassmann")
66+
method_verbosity = hasproperty(alg.method, :verbosity) ? alg.method.verbosity : 0
67+
method_verbosity > 3 || disable_timer!(timeroutput)
68+
69+
fg(x) = timeit(() -> GrassmannMPS.fg(x, H, envs; timeroutput), timeroutput, "fg")
70+
retract(state, g, α) = timeit(
71+
() -> GrassmannMPS.retract(state, g, α), timeroutput, "retract",
72+
)
73+
transport!(h, state, g, α, state′) = timeit(
74+
() -> GrassmannMPS.transport!(h, state, g, α, state′), timeroutput, "transport!",
75+
)
76+
precondition(state, g) = timeit(
77+
() -> GrassmannMPS.precondition(state, g), timeroutput, "precondition",
78+
)
79+
6680
x, _, _, _, normgradhistory = optimize(
6781
fg, ψ, alg.method;
68-
GrassmannMPS.transport!,
69-
GrassmannMPS.retract,
82+
retract, transport!, precondition,
7083
GrassmannMPS.inner,
7184
GrassmannMPS.scale!,
7285
GrassmannMPS.add!,
73-
GrassmannMPS.precondition,
7486
alg.finalize!,
75-
isometrictransport = true
87+
isometrictransport = true,
7688
)
89+
90+
LoggingExtras.withlevel(; verbosity = method_verbosity) do
91+
@infov 4 timeroutput
92+
end
93+
7794
return x, envs, normgradhistory[end]
7895
end

0 commit comments

Comments
 (0)