Skip to content

Commit 77047c8

Browse files
Merge pull request #4702 from SebastianM-C/smc/codegen_input_defaults
[AI] Use declared inputs for input-aware codegen on compiled systems
2 parents a0f8eda + 17839c4 commit 77047c8

10 files changed

Lines changed: 120 additions & 48 deletions

File tree

lib/ModelingToolkitBase/ext/MTKCasADiDynamicOptExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@ struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
7171
wrapped_model::CasADiModel
7272
kwargs::K
7373

74-
function CasADiDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
74+
function CasADiDynamicOptProblem(f, u0, tspan, p, model, kwargs)
7575
return new{
7676
typeof(u0), typeof(tspan), SciMLBase.isinplace(f, 5),
7777
typeof(p), typeof(f), typeof(kwargs),
7878
}(f, u0, tspan, p, model, kwargs)
7979
end
80+
function CasADiDynamicOptProblem(f, u0, tspan, p, model; kwargs...)
81+
return CasADiDynamicOptProblem(f, u0, tspan, p, model, kwargs)
82+
end
8083
end
8184

8285
function (M::MXLinearInterpolation)(τ)

lib/ModelingToolkitBase/ext/MTKInfiniteOptExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,15 @@ struct JuMPDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
5656
wrapped_model::InfiniteOptModel
5757
kwargs::K
5858

59-
function JuMPDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
59+
function JuMPDynamicOptProblem(f, u0, tspan, p, model, kwargs)
6060
return new{
6161
typeof(u0), typeof(tspan), SciMLBase.isinplace(f, 5),
6262
typeof(p), typeof(f), typeof(kwargs),
6363
}(f, u0, tspan, p, model, kwargs)
6464
end
65+
function JuMPDynamicOptProblem(f, u0, tspan, p, model; kwargs...)
66+
return JuMPDynamicOptProblem(f, u0, tspan, p, model, kwargs)
67+
end
6568
end
6669

6770
struct InfiniteOptDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
@@ -73,12 +76,15 @@ struct InfiniteOptDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
7376
wrapped_model::InfiniteOptModel
7477
kwargs::K
7578

76-
function InfiniteOptDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
79+
function InfiniteOptDynamicOptProblem(f, u0, tspan, p, model, kwargs)
7780
return new{
7881
typeof(u0), typeof(tspan), SciMLBase.isinplace(f),
7982
typeof(p), typeof(f), typeof(kwargs),
8083
}(f, u0, tspan, p, model, kwargs)
8184
end
85+
function InfiniteOptDynamicOptProblem(f, u0, tspan, p, model; kwargs...)
86+
return InfiniteOptDynamicOptProblem(f, u0, tspan, p, model, kwargs)
87+
end
8288
end
8389

8490
MTK.generate_internal_model(m::Type{InfiniteOptModel}) = InfiniteModel()

lib/ModelingToolkitBase/ext/MTKPyomoDynamicOptExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,15 @@ struct PyomoDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
8383
wrapped_model::PyomoDynamicOptModel
8484
kwargs::K
8585

86-
function PyomoDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
86+
function PyomoDynamicOptProblem(f, u0, tspan, p, model, kwargs)
8787
return new{
8888
typeof(u0), typeof(tspan), SciMLBase.isinplace(f, 5),
8989
typeof(p), typeof(f), typeof(kwargs),
9090
}(f, u0, tspan, p, model, kwargs)
9191
end
92+
function PyomoDynamicOptProblem(f, u0, tspan, p, model; kwargs...)
93+
return PyomoDynamicOptProblem(f, u0, tspan, p, model, kwargs)
94+
end
9295
end
9396

9497
function pysym_getproperty(s::Union{Num, SymbolicT}, name::Symbol)

lib/ModelingToolkitBase/src/inputoutput.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@re
3333
"""
3434
unbound_inputs(sys) = filter(x -> !is_bound(sys, x), inputs(sys))
3535

36+
"""
37+
default_codegen_inputs(sys)
38+
39+
The inputs to use by default for input-aware code generation.
40+
41+
For scheduled (compiled) systems this is `inputs(sys)`: the inputs declared to
42+
`mtkcompile` are the contract the simplified system was built around. The
43+
[`unbound_inputs`](@ref) heuristic cannot be used there — after flattening and
44+
simplification an effective input appears in equations together with variables
45+
from other namespaces and is therefore classified as bound, so `unbound_inputs`
46+
is empty for compiled hierarchical systems.
47+
48+
For unscheduled systems this is [`unbound_inputs`](@ref), which inspects the
49+
connection structure of the hierarchy to find external inputs.
50+
"""
51+
default_codegen_inputs(sys) = isscheduled(sys) ? inputs(sys) : unbound_inputs(sys)
52+
3653
"""
3754
bound_outputs(sys)
3855
@@ -184,7 +201,7 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
184201
"""
185202
(f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function(
186203
sys::System,
187-
inputs = unbound_inputs(sys),
204+
inputs = default_codegen_inputs(sys),
188205
disturbance_inputs = disturbances(sys);
189206
known_disturbance_inputs = nothing,
190207
implicit_dae = false,
@@ -222,7 +239,7 @@ f[1](x, inputs, p, t)
222239
```
223240
"""
224241
function generate_control_function(
225-
sys::AbstractSystem, inputs = unbound_inputs(sys),
242+
sys::AbstractSystem, inputs = default_codegen_inputs(sys),
226243
disturbance_inputs = disturbances(sys);
227244
known_disturbance_inputs = nothing,
228245
disturbance_argument = false,

lib/ModelingToolkitBase/src/problems/bvproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
wrap_gfw = Val{true}, cse, checkbounds
3737
)
3838

