1+ import . AlgorithmsInterfaceExtensions as AIE
12import AlgorithmsInterface as AI
23import MatrixAlgebraKit as MAK
34import NamedDimsArrays as NDA
@@ -9,49 +10,41 @@ using NamedDimsArrays: AbstractNamedDimsArray, dimnames, domainnames, nameddims,
910 replacedimnames, state
1011using NamedGraphs. GraphsExtensions: all_edges, boundary_edges
1112
12- # === NestedAlgorithm framework ===
13+ # === Top-level user entry point (singular) ===
1314
14- abstract type NestedAlgorithm <: AI.Algorithm end
15+ abstract type ApplyOperatorAlgorithm <: AIE.AbstractAlgorithm end
1516
16- function initialize_subsolve (
17- problem:: AI.Problem , algorithm:: AI.Algorithm , state:: AI.State
18- )
19- return throw (MethodError (initialize_subsolve, (problem, algorithm, state)))
17+ function apply_operator! end
18+
19+ function apply_operator (algorithm:: ApplyOperatorAlgorithm , operator, state; kwargs... )
20+ dest = AIE. initialize_output (apply_operator!, operator, state, algorithm)
21+ return apply_operator! (algorithm, dest, operator, state; kwargs... )
2022end
2123
22- function finalize_substate! (
23- problem:: AI.Problem , algorithm:: AI.Algorithm , state:: AI.State , substate:: AI.State
24+ # Convenience entries that pick the strategy via `AIE.select_algorithm`.
25+ function apply_operator! (dest, operator, state; alg = nothing , cache! = nothing , kwargs... )
26+ algorithm = AIE. select_algorithm (
27+ apply_operator!, alg, (dest, operator, state); kwargs...
2428 )
25- state. iterate = substate. iterate
26- return state
29+ return apply_operator! (algorithm, dest, operator, state; cache!)
2730end
28-
29- function AI. step! (problem:: AI.Problem , algorithm:: NestedAlgorithm , state:: AI.State )
30- subproblem, subalgorithm, substate = initialize_subsolve (problem, algorithm, state)
31- AI. solve! (subproblem, subalgorithm, substate)
32- finalize_substate! (problem, algorithm, state, substate)
33- return state
31+ function apply_operator (operator, state; alg = nothing , cache! = nothing , kwargs... )
32+ algorithm = AIE. select_algorithm (apply_operator!, alg, (operator, state); kwargs... )
33+ return apply_operator (algorithm, operator, state; cache!)
3434end
3535
36- # === apply_operators (plural, iterative over a list of operators) ===
37-
38- function apply_operators (ops, state; op_alg = BPApplyGate (), kwargs... )
39- problem = ApplyOperatorsProblem (; operators = ops, init = state)
40- algorithm = ApplyOperators (;
41- operator_algorithm = op_alg,
42- stopping_criterion = AI. StopAfterIteration (length (ops))
43- )
44- return AI. solve (problem, algorithm; iterate = copy (state), kwargs... )
45- end
36+ # === apply_operators (plural, still AI-based) ===
4637
4738@kwdef struct ApplyOperatorsProblem{Ops, Init} <: AI.Problem
4839 operators:: Ops
4940 init:: Init
5041end
5142
52- @kwdef struct ApplyOperators{OpAlg} <: NestedAlgorithm
43+ @kwdef struct ApplyOperators{
44+ OpAlg <: ApplyOperatorAlgorithm , SC <: AI.StoppingCriterion ,
45+ } <: AI.Algorithm
5346 operator_algorithm:: OpAlg
54- stopping_criterion:: AI.StopAfterIteration
47+ stopping_criterion:: SC
5548end
5649
5750@kwdef mutable struct ApplyOperatorsState{
6659function AI. initialize_state (
6760 problem:: ApplyOperatorsProblem , algorithm:: ApplyOperators ;
6861 iterate,
69- cache! = initialize_cache (problem , algorithm, iterate),
62+ cache! = initialize_cache (nothing , algorithm. operator_algorithm , iterate),
7063 iteration:: Int = 0
7164 )
7265 stopping_criterion_state = AI. initialize_state (
@@ -89,100 +82,65 @@ function AI.initialize_state!(
8982 return state
9083end
9184
92- function initialize_subsolve (
85+ function AI . step! (
9386 problem:: ApplyOperatorsProblem , algorithm:: ApplyOperators ,
9487 state:: ApplyOperatorsState
9588 )
96- op_i = problem. operators[state. iteration]
97- subproblem = ApplyOperatorProblem (; op = op_i, init = state. iterate)
98- subalgorithm = algorithm. operator_algorithm
99- substate = AI. initialize_state (
100- subproblem, subalgorithm; state. iterate, cache! = state. cache
89+ op = problem. operators[state. iteration]
90+ apply_operator! (
91+ algorithm. operator_algorithm, state. iterate, op, state. iterate;
92+ cache! = state. cache
10193 )
102- return subproblem, subalgorithm, substate
103- end
104-
105- function initialize_cache (problem:: AI.Problem , algorithm:: AI.Algorithm , iterate)
106- return throw (MethodError (initialize_cache, (problem, algorithm, iterate)))
94+ return state
10795end
10896
109- function initialize_cache (
110- problem:: ApplyOperatorsProblem , algorithm:: ApplyOperators , iterate
97+ function apply_operators (operators, state; op_alg = nothing , kwargs... )
98+ op_alg = AIE. select_algorithm (apply_operator!, op_alg, (state,))
99+ problem = ApplyOperatorsProblem (; operators, init = state)
100+ algorithm = ApplyOperators (;
101+ operator_algorithm = op_alg,
102+ stopping_criterion = AI. StopAfterIteration (length (operators))
111103 )
112- subproblem = ApplyOperatorProblem (; op = first (problem. operators), init = iterate)
113- subalgorithm = algorithm. operator_algorithm
114- return initialize_cache (subproblem, subalgorithm, iterate)
115- end
116-
117- # === apply_operator (singular, one gate application) ===
118-
119- @kwdef struct ApplyOperatorProblem{Op, Init} <: AI.Problem
120- op:: Op
121- init:: Init
122- end
123-
124- function apply_operator (op, state; alg = BPApplyGate (), kwargs... )
125- problem = ApplyOperatorProblem (; op, init = state)
126- return AI. solve (problem, alg; iterate = copy (state), kwargs... )
127- end
128-
129- function apply_operator! (dest, op, state; alg = BPApplyGate (), kwargs... )
130- problem = ApplyOperatorProblem (; op, init = state)
131- alg_state = AI. initialize_state (problem, alg; iterate = dest, kwargs... )
132- return AI. solve! (problem, alg, alg_state)
104+ return AI. solve (problem, algorithm; iterate = copy (state), kwargs... )
133105end
134106
135- # === BPApplyGate (non-iterative; overloads solve_loop! directly) ===
107+ # === BPApplyGate strategy ===
136108
137- @kwdef struct BPApplyGate{Trunc, PinvKwargs <: NamedTuple } <: AI.Algorithm
109+ @kwdef struct BPApplyGate{Trunc, PinvKwargs <: NamedTuple } <: ApplyOperatorAlgorithm
138110 trunc:: Trunc = nothing
139111 pinv_kwargs:: PinvKwargs = (; tol = 0 )
140112 normalize:: Bool = false
141113end
142114
143- @kwdef mutable struct BPApplyGateState{Iterate, Cache} <: AI.State
144- iterate:: Iterate
145- cache:: Cache
115+ function AIE. default_algorithm (:: typeof (apply_operator!), :: Type{<:Tuple} ; kwargs... )
116+ return BPApplyGate (; kwargs... )
146117end
147-
148- function AI. initialize_state (
149- problem:: ApplyOperatorProblem , algorithm:: BPApplyGate ;
150- iterate, cache! = initialize_cache (problem, algorithm, iterate)
118+ function AIE. initialize_output (
119+ :: typeof (apply_operator!), operator, state, :: BPApplyGate
151120 )
152- return BPApplyGateState (; iterate, cache = cache! )
121+ return copy (state )
153122end
154123
155- # Non-iterative algorithm: no per-call state to reset.
156- function AI. initialize_state! (
157- :: ApplyOperatorProblem , :: BPApplyGate , state:: BPApplyGateState
124+ function apply_operator! (
125+ algorithm:: BPApplyGate , dest, operator, state; cache! = nothing
158126 )
159- return state
127+ cache! = initialize_cache (cache!, algorithm, state)
128+ apply_gate_bp! (
129+ dest, operator, state;
130+ cache!, algorithm. trunc, algorithm. pinv_kwargs, algorithm. normalize
131+ )
132+ return dest
160133end
161134
135+ initialize_cache (cache!, :: BPApplyGate , iterate:: AbstractTensorNetwork ) = cache!
162136# Initialize the BP message cache to identity square-root messages.
163- function initialize_cache (
164- :: ApplyOperatorProblem , :: BPApplyGate , iterate:: AbstractTensorNetwork
165- )
137+ function initialize_cache (:: Nothing , :: BPApplyGate , iterate:: AbstractTensorNetwork )
166138 return sqrtmessagecache (all_edges (iterate)) do edge
167139 factor = iterate[dst (edge)]
168140 return state (one (similar_operator (factor, linkaxes (iterate, edge))))
169141 end
170142end
171143
172- # Non-iterative algorithm: bypass the step!/stopping-criterion loop.
173- function AI. solve_loop! (
174- problem:: ApplyOperatorProblem , algorithm:: BPApplyGate ,
175- state:: BPApplyGateState
176- )
177- apply_gate_bp! (
178- state. iterate, problem. op, problem. init;
179- cache! = state. cache,
180- trunc = algorithm. trunc, pinv_kwargs = algorithm. pinv_kwargs,
181- normalize = algorithm. normalize
182- )
183- return state
184- end
185-
186144# === BP simple-update implementation ===
187145#
188146# The `cache!` here is assumed to be a `SqrtMessageCache`: messages on each
0 commit comments