@@ -2,250 +2,158 @@ module AlgorithmsInterfaceExtensions
22
33import AlgorithmsInterface as AI
44
5- # ========================== Patches for AlgorithmsInterface.jl ============================
5+ # ============================ NestedAlgorithm ================= ============================
66
7- abstract type Problem <: AI.Problem end
8- abstract type Algorithm <: AI.Algorithm end
9- abstract type State <: AI.State end
7+ abstract type NestedAlgorithm <: AI.Algorithm end
108
11- function AI. initialize_state! (
12- problem:: Problem , algorithm:: Algorithm , state:: State ; iteration = 0 , kwargs...
13- )
14- for (k, v) in pairs (kwargs)
15- setproperty! (state, k, v)
16- end
17- state. iteration = iteration
18- AI. initialize_state! (
19- problem, algorithm, algorithm. stopping_criterion, state. stopping_criterion_state
9+ # Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it
10+ # returns the `(subproblem, subalgorithm, substate)` tuple that the next
11+ # inner `AI.solve!` call consumes. The default `finalize_substate!` copies
12+ # the substate's iterate back into the parent state; subtypes can override
13+ # when more is required.
14+ function initialize_subsolve (
15+ problem:: AI.Problem , algorithm:: AI.Algorithm , state:: AI.State
2016 )
21- return state
17+ return throw ( MethodError (initialize_subsolve, (problem, algorithm, state)))
2218end
2319
24- function AI. initialize_state (
25- problem:: Problem , algorithm:: Algorithm ; iterate, kwargs...
26- )
27- stopping_criterion_state = AI. initialize_state (
28- problem, algorithm, algorithm. stopping_criterion; iterate
20+ function finalize_substate! (
21+ problem:: AI.Problem , algorithm:: AI.Algorithm , state:: AI.State , substate:: AI.State
2922 )
30- return DefaultState (; iterate, stopping_criterion_state, kwargs... )
31- end
32-
33- # ============================ DefaultState ================================================
34-
35- @kwdef mutable struct DefaultState{
36- Iterate, StoppingCriterionState <: AI.StoppingCriterionState ,
37- } <: State
38- iterate:: Iterate
39- iteration:: Int = 0
40- stopping_criterion_state:: StoppingCriterionState
23+ state. iterate = substate. iterate
24+ return state
4125end
4226
43- # ============================ increment! ==================================================
44-
45- # Custom version of `increment!` that also takes the problem and algorithm as arguments.
46- function AI . increment ! (problem:: Problem , algorithm:: Algorithm , state:: State )
47- return AI . increment! ( state)
27+ function AI . step! (problem :: AI.Problem , algorithm :: NestedAlgorithm , state :: AI.State )
28+ subproblem, subalgorithm, substate = initialize_subsolve (problem, algorithm, state)
29+ AI . solve! (subproblem, subalgorithm, substate)
30+ finalize_substate ! (problem, algorithm, state, substate )
31+ return state
4832end
4933
50- # ============================ AlgorithmIterator ===========================================
34+ # ============================ NestedState ====== ===========================================
5135
52- abstract type AlgorithmIterator end
36+ # State that wraps an inner `substate` and forwards `:iterate` accesses to it,
37+ # so the inner-loop iterate is shared without duplicating storage on the outer
38+ # state. Subtypes must store the inner state as a field named `substate`.
39+ abstract type NestedState <: AI.State end
5340
54- function algorithm_iterator (
55- problem:: Problem , algorithm:: Algorithm , state:: State
56- )
57- return DefaultAlgorithmIterator (problem, algorithm, state)
58- end
59-
60- function AI. is_finished! (iterator:: AlgorithmIterator )
61- return AI. is_finished! (iterator. problem, iterator. algorithm, iterator. state)
62- end
63- function AI. is_finished (iterator:: AlgorithmIterator )
64- return AI. is_finished (iterator. problem, iterator. algorithm, iterator. state)
41+ # Use `getfield` on the right-hand side so future edits to this forwarder
42+ # can't accidentally recurse through the overload.
43+ function Base. getproperty (state:: NestedState , name:: Symbol )
44+ name === :iterate && return getfield (state, :substate ). iterate
45+ return getfield (state, name)
6546end
66- function AI. increment! (iterator:: AlgorithmIterator )
67- return AI. increment! (iterator. problem, iterator. algorithm, iterator. state)
47+ function Base. setproperty! (state:: NestedState , name:: Symbol , value)
48+ name === :iterate && return (getfield (state, :substate ). iterate = value)
49+ return setfield! (state, name, value)
6850end
69- function AI. step! (iterator:: AlgorithmIterator )
70- return AI. step! (iterator. problem, iterator. algorithm, iterator. state)
71- end
72- function Base. iterate (iterator:: AlgorithmIterator , init = nothing )
73- AI. is_finished! (iterator) && return nothing
74- AI. increment! (iterator)
75- AI. step! (iterator)
76- return iterator. state, nothing
51+ function Base. propertynames (state:: NestedState )
52+ return (fieldnames (typeof (state))... , :iterate )
7753end
7854
79- struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator
80- problem:: Problem
81- algorithm:: Algorithm
82- state:: State
83- end
55+ # ============================ select_algorithm / default_algorithm ========================
8456
85- # ============================ with_algorithmlogger ========================================
57+ # Like `MatrixAlgebraKit.select_algorithm` / `default_algorithm`, but
58+ # selection-relevant inputs are packed into an `args` tuple so the value
59+ # and type domains stay disjoint: `(1.2,)` vs `Tuple{Float64}`. Strategy
60+ # types subtype `AbstractAlgorithm` so the passthrough overload is generic.
61+ abstract type AbstractAlgorithm end
8662
87- # Allow passing functions, not just CallbackActions.
88- @inline function with_algorithmlogger (f, args:: Pair{Symbol, AI.LoggingAction} ...)
89- return AI. with_algorithmlogger (f, args... )
63+ function default_algorithm (f, :: Type{Args} ; kwargs... ) where {Args <: Tuple }
64+ return throw (MethodError (default_algorithm, (f, Args)))
9065end
91- @inline function with_algorithmlogger (f, args:: Pair{Symbol} ...)
92- return AI . with_algorithmlogger (f, ( first .( args) .=> AI . CallbackAction .( last .(args))) . .. )
66+ function default_algorithm (f, args:: Tuple ; kwargs ... )
67+ return default_algorithm (f, typeof ( args); kwargs ... )
9368end
9469
95- # ============================ NestedAlgorithm =============================================
96-
97- abstract type NestedAlgorithm <: Algorithm end
98-
99- nested_algorithm (f:: Function , int:: Int ; kwargs... ) = nested_algorithm (f, 1 : int; kwargs... )
100- function nested_algorithm (f:: Function , iterable; kwargs... )
101- return DefaultNestedAlgorithm (f, iterable; kwargs... )
70+ function select_algorithm (f, alg, args:: Tuple ; kwargs... )
71+ return select_algorithm (f, alg, typeof (args); kwargs... )
10272end
103-
104- max_iterations (algorithm:: NestedAlgorithm ) = length (algorithm. algorithms)
105-
106- function get_subproblem (
107- problem:: AI.Problem , algorithm:: NestedAlgorithm , state:: AI.State
73+ function select_algorithm (f, :: Nothing , :: Type{Args} ; kwargs... ) where {Args <: Tuple }
74+ return default_algorithm (f, Args; kwargs... )
75+ end
76+ function select_algorithm (f, alg:: NamedTuple , :: Type{Args} ; kwargs... ) where {Args <: Tuple }
77+ isempty (kwargs) || throw (
78+ ArgumentError (
79+ " Additional keyword arguments are not allowed when `alg` is a `NamedTuple`."
80+ )
10881 )
109- subproblem = problem
110- subalgorithm = algorithm. algorithms[state. iteration]
111- substate = AI. initialize_state (subproblem, subalgorithm; state. iterate)
112- return subproblem, subalgorithm, substate
82+ return default_algorithm (f, Args; alg... )
11383end
114-
115- function set_substate! (
116- problem:: AI.Problem , algorithm:: NestedAlgorithm , state:: AI.State , substate:: AI.State
84+ function select_algorithm (f, alg:: AbstractAlgorithm , :: Type{<:Tuple} ; kwargs... )
85+ isempty (kwargs) || throw (
86+ ArgumentError (
87+ " Additional keyword arguments are not allowed when `alg` is an `AbstractAlgorithm` instance."
88+ )
11789 )
118- state. iterate = substate. iterate
119- return state
90+ return alg
12091end
12192
122- function AI. step! (problem:: AI.Problem , algorithm:: NestedAlgorithm , state:: AI.State )
123- # Get the subproblem, subalgorithm, and substate.
124- subproblem, subalgorithm, substate = get_subproblem (problem, algorithm, state)
125-
126- # Solve the subproblem with the subalgorithm.
127- AI. solve! (subproblem, subalgorithm, substate)
128-
129- # Update the state with the substate.
130- set_substate! (problem, algorithm, state, substate)
93+ # ============================ StopWhenConverged ===========================================
13194
132- return state
95+ # Stopping criterion that fires once `iterate_diff(iterate, previous_iterate) < tol`.
96+ # Concrete iterate types must supply an `iterate_diff` method.
97+ function iterate_diff (a, b)
98+ return throw (MethodError (iterate_diff, (a, b)))
13399end
134100
135- #=
136- DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm})
137-
138- An algorithm that consists of running an algorithm at each iteration
139- from a list of stored algorithms.
140- =#
141- @kwdef struct DefaultNestedAlgorithm{
142- ChildAlgorithm <: Algorithm ,
143- Algorithms <: AbstractVector{ChildAlgorithm} ,
144- StoppingCriterion <: AI.StoppingCriterion ,
145- } <: NestedAlgorithm
146- algorithms:: Algorithms
147- stopping_criterion:: StoppingCriterion = AI. StopAfterIteration (length (algorithms))
101+ @kwdef struct StopWhenConverged <: AI.StoppingCriterion
102+ tol:: Float64
148103end
149- function DefaultNestedAlgorithm (f:: Function , iterable; kwargs... )
150- return DefaultNestedAlgorithm (; algorithms = f .(iterable), kwargs... )
151- end
152-
153- # ============================ FlattenedAlgorithm ==========================================
154104
155- # Flatten a nested algorithm.
156- abstract type FlattenedAlgorithm <: Algorithm end
157- abstract type FlattenedAlgorithmState <: State end
158-
159- function flattened_algorithm (f:: Function , nalgorithms:: Int ; kwargs... )
160- return DefaultFlattenedAlgorithm (f, nalgorithms; kwargs... )
105+ @kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState
106+ delta:: Float64 = Inf
107+ at_iteration:: Int = - 1
108+ previous_iterate:: Iterate
161109end
162110
163- function AI. initialize_state (
164- problem:: Problem , algorithm:: FlattenedAlgorithm ; kwargs...
165- )
166- stopping_criterion_state = AI. initialize_state (
167- problem, algorithm, algorithm. stopping_criterion
168- )
169- return DefaultFlattenedAlgorithmState (; stopping_criterion_state, kwargs... )
170- end
171- function AI. increment! (
172- problem:: Problem , algorithm:: Algorithm , state:: FlattenedAlgorithmState
173- )
174- # Increment the total iteration count.
175- state. iteration += 1
176- # TODO : Use `is_finished!` instead?
177- if state. child_iteration ≥ max_iterations (algorithm. algorithms[state. parent_iteration])
178- # We're on the last iteration of the child algorithm, so move to the next
179- # child algorithm.
180- state. parent_iteration += 1
181- state. child_iteration = 1
182- else
183- # Iterate the child algorithm.
184- state. child_iteration += 1
185- end
186- return state
111+ function AI. initialize_state (:: AI.Problem , :: AI.Algorithm , :: StopWhenConverged ; iterate)
112+ return StopWhenConvergedState (; previous_iterate = copy (iterate))
187113end
188- function AI. step! (
189- problem:: AI.Problem , algorithm:: FlattenedAlgorithm , state:: FlattenedAlgorithmState
190- )
191- algorithm_sweep = algorithm. algorithms[state. parent_iteration]
192- state_sweep = AI. initialize_state (
193- problem, algorithm_sweep;
194- state. iterate, iteration = state. child_iteration
114+
115+ function AI. initialize_state! (
116+ :: AI.Problem , :: AI.Algorithm , :: StopWhenConverged , st:: StopWhenConvergedState
195117 )
196- AI. step! (problem, algorithm_sweep, state_sweep)
197- state. iterate = state_sweep. iterate
198- return state
118+ st. delta = Inf
119+ return st
199120end
200121
201- @kwdef struct DefaultFlattenedAlgorithm{
202- ChildAlgorithm <: Algorithm ,
203- Algorithms <: AbstractVector{ChildAlgorithm} ,
204- StoppingCriterion <: AI.StoppingCriterion ,
205- } <: FlattenedAlgorithm
206- algorithms:: Algorithms
207- stopping_criterion:: StoppingCriterion =
208- AI. StopAfterIteration (sum (max_iterations, algorithms))
209- end
210- function DefaultFlattenedAlgorithm (f:: Function , nalgorithms:: Int ; kwargs... )
211- return DefaultFlattenedAlgorithm (; algorithms = f .(1 : nalgorithms), kwargs... )
212- end
122+ function AI. is_finished! (
123+ problem:: AI.Problem ,
124+ algorithm:: AI.Algorithm ,
125+ state:: AI.State ,
126+ c:: StopWhenConverged ,
127+ st:: StopWhenConvergedState
128+ )
129+ iterate = state. iterate
130+ previous_iterate = st. previous_iterate
213131
214- @kwdef mutable struct DefaultFlattenedAlgorithmState{
215- Iterate, StoppingCriterionState <: AI.StoppingCriterionState ,
216- } <: FlattenedAlgorithmState
217- iterate:: Iterate
218- iteration:: Int = 0
219- parent_iteration:: Int = 1
220- child_iteration:: Int = 0
221- stopping_criterion_state:: StoppingCriterionState
222- end
132+ delta = iterate_diff (iterate, previous_iterate)
223133
224- # ============================ NonIterativeAlgorithm =======================================
134+ st . previous_iterate = copy (iterate)
225135
226- # Algorithm that only performs a single step.
227- abstract type NonIterativeAlgorithm <: Algorithm end
228- abstract type NonIterativeAlgorithmState <: State end
136+ # delta = 0 initially, so skip this the first time.
137+ state. iteration == 0 && return false
229138
230- function AI. initialize_state (problem:: Problem , algorithm:: NonIterativeAlgorithm ; kwargs... )
231- return DefaultNonIterativeAlgorithmState (; kwargs... )
232- end
139+ st. delta = delta
233140
234- function AI. initialize_state! (
235- problem:: Problem ,
236- algorithm:: NonIterativeAlgorithm ,
237- state:: NonIterativeAlgorithmState
238- )
239- return state
240- end
141+ if AI. is_finished (problem, algorithm, state, c, st)
142+ st. at_iteration = state. iteration
143+ return true
144+ end
241145
242- function AI. solve_loop! (problem:: Problem , algorithm:: NonIterativeAlgorithm , state:: State )
243- return throw (MethodError (AI. solve_loop!, (problem, algorithm, state)))
146+ return false
244147end
245148
246- @kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} < :
247- NonIterativeAlgorithmState
248- iterate:: Iterate
149+ function AI. is_finished (
150+ :: AI.Problem ,
151+ :: AI.Algorithm ,
152+ :: AI.State ,
153+ c:: StopWhenConverged ,
154+ st:: StopWhenConvergedState
155+ )
156+ return st. delta < c. tol
249157end
250158
251159end
0 commit comments