Skip to content

Commit 27798b7

Browse files
mtfishmanclaude
andauthored
Trim AlgorithmsInterfaceExtensions to a minimal NestedAlgorithm (#115)
Trim `AlgorithmsInterfaceExtensions` (AIE) to a focused set of helpers built on top of `AlgorithmsInterface`: - `NestedAlgorithm` — abstract `AI.Algorithm` whose `step!` delegates to a subsolve via `initialize_subsolve` / `finalize_substate!`. - `NestedState` — abstract `AI.State` that wraps an inner `substate` and forwards `:iterate` accesses to it, so the iterate is shared across nesting levels without duplication. - `StopWhenConverged` + `iterate_diff` — convergence-based stopping criterion plus the per-iterate-type hook it dispatches on. - `AbstractAlgorithm` + `select_algorithm` / `default_algorithm` — MatrixAlgebraKit-style algorithm-selection helpers. Selection-relevant inputs are packed into an `args` tuple so the value and type domains stay disjoint, and operations register their default by overloading `default_algorithm(::typeof(f), ::Type{<:Tuple})`. The belief-propagation code is refactored into two `AlgorithmsInterface` `Problem` / `Algorithm` / `State` triples: - `BeliefPropagation{Problem, Algorithm, State}` — outer iteration. The algorithm is an `AIE.NestedAlgorithm` carrying a single inner `subalgorithm`; the state is an `AIE.NestedState` wrapping the inner `substate` so the message store persists across sweeps. - `BeliefPropagationSweep{Problem, Algorithm, State}` — one sweep over edges. A plain `AI.Algorithm` whose `step!` performs one message update by delegating to a `MessageUpdateAlgorithm` strategy. A `MessageUpdateAlgorithm` is a lightweight strategy supertype (not an `AI.Algorithm`): `message_update!(algorithm, cache, factors, edge)` is the single per-edge entry point. `SimpleMessageUpdate` is the default, carrying `normalize` and `contraction_alg`. Iteration- or edge-dependent behavior is added by defining a new strategy subtype and overloading `message_update!` (for per-edge variation) or a new sweep algorithm subtype and overloading `AIE.initialize_subsolve` (for per-iteration variation). The top-level `beliefpropagation(factors, messages; …)` entry point exposes `stopping_criterion` and `message_update_algorithm` kwargs, each accepting either an explicit instance, a `NamedTuple` forwarded to a default constructor, or flat kwargs. `stopping_criterion = (; maxiter = 10, tol = 1.0e-10)` combines a `StopAfterIteration` and a `StopWhenConverged` via `|`. --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
1 parent bd91366 commit 27798b7

11 files changed

Lines changed: 478 additions & 983 deletions

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
3-
version = "0.4.2"
3+
version = "0.4.3"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]

src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl

Lines changed: 106 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -2,250 +2,158 @@ module AlgorithmsInterfaceExtensions
22

33
import AlgorithmsInterface as AI
44

5-
# ========================== Patches for AlgorithmsInterface.jl ============================
5+
# ============================ NestedAlgorithm =============================================
66

7-
abstract type Problem <: AI.Problem end
8-
abstract type Algorithm <: AI.Algorithm end
9-
abstract type State <: AI.State end
7+
abstract type NestedAlgorithm <: AI.Algorithm end
108