39-
n_controls = length(unbound_inputs(sys))
39+
n_controls = length(default_codegen_inputs(sys))
4040
f_prototype = n_controls > 0 ? zeros(eltype(u0), length(dvs) - n_controls) : nothing
4141
bcresid_prototype = zeros(eltype(u0), length(u0_idxs) + length(constraints(sys)))
4242

lib/ModelingToolkitBase/src/systems/codegen.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ function calculate_control_jacobian(
10241024
sparse = false, simplify = false
10251025
)
10261026
rhs = [eq.rhs for eq in full_equations(sys)]
1027-
ctrls = unbound_inputs(sys)
1027+
ctrls = default_codegen_inputs(sys)
10281028

10291029
if sparse
10301030
jac = sparsejacobian(rhs, ctrls, simplify = simplify)

lib/ModelingToolkitBase/src/systems/optimal_control_interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ is_explicit(tableau) = tableau isa DiffEqBase.ExplicitRKTableau
120120

121121
@fallback_iip_specialize function SciMLBase.ODEInputFunction{iip, specialize}(
122122
sys::System;
123-
inputs = inputs(sys),
123+
inputs = default_codegen_inputs(sys),
124124
disturbance_inputs = disturbances(sys),
125125
u0 = nothing, tgrad = false,
126126
jac = false, controljac = false,
@@ -363,7 +363,7 @@ function process_DynamicOptProblem(
363363
add_user_constraints!(fullmodel, sys, tspan, pmap)
364364
add_initial_constraints!(fullmodel, u0, u0_idxs, model_tspan[1])
365365

366-
return prob_type(f, u0, tspan, p, fullmodel, kwargs...), pmap
366+
return prob_type(f, u0, tspan, p, fullmodel; kwargs...), pmap
367367
end
368368

369369
function generate_time_variable! end

lib/ModelingToolkitBase/test/input_output_handling.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,44 @@ eqs = [D(x) ~ u]
391391
@test isequal(ModelingToolkitBase.outputs(ss1), [x[1], x[2], x[3]])
392392
end
393393

394+
@testset "default_codegen_inputs: declared inputs on compiled hierarchical systems" begin
395+
# After mtkcompile flattens a hierarchical model, an effective input appears
396+
# in equations together with variables from other namespaces and is therefore
397+
# classified as bound, so `unbound_inputs` is empty. Input-aware codegen must
398+
# fall back to the inputs declared to `mtkcompile` instead of silently
399+
# generating input-free dynamics.
400+
function TestActuator(; name)
401+
@variables u(t) [input = true] o(t) [output = true]
402+
return System([o ~ 2u], t; name)
403+
end
404+
function TestPlant(; name)
405+
@variables y(t) i(t)
406+
return System([D(y) ~ -y + i], t; name)
407+
end
408+
@named act = TestActuator()
409+
@named plant = TestPlant()
410+
@named hier = System([plant.i ~ act.o], t; systems = [act, plant])
411+
hier = complete(hier)
412+
ss = mtkcompile(hier; inputs = [hier.act.u])
413+
414+
@test isempty(unbound_inputs(ss))
415+
@test length(ModelingToolkitBase.default_codegen_inputs(ss)) == 1
416+
@test isequal(
417+
collect(ModelingToolkitBase.default_codegen_inputs(ss)),
418+
ModelingToolkitBase.inputs(ss)
419+
)
420+
421+
# The `ODEInputFunction` default must pick up the declared input rather than
422+
# generating dynamics with the input bound to its operating-point value.
423+
f = ModelingToolkitBase.SciMLBase.ODEInputFunction(ss)
424+
p = ModelingToolkitBase.get_p(ss, Dict(hier.act.u => 0.0, hier.plant.y => 1.0))
425+
@test f([1.0], [0.0], p, 0.0) == [-1.0]
426+
@test f([1.0], [5.0], p, 0.0) == [9.0]
427+
428+
# The control jacobian wrt the declared inputs is non-empty.
429+
@test size(ModelingToolkitBase.calculate_control_jacobian(ss)) == (1, 1)
430+
end
431+
394432
using ModelingToolkitStandardLibrary.Blocks
395433

396434
if @isdefined(ModelingToolkit)

lib/ModelingToolkitBase/test/optimization/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
33
CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f"
44
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
55
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
6-
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
76
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
87
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
98
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
@@ -14,7 +13,9 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1413
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
1514
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
1615
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
16+
OrdinaryDiffEqExplicitTableaus = "3278f1b1-0f5c-4cde-98e0-ba5eb00db955"
1717
OrdinaryDiffEqFIRK = "5960d6e9-dd7a-4743-88e7-cf307b64f125"
18+
OrdinaryDiffEqImplicitTableaus = "75f66a49-58fc-43e3-9173-2340726368f7"
1819
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
1920
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
2021
OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2"
@@ -32,5 +33,7 @@ ModelingToolkitBase = {path = "../.."}
3233
[compat]
3334
CasADi = "1.0.7"
3435
DataInterpolations = "8.8"
36+
OrdinaryDiffEqExplicitTableaus = "2"
37+
OrdinaryDiffEqImplicitTableaus = "2"
3538
SafeTestsets = "0.1, 1"
3639
SciMLTesting = "1"

0 commit comments

Comments
 (0)