Skip to content

Commit e89fd0b

Browse files
refactor: improve type-stability of linearization
1 parent 9823257 commit e89fd0b

1 file changed

Lines changed: 77 additions & 30 deletions

File tree

src/linearization.jl

Lines changed: 77 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ function _build_op_from_solution(op::LinearizationOpPoint{S, <:AbstractVector})
6464
end
6565
end
6666

67+
function _linearization_wrap_odeproblem_f(@nospecialize(prob::ODEProblem), ::Type{T}) where {T}
68+
f = SciMLBase.Void{Any}(prob.f.f)
69+
u0 = T.(prob.u0)
70+
t = T(prob.tspan[1])
71+
f = SciMLBase.wrapfun_iip(f, (u0, u0, prob.p, t))
72+
odef = remake(prob.f; f = f)
73+
return remake(prob; f = odef)
74+
end
75+
6776
"""
6877
lin_fun, simplified_sys = linearization_function(sys::AbstractSystem, inputs, outputs; simplify = false, initialize = true, initialization_solver_alg = nothing, kwargs...)
6978
@@ -170,6 +179,7 @@ function linearization_function(
170179

171180
ps = parameters(sys)
172181
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
182+
h = SciMLBase.Void{Any}(h)
173183

174184
initialization_kwargs = (;
175185
abstol = initialization_abstol, reltol = initialization_reltol,
@@ -180,24 +190,21 @@ function linearization_function(
180190
t0 = current_time(prob)
181191
inputvals = [prob.ps[i] for i in inputs]
182192

183-
hp_fun = let fun = h, setter = setp_oop(sys, inputs)
184-
function hpf(du, input, u, p, t)
185-
p = setter(p, input)
186-
fun(du, u, p, t)
187-
return du
188-
end
189-
end
190193
if u0 === nothing
191194
T = typeof(t0)
192195
else
193196
T = promote_type(eltype(u0), typeof(t0))
194197
end
198+
prob = _linearization_wrap_odeproblem_f(prob, T)
195199
ct0 = DI.Constant(T(t0))
196200
u0T = if u0 === nothing
197201
u0
198202
else
199203
T.(u0)
200204
end
205+
h = SciMLBase.wrapfun_iip(h, (u0T, u0T, p, T(t0)))
206+
hp_fun = HPFun(h, setp_oop(sys, inputs))
207+
201208
cu0T = DI.Constant(u0T)
202209
cp = DI.Constant(p)
203210

@@ -209,11 +216,7 @@ function linearization_function(
209216
cu0T, cp, DI.Constant(t0)
210217
)
211218
else
212-
uf_fun = let fun = prob.f
213-
function uff(du, u, p, t)
214-
return SciMLBase.UJacobianWrapper(fun, t, p)(du, u)
215-
end
216-
end
219+
uf_fun = UFFun(prob.f)
217220

218221
uf_jac = PreparedJacobian{true}(
219222
uf_fun, similar(prob.u0, T), autodiff, u0T, cp, ct0
@@ -223,12 +226,8 @@ function linearization_function(
223226
h, similar(prob.u0, T, size(outputs)), autodiff,
224227
u0T, cp, ct0
225228
)
226-
pf_fun = let fun = prob.f, setter = setp_oop(sys, inputs)
227-
function pff(du, input, u, p, t)
228-
p = setter(p, input)
229-
return SciMLBase.ParamJacobianWrapper(fun, t, u)(du, p)
230-
end
231-
end
229+
230+
pf_fun = PFFun(prob.f, setp_oop(sys, inputs))
232231
pf_jac = PreparedJacobian{true}(
233232
pf_fun, similar(prob.u0, T), autodiff, inputvals,
234233
cu0T, cp, ct0
@@ -239,14 +238,45 @@ function linearization_function(
239238
)
240239
end
241240

241+
input_getter = getsym(prob, inputs)
242+
242243
lin_fun = LinearizationFunction(
243-
diff_idxs, alge_idxs, inputs, length(unknowns(sys)),
244+
diff_idxs, alge_idxs, input_getter, length(inputs), length(unknowns(sys)),
244245
prob, h, u0 === nothing ? nothing : similar(u0, T), uf_jac, h_jac, pf_jac,
245246
hp_jac, initializealg, initialization_kwargs
246247
)
247248
return lin_fun, sys
248249
end
249250

251+
struct HPFun{F, S}
252+
fn::F
253+
setter::S
254+
end
255+
256+
function (hpf::HPFun)(du, input, u, p, t)
257+
p = hpf.setter(p, input)
258+
hpf.fn(du, u, p, t)
259+
return du
260+
end
261+
262+
struct UFFun{F}
263+
fn::F
264+
end
265+
266+
function (uff::UFFun)(du, u, p, t)
267+
return SciMLBase.UJacobianWrapper(uff.fn, t, p)(du, u)
268+
end
269+
270+
struct PFFun{F, S}
271+
fn::F
272+
setter::S
273+
end
274+
275+
function (pff::PFFun)(du, input, u, p, t)
276+
p = pff.setter(p, input)
277+
return SciMLBase.ParamJacobianWrapper(pff.fn, t, u)(du, p)
278+
end
279+
250280
"""
251281
Return the set of indexes of differential equations and algebraic equations in the simplified system.
252282
"""
@@ -319,22 +349,26 @@ A callable struct which linearizes a system.
319349
$(TYPEDFIELDS)
320350
"""
321351
struct LinearizationFunction{
322-
DI <: AbstractVector{Int}, AI <: AbstractVector{Int}, I, P <: ODEProblem,
352+
I, P <: ODEProblem,
323353
H, C, J1, J2, J3, J4, IA <: SciMLBase.DAEInitializationAlgorithm, IK,
324354
}
325355
"""
326356
The indexes of differential equations in the linearized system.
327357
"""
328-
diff_idxs::DI
358+
diff_idxs::Vector{Int}
329359
"""
330360
The indexes of algebraic equations in the linearized system.
331361
"""
332-
alge_idxs::AI
362+
alge_idxs::Vector{Int}
333363
"""
334-
The indexes of parameters in the linearized system which represent
364+
Getter function for parameters in the linearized system which represent
335365
input variables.
336366
"""
337-
inputs::I
367+
inputs_getter::I
368+
"""
369+
Number of input variables.
370+
"""
371+
num_inputs::Int
338372
"""
339373
The number of unknowns in the linearized system.
340374
"""
@@ -409,7 +443,7 @@ function (linfun::LinearizationFunction)(u, p, t)
409443
end
410444

