Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
Accessors = "0.1"
Aqua = "0.8.9"
BlockTensorKit = "0.1.4"
Compat = "3.47, 4.10"
Combinatorics = "1"
Compat = "3.47, 4.10"
DocStringExtensions = "0.9.3"
HalfIntegers = "1.6.0"
KrylovKit = "0.8.3, 0.9.2"
LinearAlgebra = "1.6"
LoggingExtras = "~1.0"
OhMyThreads = "0.7.0"
OhMyThreads = "0.7, 0.8"
OptimKit = "0.3.1, 0.4"
Pkg = "1"
Plots = "1.40"
Expand Down
119 changes: 80 additions & 39 deletions src/algorithms/approximate/vomps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,63 +5,104 @@ Base.@deprecate(approximate(ψ::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:M
alg.verbosity, alg.alg_gauge, alg.alg_environments),
envs...; kwargs...))

function approximate(ψ::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
alg::VOMPS, envs=environments(ψ, toapprox))
ϵ::Float64 = calc_galerkin(ψ, toapprox..., envs)
temp_ACs = similar.(ψ.AC)
scheduler = Defaults.scheduler[]
function approximate(mps::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
alg::VOMPS, envs=environments(mps, toapprox))
log = IterLog("VOMPS")
alg_environments = updatetol(alg.alg_environments, 0, ϵ)
recalculate!(envs, ψ, toapprox...; alg_environments.tol)
iter = 0
ϵ = calc_galerkin(mps, toapprox..., envs)
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
recalculate!(envs, mps, toapprox...; alg_environments.tol)

LoggingExtras.withlevel(; alg.verbosity) do
state = VOMPSState(mps, toapprox, envs, iter, ϵ)
it = IterativeSolver(alg, state)

return LoggingExtras.withlevel(; alg.verbosity) do
@infov 2 loginit!(log, ϵ)
for iter in 1:(alg.maxiter)
tmap!(eachcol(temp_ACs), 1:size(ψ, 2); scheduler) do col
return _vomps_localupdate(col, ψ, toapprox, envs)

for (mps, envs, ϵ) in it
if ϵ ≤ alg.tol
@infov 2 logfinish!(log, it.iter, ϵ)
return mps, envs, ϵ
end
if it.iter ≥ alg.maxiter
@warnv 1 logcancel!(log, it.iter, ϵ)
return mps, envs, ϵ
end
@infov 3 logiter!(log, it.iter, ϵ)
end

alg_gauge = updatetol(alg.alg_gauge, iter, ϵ)
ψ = MultilineMPS(temp_ACs, ψ.C[:, end]; alg_gauge.tol, alg_gauge.maxiter)
# this should never be reached
return it.state.mps, it.state.envs, it.state.ϵ
end
end

alg_environments = updatetol(alg.alg_environments, iter, ϵ)
recalculate!(envs, ψ, toapprox...; alg_environments.tol)
# need to specialize a bunch of functions because different arguments are passed with tuples
# TODO: can we avoid this?
function Base.iterate(it::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple})
ACs = localupdate_step!(it, state)
mps = gauge_step!(it, state, ACs)
envs = envs_step!(it, state, mps)

ψ, envs = alg.finalize(iter, ψ, toapprox, envs)::Tuple{typeof(ψ),typeof(envs)}
# finalizer step
mps, envs = it.finalize(state.iter, mps, state.operator, envs)::typeof((mps, envs))

ϵ = calc_galerkin(ψ, toapprox..., envs)
# error criterion
ϵ = calc_galerkin(mps, state.operator..., envs)

if ϵ <= alg.tol
@infov 2 logfinish!(log, iter, ϵ)
break
end
if iter == alg.maxiter
@warnv 1 logcancel!(log, iter, ϵ)
else
@infov 3 logiter!(log, iter, ϵ)
end
end
# update state
it.state = VOMPSState(mps, state.operator, envs, state.iter + 1, ϵ)

return (mps, envs, ϵ), it.state
end

# TODO: ac_proj and c_proj should be rewritten to also be simply ∂AC/∂C functions
# once these have better support for different above/below mps
function localupdate_step!(::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple},
::SerialScheduler)
alg_orth = QRpos()
eachsite = 1:length(state.mps)
ACs = similar(state.mps.AC)
dst_ACs = state.mps isa Multiline ? eachcol(ACs) : ACs

foreach(eachsite) do site
AC = circshift([ac_proj(row, loc, state.mps, state.toapprox, state.envs)
for row in 1:size(state.mps, 1)], 1)
C = circshift([c_proj(row, loc, state.mps, state.toapprox, state.envs)
for row in 1:size(state.mps, 1)], 1)
dst_ACs[site] = regauge!(AC, C; alg=alg_orth)
return nothing
end

return ψ, envs, ϵ
return ACs
end
function localupdate_step!(::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple},
scheduler)
alg_orth = QRpos()
eachsite = 1:length(state.mps)

