Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworksNext"
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.1.11"
version = "0.1.12"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
3 changes: 3 additions & 0 deletions src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ include("lazynameddimsarrays.jl")
include("abstracttensornetwork.jl")
include("tensornetwork.jl")
include("contract_network.jl")
include("abstract_problem.jl")
include("iterators.jl")
include("adapters.jl")

end
1 change: 1 addition & 0 deletions src/abstract_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
abstract type AbstractProblem end
45 changes: 45 additions & 0 deletions src/adapters.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator

Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
process. This allows one to manually call a custom `compute!` or insert their own code it in
the loop body in place of `compute!`.
"""
struct IncrementOnly{S <: AbstractNetworkIterator} <: AbstractNetworkIterator
parent::S
end

islaststep(adapter::IncrementOnly) = islaststep(adapter.parent)
increment!(adapter::IncrementOnly) = increment!(adapter.parent)
compute!(adapter::IncrementOnly) = adapter

IncrementOnly(adapter::IncrementOnly) = adapter

"""
struct EachRegion{SweepIterator} <: AbstractNetworkIterator

Adapter that flattens each region iterator in the parent sweep iterator into a single
iterator.
"""
struct EachRegion{SI <: SweepIterator} <: AbstractNetworkIterator
parent::SI
end

# In keeping with Julia convention.
eachregion(iter::SweepIterator) = EachRegion(iter)

# Essential definitions
function islaststep(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
return islaststep(adapter.parent) && islaststep(region_iter)
end
function increment!(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter)
return adapter
end
function compute!(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
compute!(region_iter)
return adapter
end
170 changes: 170 additions & 0 deletions src/iterators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
abstract type AbstractNetworkIterator

A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins
with a call to `increment!` before executing `compute!`, however the initial call to
`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that
this call is implict. Termination of the iterator is controlled by the function `done`.
"""
abstract type AbstractNetworkIterator end

# We use greater than or equals here as we increment the state at the start of the iteration
islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator)

function Base.iterate(iterator::AbstractNetworkIterator, init = true)
# The assumption is that first "increment!" is implicit, therefore we must skip the
# the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not
# defined when length < 1,
init || islaststep(iterator) && return nothing
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
# define a method for increment! This way we avoid cases where one may wish to nest
# calls to different step! methods accidentaly incrementing multiple times.
init || increment!(iterator)
rv = compute!(iterator)
return rv, false
end

increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)}))
compute!(iterator::AbstractNetworkIterator) = iterator

step!(iterator::AbstractNetworkIterator) = step!(identity, iterator)
function step!(f, iterator::AbstractNetworkIterator)
compute!(iterator)
f(iterator)
increment!(iterator)
return iterator
end

#
# RegionIterator
#
"""
struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
"""
mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
problem::Problem
region_plan::RegionPlan
which_region::Int
const which_sweep::Int
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R}
if isempty(region_plan)
throw(ArgumentError("Cannot construct a region iterator with 0 elements."))
end
return new{P, R}(problem, region_plan, 1, sweep)
end
end

function RegionIterator(problem; sweep, sweep_kwargs...)
plan = region_plan(problem; sweep_kwargs...)
return RegionIterator(problem, plan, sweep)
end

state(region_iter::RegionIterator) = region_iter.which_region
Base.length(region_iter::RegionIterator) = length(region_iter.region_plan)

problem(region_iter::RegionIterator) = region_iter.problem

function current_region_plan(region_iter::RegionIterator)
return region_iter.region_plan[region_iter.which_region]
end

function current_region(region_iter::RegionIterator)
region, _ = current_region_plan(region_iter)
return region
end

function region_kwargs(region_iter::RegionIterator)
_, kwargs = current_region_plan(region_iter)
return kwargs
end
function region_kwargs(f::Function, iter::RegionIterator)
return get(region_kwargs(iter), Symbol(f, :_kwargs), (;))
end

function prev_region(region_iter::RegionIterator)
state(region_iter) <= 1 && return nothing
prev, _ = region_iter.region_plan[region_iter.which_region - 1]
return prev
end

function next_region(region_iter::RegionIterator)
islaststep(region_iter) && return nothing
next, _ = region_iter.region_plan[region_iter.which_region + 1]
return next
end

#
# Functions associated with RegionIterator
#
function increment!(region_iter::RegionIterator)
region_iter.which_region += 1
return region_iter
end

function compute!(iter::RegionIterator)
extract!(iter; region_kwargs(extract!, iter)...)
update!(iter; region_kwargs(update!, iter)...)
insert!(iter; region_kwargs(insert!, iter)...)

return iter
end

region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...)

#
# SweepIterator
#

mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator
region_iter::RegionIterator{Problem}
sweep_kwargs::Iterators.Stateful{Iter}
which_sweep::Int
function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter}
stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs)

first_state = Iterators.peel(stateful_sweep_kwargs)

if isnothing(first_state)
throw(ArgumentError("Cannot construct a sweep iterator with 0 elements."))
end

first_kwargs, _ = first_state
region_iter = RegionIterator(problem; sweep = 1, first_kwargs...)

return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1)
end
end

islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs))

region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter
problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter))

state(sweep_iter::SweepIterator) = sweep_iter.which_sweep
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs)
function increment!(sweep_iter::SweepIterator)
sweep_iter.which_sweep += 1
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs)
update_region_iterator!(sweep_iter; sweep_kwargs...)
return sweep_iter
end

function update_region_iterator!(iterator::SweepIterator; kwargs...)
sweep = state(iterator)
iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...)
return iterator
end

function compute!(sweep_iter::SweepIterator)
for _ in sweep_iter.region_iter
# TODO: Is it sensible to execute the default region callback function?
end
return
end

# More basic constructor where sweep_kwargs are constant throughout sweeps
function SweepIterator(problem, nsweeps::Int; sweep_kwargs...)
# Initialize this to an empty RegionIterator
sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps)
return SweepIterator(problem, sweep_kwargs_iter)
end
Loading
Loading