diff --git a/Project.toml b/Project.toml index 6a3e077b..afd02cb4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" -version = "0.15.26" +version = "0.15.27" authors = ["Matthew Fishman , Joseph Tindall and contributors"] [workspace] diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 19b99c9a..c1be61fe 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -8,7 +8,10 @@ abstract type AbstractNetworkIterator end islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) function Base.iterate(iterator::AbstractNetworkIterator, init = true) - islaststep(iterator) && return nothing + # The assumption is that first "increment!" is implicit, therefore we must skip the + # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not + # defined when length < 1, + init || islaststep(iterator) && return nothing # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* # define a method for increment! This way we avoid cases where one may wish to nest # calls to different step! methods accidentaly incrementing multiple times. @@ -37,6 +40,9 @@ mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator which_region::Int const which_sweep::Int function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} + if length(region_plan) == 0 + throw(BoundsError("Cannot construct a region iterator with 0 elements.")) + end return new{P, R}(problem, region_plan, 1, sweep) end end @@ -112,8 +118,15 @@ mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator which_sweep::Int function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) - first_kwargs, _ = Iterators.peel(stateful_sweep_kwargs) + first_state = Iterators.peel(stateful_sweep_kwargs) + + if isnothing(first_state) + throw(BoundsError("Cannot construct a sweep iterator with 0 elements.")) + end + + first_kwargs, _ = first_state region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) + return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) end end diff --git a/test/solvers/test_iterators.jl b/test/solvers/test_iterators.jl index 8095f81d..de9522b2 100644 --- a/test/solvers/test_iterators.jl +++ b/test/solvers/test_iterators.jl @@ -1,5 +1,6 @@ -using ITensorNetworks: SweepIterator, compute!, eachregion, increment!, islaststep, state -using Test: @test, @testset +using ITensorNetworks: + RegionIterator, SweepIterator, compute!, eachregion, increment!, islaststep, state +using Test: @test, @test_throws, @testset module TestIteratorUtils @@ -47,6 +48,23 @@ end import .TestIteratorUtils @testset "`AbstractNetworkIterator` Interface" begin + @testset "Edge cases" begin + TI = TestIteratorUtils.TestIterator(1, 1, []) + cb = [] + @test islaststep(TI) + for _ in TI + @test islaststep(TI) + push!(cb, state(TI)) + end + @test length(cb) == 1 + @test length(TI.output) == 1 + @test only(cb) == 1 + + prob = TestIteratorUtils.TestProblem([]) + @test_throws BoundsError SweepIterator(prob, 0) + @test_throws BoundsError RegionIterator(prob, [], 1) + end + TI = TestIteratorUtils.TestIterator(1, 4, []) @test !islaststep((TI)) @@ -165,6 +183,16 @@ end @test prob.data[1:2:end] == fill(1, 5) @test prob.data[2:2:end] == fill(2, 5) + + let i = 1, prob = TestIteratorUtils.TestProblem([]) + SI = SweepIterator(prob, 1) + cb = [] + for _ in eachregion(SI) + push!(cb, i) + i += 1 + end + @test length(cb) == 2 + end end end end