11-
function AI.initialize_state!(
12-
problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs...
13-
)
14-
for (k, v) in pairs(kwargs)
15-
setproperty!(state, k, v)
16-
end
17-
state.iteration = iteration
18-
AI.initialize_state!(
19-
problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state
9+
# Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it
10+
# returns the `(subproblem, subalgorithm, substate)` tuple that the next
11+
# inner `AI.solve!` call consumes. The default `finalize_substate!` copies
12+
# the substate's iterate back into the parent state; subtypes can override
13+
# when more is required.
14+
function initialize_subsolve(
15+
problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State
2016
)
21-
return state
17+
return throw(MethodError(initialize_subsolve, (problem, algorithm, state)))
2218
end
2319

24-
function AI.initialize_state(
25-
problem::Problem, algorithm::Algorithm; iterate, kwargs...
26-
)
27-
stopping_criterion_state = AI.initialize_state(
28-
problem, algorithm, algorithm.stopping_criterion; iterate
20+
function finalize_substate!(
21+
problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State, substate::AI.State
2922
)
30-
return DefaultState(; iterate, stopping_criterion_state, kwargs...)
31-
end
32-
33-
# ============================ DefaultState ================================================
34-
35-
@kwdef mutable struct DefaultState{
36-
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
37-
} <: State
38-
iterate::Iterate
39-
iteration::Int = 0
40-
stopping_criterion_state::StoppingCriterionState
23+
state.iterate = substate.iterate
24+
return state
4125
end
4226

43-
# ============================ increment! ==================================================
44-
45-
# Custom version of `increment!` that also takes the problem and algorithm as arguments.
46-
function AI.increment!(problem::Problem, algorithm::Algorithm, state::State)
47-
return AI.increment!(state)
27+
function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State)
28+
subproblem, subalgorithm, substate = initialize_subsolve(problem, algorithm, state)
29+
AI.solve!(subproblem, subalgorithm, substate)
30+
finalize_substate!(problem, algorithm, state, substate)
31+
return state
4832
end
4933

50-
# ============================ AlgorithmIterator ===========================================
34+
# ============================ NestedState =================================================
5135

52-
abstract type AlgorithmIterator end
36+
# State that wraps an inner `substate` and forwards `:iterate` accesses to it,
37+
# so the inner-loop iterate is shared without duplicating storage on the outer
38+
# state. Subtypes must store the inner state as a field named `substate`.
39+
abstract type NestedState <: AI.State end
5340

54-
function algorithm_iterator(
55-
problem::Problem, algorithm::Algorithm, state::State
56-
)
57-
return DefaultAlgorithmIterator(problem, algorithm, state)
58-
end
59-
60-
function AI.is_finished!(iterator::AlgorithmIterator)
61-
return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state)
62-
end
63-
function AI.is_finished(iterator::AlgorithmIterator)
64-
return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state)
41+
# Use `getfield` on the right-hand side so future edits to this forwarder
42+
# can't accidentally recurse through the overload.
43+
function Base.getproperty(state::NestedState, name::Symbol)
44+
name === :iterate && return getfield(state, :substate).iterate
45+
return getfield(state, name)
6546
end
66-
function AI.increment!(iterator::AlgorithmIterator)
67-
return AI.increment!(iterator.problem, iterator.algorithm, iterator.state)
47+
function Base.setproperty!(state::NestedState, name::Symbol, value)
48+
name === :iterate && return (getfield(state, :substate).iterate = value)
49+
return setfield!(state, name, value)
6850
end
69-
function AI.step!(iterator::AlgorithmIterator)
70-
return AI.step!(iterator.problem, iterator.algorithm, iterator.state)
71-
end
72-
function Base.iterate(iterator::AlgorithmIterator, init = nothing)
73-
AI.is_finished!(iterator) && return nothing
74-
AI.increment!(iterator)
75-
AI.step!(iterator)
76-
return iterator.state, nothing
51+
function Base.propertynames(state::NestedState)
52+
return (fieldnames(typeof(state))..., :iterate)
7753
end
7854

79-
struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator
80-
problem::Problem
81-
algorithm::Algorithm
82-
state::State
83-
end
55+
# ============================ select_algorithm / default_algorithm ========================
8456

85-
# ============================ with_algorithmlogger ========================================
57+
# Like `MatrixAlgebraKit.select_algorithm` / `default_algorithm`, but
58+
# selection-relevant inputs are packed into an `args` tuple so the value
59+
# and type domains stay disjoint: `(1.2,)` vs `Tuple{Float64}`. Strategy
60+
# types subtype `AbstractAlgorithm` so the passthrough overload is generic.
61+
abstract type AbstractAlgorithm end
8662

87-
# Allow passing functions, not just CallbackActions.
88-
@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...)
89-
return AI.with_algorithmlogger(f, args...)
63+
function default_algorithm(f, ::Type{Args}; kwargs...) where {Args <: Tuple}
64+
return throw(MethodError(default_algorithm, (f, Args)))
9065
end
91-
@inline function with_algorithmlogger(f, args::Pair{Symbol}...)
92-
return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...)
66+
function default_algorithm(f, args::Tuple; kwargs...)
67+
return default_algorithm(f, typeof(args); kwargs...)
9368
end
9469