function _vomps_localupdate(loc, ψ, Oϕ, envs, factalg=QRpos())
local tmp_AC, tmp_C
if Defaults.scheduler[] isa SerialScheduler
tmp_AC = circshift([ac_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)], 1)
tmp_C = circshift([c_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)], 1)
else
ACs = similar(state.mps.AC)
dst_ACs = state.mps isa Multiline ? eachcol(ACs) : ACs

tforeach(eachsite; scheduler) do site
local AC, C
@sync begin
Threads.@spawn begin
tmp_AC = circshift([ac_proj(row, loc, ψ, Oϕ, envs)
for row in 1:size(ψ, 1)], 1)
AC = circshift([ac_proj(row, site, state.mps, state.operator, state.envs)
for row in 1:size(state.mps, 1)], 1)
end
Threads.@spawn begin
tmp_C = circshift([c_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)],
1)
C = circshift([c_proj(row, site, state.mps, state.operator, state.envs)
for row in 1:size(state.mps, 1)], 1)
end
end
dst_ACs[site] = regauge!(AC, C; alg=alg_orth)
return nothing
end
return regauge!.(tmp_AC, tmp_C; alg=factalg)

return ACs
end

function envs_step!(it::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple}, mps)
alg_environments = updatetol(it.alg_environments, state.iter, state.ϵ)
return recalculate!(state.envs, mps, state.operator...; alg_environments.tol)
end
6 changes: 3 additions & 3 deletions src/algorithms/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function ∂AC(x::GenericMPSTensor{S,3}, operator::MPOTensor{S}, leftenv::MPSTen
end

# mpo multiline
function ∂AC(x::Vector, opp, leftenv, rightenv)
function ∂AC(x::AbstractVector, opp, leftenv, rightenv)
return circshift(map(∂AC, x, opp, leftenv, rightenv), 1)
end

Expand All @@ -109,7 +109,7 @@ function ∂AC2(x::AbstractTensorMap{<:Any,<:Any,3,3}, operator1::MPOTensor,
operator2[7 -6; 4 5] * τ[5 -5; 2 3]
end

function ∂AC2(x::Vector, opp1, opp2, leftenv, rightenv)
function ∂AC2(x::AbstractVector, opp1, opp2, leftenv, rightenv)
return circshift(map(∂AC2, x, opp1, opp2, leftenv, rightenv), 1)
end

Expand All @@ -122,7 +122,7 @@ function ∂C(x::MPSBondTensor, leftenv::MPSBondTensor, rightenv::MPSBondTensor)
@plansor toret[-1; -2] := leftenv[-1; 1] * x[1; 2] * rightenv[2; -2]
end

function ∂C(x::Vector, leftenv, rightenv)
function ∂C(x::AbstractVector, leftenv, rightenv)
return circshift(map(t -> ∂C(t...), zip(x, leftenv, rightenv)), 1)
end

Expand Down
151 changes: 106 additions & 45 deletions src/algorithms/groundstate/vumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,63 +35,124 @@ $(TYPEDFIELDS)
finalize::F = Defaults._finalize
end

function find_groundstate(ψ::InfiniteMPS, H, alg::VUMPS, envs=environments(ψ, H))
# initialization
scheduler = Defaults.scheduler[]
log = IterLog("VUMPS")
ϵ::Float64 = calc_galerkin(ψ, H, ψ, envs)
temp_ACs = similar.(ψ.AC)
alg_environments = updatetol(alg.alg_environments, 0, ϵ)
recalculate!(envs, ψ, H, ψ; alg_environments.tol)

LoggingExtras.withlevel(; alg.verbosity) do
@infov 2 loginit!(log, ϵ, sum(expectation_value(ψ, H, envs)))
for iter in 1:(alg.maxiter)
alg_eigsolve = updatetol(alg.alg_eigsolve, iter, ϵ)
tmap!(temp_ACs, 1:length(ψ); scheduler) do loc
return _vumps_localupdate(loc, ψ, H, envs, alg_eigsolve)
end
# Internal state of the VUMPS algorithm
struct VUMPSState{S,O,E}
mps::S
operator::O
envs::E
iter::Int
ϵ::Float64
which::Symbol
end

alg_gauge = updatetol(alg.alg_gauge, iter, ϵ)
ψ = InfiniteMPS(temp_ACs, ψ.C[end]; alg_gauge.tol, alg_gauge.maxiter)
function find_groundstate(mps::InfiniteMPS, operator, alg::VUMPS,
envs=environments(mps, operator))
return dominant_eigsolve(operator, mps, alg, envs; which=:SR)
end

alg_environments = updatetol(alg.alg_environments, iter, ϵ)
recalculate!(envs, ψ, H, ψ; alg_environments.tol)
function dominant_eigsolve(operator, mps, alg::VUMPS, envs=environments(mps, operator);
which)
log = IterLog("VUMPS")
iter = 0
ϵ = calc_galerkin(mps, operator, mps, envs)
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
recalculate!(envs, mps, operator, mps; alg_environments.tol)

ψ, envs = alg.finalize(iter, ψ, H, envs)::Tuple{typeof(ψ),typeof(envs)}
state = VUMPSState(mps, operator, envs, iter, ϵ, which)
it = IterativeSolver(alg, state)

ϵ = calc_galerkin(ψ, H, ψ, envs)
return LoggingExtras.withlevel(; alg.verbosity) do
@infov 2 loginit!(log, ϵ, sum(expectation_value(mps, operator, envs)))

# breaking conditions
if ϵ <= alg.tol
@infov 2 logfinish!(log, iter, ϵ, expectation_value(ψ, H, envs))
break
for (mps, envs, ϵ) in it
if ϵ alg.tol
@infov 2 logfinish!(log, it.iter, ϵ, expectation_value(mps, operator, envs))
return mps, envs, ϵ
end
if iter == alg.maxiter
@warnv 1 logcancel!(log, iter, ϵ, expectation_value(ψ, H, envs))
else
@infov 3 logiter!(log, iter, ϵ, expectation_value(ψ, H, envs))
if it.iter ≥ alg.maxiter
@warnv 1 logcancel!(log, it.iter, ϵ, expectation_value(mps, operator, envs))
return mps, envs, ϵ
end
@infov 3 logiter!(log, it.iter, ϵ, expectation_value(mps, operator, envs))
end

# this should never be reached
return it.state.mps, it.state.envs, it.state.ϵ
end
end

return ψ, envs, ϵ
function Base.iterate(it::IterativeSolver{<:VUMPS}, state=it.state)
ACs = localupdate_step!(it, state)
mps = gauge_step!(it, state, ACs)
envs = envs_step!(it, state, mps)

# finalizer step
mps, envs = it.finalize(state.iter, mps, state.operator, envs)::typeof((mps, envs))

# error criterion
ϵ = calc_galerkin(mps, state.operator, mps, envs)

# update state
it.state = VUMPSState(mps, state.operator, envs, state.iter + 1, ϵ, state.which)

return (mps, envs, ϵ), it.state
end

function _vumps_localupdate(loc, ψ, H, envs, eigalg, factalg=QRpos())
local AC′, C′
if Defaults.scheduler[] isa SerialScheduler
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
else
@sync begin
Threads.@spawn begin
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
end
Threads.@spawn begin
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
end
function localupdate_step!(it::IterativeSolver{<:VUMPS}, state,
scheduler=Defaults.scheduler[])
alg_eigsolve = updatetol(it.alg_eigsolve, state.iter, state.ϵ)
alg_orth = QRpos()

mps = state.mps
eachsite = 1:length(mps)
src_Cs = mps isa Multiline ? eachcol(mps.C) : mps.C
src_ACs = mps isa Multiline ? eachcol(mps.AC) : mps.AC
ACs = similar(mps.AC)
dst_ACs = mps isa Multiline ? eachcol(ACs) : ACs

tforeach(eachsite, src_ACs, src_Cs; scheduler) do site, AC₀, C₀
dst_ACs[site] = _localupdate_vumps_step!(site, mps, state.operator, state.envs,
AC₀, C₀; parallel=false, alg_orth,
state.which, alg_eigsolve)
return nothing
end

return ACs
end

function _localupdate_vumps_step!(site, mps, operator, envs, AC₀, C₀;
parallel::Bool=false, alg_orth=QRpos(),
alg_eigsolve=Defaults.eigsolver, which)
if !parallel
_, AC = fixedpoint(∂∂AC(site, mps, operator, envs), AC₀, which, alg_eigsolve)
_, C = fixedpoint(∂∂C(site, mps, operator, envs), C₀, which, alg_eigsolve)
return regauge!(AC, C; alg=alg_orth)
end

local AC, C
@sync begin
@spawn begin
_, AC = fixedpoint(∂∂AC(site, mps, operator, envs),
AC₀, which, alg_eigsolve)
end
@spawn begin
_, C = fixedpoint(∂∂C(site, mps, operator, envs),
C₀, which, alg_eigsolve)
end
end
return regauge!(AC′, C′; alg=factalg)
return regauge!(AC, C; alg=alg_orth)
end

function gauge_step!(it::IterativeSolver{<:VUMPS}, state, ACs::AbstractVector)
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
return InfiniteMPS(ACs, state.mps.C[end]; alg_gauge.tol, alg_gauge.maxiter)
end
function gauge_step!(it::IterativeSolver{<:VUMPS}, state, ACs::AbstractMatrix)
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
return MultilineMPS(ACs, @view(state.mps.C[:, end]); alg_gauge.tol, alg_gauge.maxiter)
end

function envs_step!(it::IterativeSolver{<:VUMPS}, state, mps)
alg_environments = updatetol(it.alg_environments, state.iter, state.ϵ)
return recalculate!(state.envs, mps, state.operator, mps; alg_environments.tol)
end
Loading