Skip to content

Commit fddea41

Browse files
mtfishmanclaude
andcommitted
Redesign apply_operator as plain function with strategy dispatch
Drop the `AlgorithmsInterface`-based framing for the singular `apply_operator`: it is now a regular function that takes an `ApplyOperatorAlgorithm` strategy and dispatches on it, in the same spirit as `message_update!` in the BP rewrite. The plural `apply_operators` keeps its AI-based Problem/Algorithm/State triple but now delegates to `apply_operator!` per step instead of going through `NestedAlgorithm`. `BPApplyGate` is the default strategy (registered via `AIE.default_algorithm(::typeof(apply_operator!), ::Type{<:Tuple})`), and destination allocation goes through a MAK-style `AIE.initialize_output` hook. The `cache!` keyword threads through all call sites with `nothing` meaning "build a fresh cache"; the nothing-handling lives in `initialize_cache(cache!, algorithm, state)` via a `::Nothing` overload. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 73d9859 commit fddea41

2 files changed

Lines changed: 58 additions & 94 deletions

File tree

src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ function select_algorithm(f, alg::NamedTuple, ::Type{Args}; kwargs...) where {Ar
8181
)
8282
return default_algorithm(f, Args; alg...)
8383
end
84+
# Allocate the destination for an in-place call to `f`. Operations overload
85+
# `initialize_output(::typeof(f), args..., alg)` to control allocation.
86+
function initialize_output(f, args...; kwargs...)
87+
return throw(MethodError(initialize_output, (f, args...)))
88+
end
89+
8490
function select_algorithm(f, alg::AbstractAlgorithm, ::Type{<:Tuple}; kwargs...)
8591
isempty(kwargs) || throw(
8692
ArgumentError(

src/apply/apply_operators.jl

Lines changed: 52 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import .AlgorithmsInterfaceExtensions as AIE
12
import AlgorithmsInterface as AI
23
import MatrixAlgebraKit as MAK
34
import NamedDimsArrays as NDA
@@ -9,49 +10,41 @@ using NamedDimsArrays: AbstractNamedDimsArray, dimnames, domainnames, nameddims,
910
replacedimnames, state
1011
using NamedGraphs.GraphsExtensions: all_edges, boundary_edges
1112

12-
# === NestedAlgorithm framework ===
13+
# === Top-level user entry point (singular) ===
1314

14-
abstract type NestedAlgorithm <: AI.Algorithm end
15+
abstract type ApplyOperatorAlgorithm <: AIE.AbstractAlgorithm end
1516

16-
function initialize_subsolve(
17-
problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State
18-
)
19-
return throw(MethodError(initialize_subsolve, (problem, algorithm, state)))
17+
function apply_operator! end
18+
19+
function apply_operator(algorithm::ApplyOperatorAlgorithm, operator, state; kwargs...)
20+
dest = AIE.initialize_output(apply_operator!, operator, state, algorithm)
21+
return apply_operator!(algorithm, dest, operator, state; kwargs...)
2022
end
2123

22-
function finalize_substate!(
23-
problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State, substate::AI.State
24+
# Convenience entries that pick the strategy via `AIE.select_algorithm`.
25+
function apply_operator!(dest, operator, state; alg = nothing, cache! = nothing, kwargs...)
26+
algorithm = AIE.select_algorithm(
27+
apply_operator!, alg, (dest, operator, state); kwargs...
2428
)
25-
state.iterate = substate.iterate
26-
return state
29+
return apply_operator!(algorithm, dest, operator, state; cache!)
2730
end
28-
29-
function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State)
30-
subproblem, subalgorithm, substate = initialize_subsolve(problem, algorithm, state)
31-
AI.solve!(subproblem, subalgorithm, substate)
32-
finalize_substate!(problem, algorithm, state, substate)
33-
return state
31+
function apply_operator(operator, state; alg = nothing, cache! = nothing, kwargs...)
32+
algorithm = AIE.select_algorithm(apply_operator!, alg, (operator, state); kwargs...)
33+
return apply_operator(algorithm, operator, state; cache!)
3434
end
3535

36-
# === apply_operators (plural, iterative over a list of operators) ===
37-
38-
function apply_operators(ops, state; op_alg = BPApplyGate(), kwargs...)
39-
problem = ApplyOperatorsProblem(; operators = ops, init = state)
40-
algorithm = ApplyOperators(;
41-
operator_algorithm = op_alg,
42-
stopping_criterion = AI.StopAfterIteration(length(ops))
43-
)
44-
return AI.solve(problem, algorithm; iterate = copy(state), kwargs...)
45-
end
36+
# === apply_operators (plural, still AI-based) ===
4637

4738
@kwdef struct ApplyOperatorsProblem{Ops, Init} <: AI.Problem
4839
operators::Ops
4940
init::Init
5041
end
5142

52-
@kwdef struct ApplyOperators{OpAlg} <: NestedAlgorithm
43+
@kwdef struct ApplyOperators{
44+
OpAlg <: ApplyOperatorAlgorithm, SC <: AI.StoppingCriterion,
45+
} <: AI.Algorithm
5346
operator_algorithm::OpAlg
54-
stopping_criterion::AI.StopAfterIteration
47+
stopping_criterion::SC
5548
end
5649

5750
@kwdef mutable struct ApplyOperatorsState{
@@ -66,7 +59,7 @@ end
6659
function AI.initialize_state(
6760
problem::ApplyOperatorsProblem, algorithm::ApplyOperators;
6861
iterate,
69-
cache! = initialize_cache(problem, algorithm, iterate),
62+
cache! = initialize_cache(nothing, algorithm.operator_algorithm, iterate),
7063
iteration::Int = 0
7164
)
7265
stopping_criterion_state = AI.initialize_state(
@@ -89,100 +82,65 @@ function AI.initialize_state!(
8982
return state
9083
end
9184

92-
function initialize_subsolve(
85+
function AI.step!(
9386
problem::ApplyOperatorsProblem, algorithm::ApplyOperators,
9487
state::ApplyOperatorsState
9588
)
96-
op_i = problem.operators[state.iteration]
97-
subproblem = ApplyOperatorProblem(; op = op_i, init = state.iterate)
98-
subalgorithm = algorithm.operator_algorithm
99-
substate = AI.initialize_state(
100-
subproblem, subalgorithm; state.iterate, cache! = state.cache
89+
op = problem.operators[state.iteration]
90+
apply_operator!(
91+
algorithm.operator_algorithm, state.iterate, op, state.iterate;
92+
cache! = state.cache
10193
)
102-
return subproblem, subalgorithm, substate
103-
end
104-
105-
function initialize_cache(problem::AI.Problem, algorithm::AI.Algorithm, iterate)
106-
return throw(MethodError(initialize_cache, (problem, algorithm, iterate)))
94+
return state
10795
end
10896

109-
function initialize_cache(
110-
problem::ApplyOperatorsProblem, algorithm::ApplyOperators, iterate
97+
function apply_operators(operators, state; op_alg = nothing, kwargs...)
98+
op_alg = AIE.select_algorithm(apply_operator!, op_alg, (state,))
99+
problem = ApplyOperatorsProblem(; operators, init = state)
100+
algorithm = ApplyOperators(;
101+
operator_algorithm = op_alg,
102+
stopping_criterion = AI.StopAfterIteration(length(operators))
111103
)
112-
subproblem = ApplyOperatorProblem(; op = first(problem.operators), init = iterate)
113-
subalgorithm = algorithm.operator_algorithm
114-
return initialize_cache(subproblem, subalgorithm, iterate)
115-
end
116-
117-
# === apply_operator (singular, one gate application) ===
118-
119-
@kwdef struct ApplyOperatorProblem{Op, Init} <: AI.Problem
120-
op::Op
121-
init::Init
122-
end
123-
124-
function apply_operator(op, state; alg = BPApplyGate(), kwargs...)
125-
problem = ApplyOperatorProblem(; op, init = state)
126-
return AI.solve(problem, alg; iterate = copy(state), kwargs...)
127-
end
128-
129-
function apply_operator!(dest, op, state; alg = BPApplyGate(), kwargs...)
130-
problem = ApplyOperatorProblem(; op, init = state)
131-
alg_state = AI.initialize_state(problem, alg; iterate = dest, kwargs...)
132-
return AI.solve!(problem, alg, alg_state)
104+
return AI.solve(problem, algorithm; iterate = copy(state), kwargs...)
133105
end
134106

135-
# === BPApplyGate (non-iterative; overloads solve_loop! directly) ===
107+
# === BPApplyGate strategy ===
136108

137-
@kwdef struct BPApplyGate{Trunc, PinvKwargs <: NamedTuple} <: AI.Algorithm
109+
@kwdef struct BPApplyGate{Trunc, PinvKwargs <: NamedTuple} <: ApplyOperatorAlgorithm
138110
trunc::Trunc = nothing
139111
pinv_kwargs::PinvKwargs = (; tol = 0)
140112
normalize::Bool = false
141113
end
142114

143-
@kwdef mutable struct BPApplyGateState{Iterate, Cache} <: AI.State
144-
iterate::Iterate
145-
cache::Cache
115+
function AIE.default_algorithm(::typeof(apply_operator!), ::Type{<:Tuple}; kwargs...)
116+
return BPApplyGate(; kwargs...)
146117
end
147-
148-
function AI.initialize_state(
149-
problem::ApplyOperatorProblem, algorithm::BPApplyGate;
150-
iterate, cache! = initialize_cache(problem, algorithm, iterate)
118+
function AIE.initialize_output(
119+
::typeof(apply_operator!), operator, state, ::BPApplyGate
151120
)
152-
return BPApplyGateState(; iterate, cache = cache!)
121+
return copy(state)
153122
end
154123

155-
# Non-iterative algorithm: no per-call state to reset.
156-
function AI.initialize_state!(
157-
::ApplyOperatorProblem, ::BPApplyGate, state::BPApplyGateState
124+
function apply_operator!(
125+
algorithm::BPApplyGate, dest, operator, state; cache! = nothing
158126
)
159-
return state
127+
cache! = initialize_cache(cache!, algorithm, state)
128+
apply_gate_bp!(
129+
dest, operator, state;
130+
cache!, algorithm.trunc, algorithm.pinv_kwargs, algorithm.normalize
131+
)
132+
return dest
160133
end
161134

135+
initialize_cache(cache!, ::BPApplyGate, iterate::AbstractTensorNetwork) = cache!
162136
# Initialize the BP message cache to identity square-root messages.
163-
function initialize_cache(
164-
::ApplyOperatorProblem, ::BPApplyGate, iterate::AbstractTensorNetwork
165-
)
137+
function initialize_cache(::Nothing, ::BPApplyGate, iterate::AbstractTensorNetwork)
166138
return sqrtmessagecache(all_edges(iterate)) do edge
167139
factor = iterate[dst(edge)]
168140
return state(one(similar_operator(factor, linkaxes(iterate, edge))))
169141
end
170142
end
171143

172-
# Non-iterative algorithm: bypass the step!/stopping-criterion loop.
173-
function AI.solve_loop!(
174-
problem::ApplyOperatorProblem, algorithm::BPApplyGate,
175-
state::BPApplyGateState
176-
)
177-
apply_gate_bp!(
178-
state.iterate, problem.op, problem.init;
179-
cache! = state.cache,
180-
trunc = algorithm.trunc, pinv_kwargs = algorithm.pinv_kwargs,
181-
normalize = algorithm.normalize
182-
)
183-
return state
184-
end
185-
186144
# === BP simple-update implementation ===
187145
#
188146
# The `cache!` here is assumed to be a `SqrtMessageCache`: messages on each

0 commit comments

Comments
 (0)