411445
fun = linfun.prob.f
412-
input_vals = [linfun.prob.ps[i] for i in linfun.inputs]
446+
input_vals = linfun.inputs_getter(linfun.prob)
413447
if u !== nothing # Handle systems without unknowns
414448
linfun.num_states == length(u) ||
415449
error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))")
@@ -420,6 +454,9 @@ function (linfun::LinearizationFunction)(u, p, t)
420454
linfun.prob, integ, fun, linfun.initializealg, Val(true);
421455
linfun.initialize_kwargs...
422456
)
457+
u = u::typeof(linfun.prob.u0)
458+
p = p::typeof(linfun.prob.p)
459+
success = success::Bool
423460
if !success
424461
error("Initialization algorithm $(linfun.initializealg) failed with `unknowns = $u` and `p = $p`.")
425462
end
@@ -433,7 +470,7 @@ function (linfun::LinearizationFunction)(u, p, t)
433470
linfun.num_states == 0 ||
434471
error("Number of unknown variables (0) does not match the expected number of unknowns ($(linfun.num_states))")
435472
fg_xz = zeros(0, 0)
436-
h_xz = fg_u = zeros(0, length(linfun.inputs))
473+
h_xz = fg_u = zeros(0, length(linfun.num_inputs))
437474
end
438475
h_u = linfun.hp_jac(
439476
input_vals,
@@ -889,6 +926,18 @@ function linearize(
889926
return solve(prob; allow_input_derivatives)
890927
end
891928

929+
function __linearize_multiple_op_barrier(ssys, lin_fun; ops, ts, allow_input_derivatives)
930+
T = eltype(lin_fun.prob.u0)
931+
results = @NamedTuple{A::Matrix{T}, B::Matrix{T}, C::Matrix{T}, D::Matrix{T}}[]
932+
xpts = @NamedTuple{x::typeof(lin_fun.prob.u0), p::typeof(lin_fun.prob.p), t::typeof(lin_fun.prob.tspan[1])}[]
933+
for (op, t) in zip(ops, ts)
934+
res, xpt = linearize(ssys, lin_fun; op, t, allow_input_derivatives)::Tuple{eltype(results), eltype(xpts)}
935+
push!(results, res)
936+
push!(xpts, xpt)
937+
end
938+
return results, xpts
939+
end
940+
892941
function linearize(
893942
sys, inputs, outputs; op = Dict(), t = 0.0,
894943
allow_input_derivatives = false,
@@ -906,10 +955,8 @@ function linearize(
906955
zero_dummy_der, op = ops[1], t = ts[1],
907956
ignore_system_initial_conditions = true, kwargs...
908957
)
909-
results = map(zip(ops, ts)) do (op_i, ti)
910-
linearize(ssys, lin_fun; op = op_i, t = ti, allow_input_derivatives)
911-
end
912-
return first.(results), ssys, last.(results)
958+
ress, ops = __linearize_multiple_op_barrier(ssys, lin_fun; ops, ts, allow_input_derivatives)
959+
return ress, ssys, ops
913960
end
914961
ignore_system_ics = false
915962
if op isa LinearizationOpPoint

0 commit comments

Comments
 (0)