95-
# ============================ NestedAlgorithm =============================================
96-
97-
abstract type NestedAlgorithm <: Algorithm end
98-
99-
nested_algorithm(f::Function, int::Int; kwargs...) = nested_algorithm(f, 1:int; kwargs...)
100-
function nested_algorithm(f::Function, iterable; kwargs...)
101-
return DefaultNestedAlgorithm(f, iterable; kwargs...)
70+
function select_algorithm(f, alg, args::Tuple; kwargs...)
71+
return select_algorithm(f, alg, typeof(args); kwargs...)
10272
end
103-
104-
max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms)
105-
106-
function get_subproblem(
107-
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State
73+
function select_algorithm(f, ::Nothing, ::Type{Args}; kwargs...) where {Args <: Tuple}
74+
return default_algorithm(f, Args; kwargs...)
75+
end
76+
function select_algorithm(f, alg::NamedTuple, ::Type{Args}; kwargs...) where {Args <: Tuple}
77+
isempty(kwargs) || throw(
78+
ArgumentError(
79+
"Additional keyword arguments are not allowed when `alg` is a `NamedTuple`."
80+
)
10881
)
109-
subproblem = problem
110-
subalgorithm = algorithm.algorithms[state.iteration]
111-
substate = AI.initialize_state(subproblem, subalgorithm; state.iterate)
112-
return subproblem, subalgorithm, substate
82+
return default_algorithm(f, Args; alg...)
11383
end
114-
115-
function set_substate!(
116-
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State
84+
function select_algorithm(f, alg::AbstractAlgorithm, ::Type{<:Tuple}; kwargs...)
85+
isempty(kwargs) || throw(
86+
ArgumentError(
87+
"Additional keyword arguments are not allowed when `alg` is an `AbstractAlgorithm` instance."
88+
)
11789
)
118-
state.iterate = substate.iterate
119-
return state
90+
return alg
12091
end
12192

122-
function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State)
123-
# Get the subproblem, subalgorithm, and substate.
124-
subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state)
125-
126-
# Solve the subproblem with the subalgorithm.
127-
AI.solve!(subproblem, subalgorithm, substate)
128-
129-
# Update the state with the substate.
130-
set_substate!(problem, algorithm, state, substate)
93+
# ============================ StopWhenConverged ===========================================
13194

132-
return state
95+
# Stopping criterion that fires once `iterate_diff(iterate, previous_iterate) < tol`.
96+
# Concrete iterate types must supply an `iterate_diff` method.
97+
function iterate_diff(a, b)
98+
return throw(MethodError(iterate_diff, (a, b)))
13399
end
134100

135-
#=
136-
DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm})
137-
138-
An algorithm that consists of running an algorithm at each iteration
139-
from a list of stored algorithms.
140-
=#
141-
@kwdef struct DefaultNestedAlgorithm{
142-
ChildAlgorithm <: Algorithm,
143-
Algorithms <: AbstractVector{ChildAlgorithm},
144-
StoppingCriterion <: AI.StoppingCriterion,
145-
} <: NestedAlgorithm
146-
algorithms::Algorithms
147-
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms))
101+
@kwdef struct StopWhenConverged <: AI.StoppingCriterion
102+
tol::Float64
148103
end
149-
function DefaultNestedAlgorithm(f::Function, iterable; kwargs...)
150-
return DefaultNestedAlgorithm(; algorithms = f.(iterable), kwargs...)
151-
end
152-
153-
# ============================ FlattenedAlgorithm ==========================================
154104

155-
# Flatten a nested algorithm.
156-
abstract type FlattenedAlgorithm <: Algorithm end
157-
abstract type FlattenedAlgorithmState <: State end
158-
159-
function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...)
160-
return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...)
105+
@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState
106+
delta::Float64 = Inf
107+
at_iteration::Int = -1
108+
previous_iterate::Iterate
161109
end
162110

