Skip to content

Commit 19e47ce

Browse files
committed
feat(solve): allow registry injection in dispatch and add mock dispatch tests
1 parent b21ebfb commit 19e47ce

2 files changed

Lines changed: 259 additions & 1 deletion

File tree

src/solve/dispatch.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ function CommonSolve.solve(
5151
normalized_init = CTModels.build_initial_guess(ocp, initial_guess)
5252

5353
# 3. Get registry for component completion
54-
registry = get_strategy_registry()
54+
registry = _extract_kwarg(kwargs, CTSolvers.StrategyRegistry)
55+
if isnothing(registry)
56+
registry = get_strategy_registry()
57+
end
5558

5659
# 4. Dispatch — asymmetric signatures:
5760
# ExplicitMode: extract typed components by type from kwargs (default nothing)
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
module TestDispatchLogic
2+
3+
import Test
4+
import OptimalControl
5+
import CTModels
6+
import CTDirect
7+
import CTSolvers
8+
import CTBase
9+
import CommonSolve
10+
11+
const VERBOSE = isdefined(Main, :TestOptions) ? Main.TestOptions.VERBOSE : true
12+
const SHOWTIMING = isdefined(Main, :TestOptions) ? Main.TestOptions.SHOWTIMING : true
13+
14+
# ============================================================================
15+
# TOP-LEVEL: Parametric Mock types
16+
# ============================================================================
17+
18+
struct MockOCP <: CTModels.AbstractModel end
19+
struct MockInit <: CTModels.AbstractInitialGuess end
20+
struct MockSolution <: CTModels.AbstractSolution
21+
components::Tuple
22+
end
23+
24+
# Parametric mocks to simulate ANY strategy ID found in methods.jl
25+
struct MockDiscretizer{ID} <: CTDirect.AbstractDiscretizer
26+
options::CTSolvers.StrategyOptions
27+
end
28+
29+
struct MockModeler{ID} <: CTSolvers.AbstractNLPModeler
30+
options::CTSolvers.StrategyOptions
31+
end
32+
33+
struct MockSolver{ID} <: CTSolvers.AbstractNLPSolver
34+
options::CTSolvers.StrategyOptions
35+
end
36+
37+
# ----------------------------------------------------------------------------
38+
# Strategies Interface Implementation
39+
# ----------------------------------------------------------------------------
40+
41+
# ID accessors
42+
CTSolvers.Strategies.id(::Type{MockDiscretizer{ID}}) where {ID} = ID
43+
CTSolvers.Strategies.id(::Type{MockModeler{ID}}) where {ID} = ID
44+
CTSolvers.Strategies.id(::Type{MockSolver{ID}}) where {ID} = ID
45+
46+
# Metadata (required by registry)
47+
CTSolvers.Strategies.metadata(::Type{<:MockDiscretizer}) = CTSolvers.Strategies.StrategyMetadata()
48+
CTSolvers.Strategies.metadata(::Type{<:MockModeler}) = CTSolvers.Strategies.StrategyMetadata()
49+
CTSolvers.Strategies.metadata(::Type{<:MockSolver}) = CTSolvers.Strategies.StrategyMetadata()
50+
51+
# Options accessors
52+
CTSolvers.Strategies.options(d::MockDiscretizer) = d.options
53+
CTSolvers.Strategies.options(m::MockModeler) = m.options
54+
CTSolvers.Strategies.options(s::MockSolver) = s.options
55+
56+
# Constructors (required by _build_or_use_strategy)
57+
function MockDiscretizer{ID}(; mode::Symbol=:strict, kwargs...) where {ID}
58+
opts = CTSolvers.Strategies.build_strategy_options(MockDiscretizer{ID}; mode=mode, kwargs...)
59+
return MockDiscretizer{ID}(opts)
60+
end
61+
62+
function MockModeler{ID}(; mode::Symbol=:strict, kwargs...) where {ID}
63+
opts = CTSolvers.Strategies.build_strategy_options(MockModeler{ID}; mode=mode, kwargs...)
64+
return MockModeler{ID}(opts)
65+
end
66+
67+
function MockSolver{ID}(; mode::Symbol=:strict, kwargs...) where {ID}
68+
opts = CTSolvers.Strategies.build_strategy_options(MockSolver{ID}; mode=mode, kwargs...)
69+
return MockSolver{ID}(opts)
70+
end
71+
72+
# ----------------------------------------------------------------------------
73+
# Mock Registry Builder
74+
# ----------------------------------------------------------------------------
75+
76+
function build_mock_registry_from_methods()::CTSolvers.StrategyRegistry
77+
# 1. Get all valid triplets from methods()
78+
# e.g. ((:collocation, :adnlp, :ipopt), ...)
79+
valid_methods = OptimalControl.methods()
80+
81+
# 2. Extract unique symbols for each category
82+
disc_ids = unique(m[1] for m in valid_methods)
83+
mod_ids = unique(m[2] for m in valid_methods)
84+
sol_ids = unique(m[3] for m in valid_methods)
85+
86+
# 3. Create tuple of Mock types for each ID
87+
# We need to map AbstractType => (MockType{ID1}, MockType{ID2}, ...)
88+
disc_types = Tuple(MockDiscretizer{id} for id in disc_ids)
89+
mod_types = Tuple(MockModeler{id} for id in mod_ids)
90+
sol_types = Tuple(MockSolver{id} for id in sol_ids)
91+
92+
# 4. Create registry
93+
return CTSolvers.create_registry(
94+
CTDirect.AbstractDiscretizer => disc_types,
95+
CTSolvers.AbstractNLPModeler => mod_types,
96+
CTSolvers.AbstractNLPSolver => sol_types
97+
)
98+
end
99+
100+
# ----------------------------------------------------------------------------
101+
# Layer 3 Overrides (Mock Resolution)
102+
# ----------------------------------------------------------------------------
103+
104+
# Override CommonSolve.solve (Explicit Mode final step)
105+
# This intercepts the call after components have been completed/instantiated.
106+
function CommonSolve.solve(
107+
::MockOCP, ::MockInit,
108+
d::MockDiscretizer, m::MockModeler, s::MockSolver;
109+
display::Bool
110+
)::MockSolution
111+
return MockSolution((d, m, s))
112+
end
113+
114+
# Override OptimalControl.solve_descriptive (Descriptive Mode final step)
115+
# This intercepts the call after mode detection.
116+
function OptimalControl.solve_descriptive(
117+
ocp::MockOCP, description::Symbol...;
118+
initial_guess, display::Bool, registry::CTSolvers.StrategyRegistry, kwargs...
119+
)::MockSolution
120+
# For testing purposes, we return a MockSolution containing the description symbols
121+
# and the registry itself to verify they were passed correctly.
122+
return MockSolution((description, registry))
123+
end
124+
125+
# ============================================================================
126+
# TESTS
127+
# ============================================================================
128+
129+
function test_dispatch_logic()
130+
Test.@testset "Dispatch Logic & Completion" verbose=VERBOSE showtiming=SHOWTIMING begin
131+
132+
ocp = MockOCP()
133+
init = MockInit()
134+
mock_registry = build_mock_registry_from_methods()
135+
136+
# Iterate over all valid methods defined in OptimalControl
137+
# This ensures we cover every supported combination
138+
for (d_id, m_id, s_id) in OptimalControl.methods()
139+
140+
method_str = "($d_id, $m_id, $s_id)"
141+
142+
# ----------------------------------------------------------------
143+
# TEST 1: Explicit Mode with FULL Components
144+
# ----------------------------------------------------------------
145+
# Verify that we can explicitly target EVERY method supported.
146+
147+
Test.@testset "Explicit Full: $method_str" begin
148+
149+
d_instance = MockDiscretizer{d_id}(CTSolvers.StrategyOptions())
150+
m_instance = MockModeler{m_id}(CTSolvers.StrategyOptions())
151+
s_instance = MockSolver{s_id}(CTSolvers.StrategyOptions())
152+
153+
sol = OptimalControl.solve(
154+
ocp;
155+
initial_guess=init,
156+
display=false,
157+
registry=mock_registry,
158+
discretizer=d_instance,
159+
modeler=m_instance,
160+
solver=s_instance
161+
)
162+
163+
Test.@test sol isa MockSolution
164+
(d_res, m_res, s_res) = sol.components
165+
166+
Test.@test d_res isa MockDiscretizer{d_id}
167+
Test.@test m_res isa MockModeler{m_id}
168+
Test.@test s_res isa MockSolver{s_id}
169+
end
170+
171+
# ----------------------------------------------------------------
172+
# TEST 2: Descriptive Mode
173+
# ----------------------------------------------------------------
174+
# We pass symbols (:collocation, :adnlp, :ipopt)
175+
# Should dispatch to solve_descriptive with these symbols
176+
177+
Test.@testset "Descriptive: $method_str" begin
178+
179+
sol = OptimalControl.solve(
180+
ocp, d_id, m_id, s_id;
181+
initial_guess=init,
182+
display=false,
183+
registry=mock_registry
184+
)
185+
186+
Test.@test sol isa MockSolution
187+
(desc_res, reg_res) = sol.components
188+
189+
# Check that description was passed correctly
190+
Test.@test desc_res == (d_id, m_id, s_id)
191+
192+
# Check that registry was passed correctly
193+
Test.@test reg_res === mock_registry
194+
end
195+
end
196+
197+
# ----------------------------------------------------------------
198+
# TEST 3: Partial Explicit (Defaults)
199+
# ----------------------------------------------------------------
200+
# Verify that providing partial components triggers completion
201+
# to a valid default (usually the first match).
202+
203+
Test.@testset "Explicit Partial (Defaults)" begin
204+
# Case: Only Discretizer(:collocation) provided
205+
# Expectation: Defaults to :adnlp, :ipopt (based on methods order)
206+
207+
d_instance = MockDiscretizer{:collocation}(CTSolvers.StrategyOptions())
208+
209+
sol = OptimalControl.solve(
210+
ocp;
211+
initial_guess=init,
212+
display=false,
213+
registry=mock_registry,
214+
discretizer=d_instance
215+
)
216+
217+
Test.@test sol isa MockSolution
218+
(d_res, m_res, s_res) = sol.components
219+
220+
Test.@test d_res isa MockDiscretizer{:collocation}
221+
# Verify it filled in valid components
222+
Test.@test m_res isa MockModeler
223+
Test.@test s_res isa MockSolver
224+
end
225+
226+
# ----------------------------------------------------------------
227+
# TEST 4: Default Registry Fallback
228+
# ----------------------------------------------------------------
229+
# Verify that if we don't pass `registry`, it falls back to the real one.
230+
231+
Test.@testset "Default Registry Fallback" begin
232+
sol = OptimalControl.solve(
233+
ocp, :foo, :bar;
234+
initial_guess=init,
235+
display=false
236+
)
237+
238+
(_, reg_res) = sol.components
239+
# It should NOT be our mock registry
240+
Test.@test reg_res !== mock_registry
241+
242+
# It should look like the real registry (checking internal families)
243+
# Real registry has CTDirect.AbstractDiscretizer, etc.
244+
families = reg_res.families
245+
Test.@test haskey(families, CTDirect.AbstractDiscretizer)
246+
Test.@test haskey(families, CTSolvers.AbstractNLPModeler)
247+
end
248+
249+
end
250+
end
251+
252+
end # module
253+
254+
# Entry point for TestRunner
255+
test_dispatch_logic() = TestDispatchLogic.test_dispatch_logic()

0 commit comments

Comments
 (0)