163-
function AI.initialize_state(
164-
problem::Problem, algorithm::FlattenedAlgorithm; kwargs...
165-
)
166-
stopping_criterion_state = AI.initialize_state(
167-
problem, algorithm, algorithm.stopping_criterion
168-
)
169-
return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...)
170-
end
171-
function AI.increment!(
172-
problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState
173-
)
174-
# Increment the total iteration count.
175-
state.iteration += 1
176-
# TODO: Use `is_finished!` instead?
177-
if state.child_iteration max_iterations(algorithm.algorithms[state.parent_iteration])
178-
# We're on the last iteration of the child algorithm, so move to the next
179-
# child algorithm.
180-
state.parent_iteration += 1
181-
state.child_iteration = 1
182-
else
183-
# Iterate the child algorithm.
184-
state.child_iteration += 1
185-
end
186-
return state
111+
function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate)
112+
return StopWhenConvergedState(; previous_iterate = copy(iterate))
187113
end
188-
function AI.step!(
189-
problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState
190-
)
191-
algorithm_sweep = algorithm.algorithms[state.parent_iteration]
192-
state_sweep = AI.initialize_state(
193-
problem, algorithm_sweep;
194-
state.iterate, iteration = state.child_iteration
114+
115+
function AI.initialize_state!(
116+
::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState
195117
)
196-
AI.step!(problem, algorithm_sweep, state_sweep)
197-
state.iterate = state_sweep.iterate
198-
return state
118+
st.delta = Inf
119+
return st
199120
end
200121

201-
@kwdef struct DefaultFlattenedAlgorithm{
202-
ChildAlgorithm <: Algorithm,
203-
Algorithms <: AbstractVector{ChildAlgorithm},
204-
StoppingCriterion <: AI.StoppingCriterion,
205-
} <: FlattenedAlgorithm
206-
algorithms::Algorithms
207-
stopping_criterion::StoppingCriterion =
208-
AI.StopAfterIteration(sum(max_iterations, algorithms))
209-
end
210-
function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
211-
return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
212-
end
122+
function AI.is_finished!(
123+
problem::AI.Problem,
124+
algorithm::AI.Algorithm,
125+
state::AI.State,
126+
c::StopWhenConverged,
127+
st::StopWhenConvergedState
128+
)
129+
iterate = state.iterate
130+
previous_iterate = st.previous_iterate
213131

214-
@kwdef mutable struct DefaultFlattenedAlgorithmState{
215-
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
216-
} <: FlattenedAlgorithmState
217-
iterate::Iterate
218-
iteration::Int = 0
219-
parent_iteration::Int = 1
220-
child_iteration::Int = 0
221-
stopping_criterion_state::StoppingCriterionState
222-
end
132+
delta = iterate_diff(iterate, previous_iterate)
223133

224-
# ============================ NonIterativeAlgorithm =======================================
134+
st.previous_iterate = copy(iterate)
225135

226-
# Algorithm that only performs a single step.
227-
abstract type NonIterativeAlgorithm <: Algorithm end
228-
abstract type NonIterativeAlgorithmState <: State end
136+
# delta = 0 initially, so skip this the first time.
137+
state.iteration == 0 && return false
229138

230-
function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...)
231-
return DefaultNonIterativeAlgorithmState(; kwargs...)
232-
end
139+
st.delta = delta
233140

234-
function AI.initialize_state!(
235-
problem::Problem,
236-
algorithm::NonIterativeAlgorithm,
237-
state::NonIterativeAlgorithmState
238-
)
239-
return state
240-
end
141+
if AI.is_finished(problem, algorithm, state, c, st)
142+
st.at_iteration = state.iteration
143+
return true
144+
end
241145

242-
function AI.solve_loop!(problem::Problem, algorithm::NonIterativeAlgorithm, state::State)
243-
return throw(MethodError(AI.solve_loop!, (problem, algorithm, state)))
146+
return false
244147
end
245148

246-
@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <:
247-
NonIterativeAlgorithmState
248-
iterate::Iterate
149+
function AI.is_finished(
150+
::AI.Problem,
151+
::AI.Algorithm,
152+
::AI.State,
153+
c::StopWhenConverged,
154+
st::StopWhenConvergedState
155+
)
156+
return st.delta < c.tol
249157
end
250158

251159
end

src/ITensorNetworksNext.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ include("abstracttensornetwork.jl")
1212
include("tensornetwork.jl")
1313
include("TensorNetworkGenerators/TensorNetworkGenerators.jl")
1414
include("contract_network.jl")
15-
include("sweeping/utils.jl")
16-
include("sweeping/eigenproblem.jl")
1715

1816
include("beliefpropagation/messagecache.jl")
19-
include("beliefpropagation/beliefpropagationproblem.jl")
17+
include("beliefpropagation/beliefpropagation.jl")
2018

2119
end

0 commit comments

Comments
 (0)