@@ -64,6 +64,15 @@ function _build_op_from_solution(op::LinearizationOpPoint{S, <:AbstractVector})
6464 end
6565end
6666
67+ function _linearization_wrap_odeproblem_f (@nospecialize (prob:: ODEProblem ), :: Type{T} ) where {T}
68+ f = SciMLBase. Void {Any} (prob. f. f)
69+ u0 = T .(prob. u0)
70+ t = T (prob. tspan[1 ])
71+ f = SciMLBase. wrapfun_iip (f, (u0, u0, prob. p, t))
72+ odef = remake (prob. f; f = f)
73+ return remake (prob; f = odef)
74+ end
75+
6776"""
6877 lin_fun, simplified_sys = linearization_function(sys::AbstractSystem, inputs, outputs; simplify = false, initialize = true, initialization_solver_alg = nothing, kwargs...)
6978
@@ -170,6 +179,7 @@ function linearization_function(
170179
171180 ps = parameters (sys)
172181 h = build_explicit_observed_function (sys, outputs; eval_expression, eval_module)
182+ h = SciMLBase. Void {Any} (h)
173183
174184 initialization_kwargs = (;
175185 abstol = initialization_abstol, reltol = initialization_reltol,
@@ -180,24 +190,21 @@ function linearization_function(
180190 t0 = current_time (prob)
181191 inputvals = [prob. ps[i] for i in inputs]
182192
183- hp_fun = let fun = h, setter = setp_oop (sys, inputs)
184- function hpf (du, input, u, p, t)
185- p = setter (p, input)
186- fun (du, u, p, t)
187- return du
188- end
189- end
190193 if u0 === nothing
191194 T = typeof (t0)
192195 else
193196 T = promote_type (eltype (u0), typeof (t0))
194197 end
198+ prob = _linearization_wrap_odeproblem_f (prob, T)
195199 ct0 = DI. Constant (T (t0))
196200 u0T = if u0 === nothing
197201 u0
198202 else
199203 T .(u0)
200204 end
205+ h = SciMLBase. wrapfun_iip (h, (u0T, u0T, p, T (t0)))
206+ hp_fun = HPFun (h, setp_oop (sys, inputs))
207+
201208 cu0T = DI. Constant (u0T)
202209 cp = DI. Constant (p)
203210
@@ -209,11 +216,7 @@ function linearization_function(
209216 cu0T, cp, DI. Constant (t0)
210217 )
211218 else
212- uf_fun = let fun = prob. f
213- function uff (du, u, p, t)
214- return SciMLBase. UJacobianWrapper (fun, t, p)(du, u)
215- end
216- end
219+ uf_fun = UFFun (prob. f)
217220
218221 uf_jac = PreparedJacobian {true} (
219222 uf_fun, similar (prob. u0, T), autodiff, u0T, cp, ct0
@@ -223,12 +226,8 @@ function linearization_function(
223226 h, similar (prob. u0, T, size (outputs)), autodiff,
224227 u0T, cp, ct0
225228 )
226- pf_fun = let fun = prob. f, setter = setp_oop (sys, inputs)
227- function pff (du, input, u, p, t)
228- p = setter (p, input)
229- return SciMLBase. ParamJacobianWrapper (fun, t, u)(du, p)
230- end
231- end
229+
230+ pf_fun = PFFun (prob. f, setp_oop (sys, inputs))
232231 pf_jac = PreparedJacobian {true} (
233232 pf_fun, similar (prob. u0, T), autodiff, inputvals,
234233 cu0T, cp, ct0
@@ -239,14 +238,45 @@ function linearization_function(
239238 )
240239 end
241240
241+ input_getter = getsym (prob, inputs)
242+
242243 lin_fun = LinearizationFunction (
243- diff_idxs, alge_idxs, inputs, length (unknowns (sys)),
244+ diff_idxs, alge_idxs, input_getter, length ( inputs) , length (unknowns (sys)),
244245 prob, h, u0 === nothing ? nothing : similar (u0, T), uf_jac, h_jac, pf_jac,
245246 hp_jac, initializealg, initialization_kwargs
246247 )
247248 return lin_fun, sys
248249end
249250
251+ struct HPFun{F, S}
252+ fn:: F
253+ setter:: S
254+ end
255+
256+ function (hpf:: HPFun )(du, input, u, p, t)
257+ p = hpf. setter (p, input)
258+ hpf. fn (du, u, p, t)
259+ return du
260+ end
261+
262+ struct UFFun{F}
263+ fn:: F
264+ end
265+
266+ function (uff:: UFFun )(du, u, p, t)
267+ return SciMLBase. UJacobianWrapper (uff. fn, t, p)(du, u)
268+ end
269+
270+ struct PFFun{F, S}
271+ fn:: F
272+ setter:: S
273+ end
274+
275+ function (pff:: PFFun )(du, input, u, p, t)
276+ p = pff. setter (p, input)
277+ return SciMLBase. ParamJacobianWrapper (pff. fn, t, u)(du, p)
278+ end
279+
250280"""
251281Return the set of indexes of differential equations and algebraic equations in the simplified system.
252282"""
@@ -319,22 +349,26 @@ A callable struct which linearizes a system.
319349$(TYPEDFIELDS)
320350"""
321351struct LinearizationFunction{
322- DI <: AbstractVector{Int} , AI <: AbstractVector{Int} , I, P <: ODEProblem ,
352+ I, P <: ODEProblem ,
323353 H, C, J1, J2, J3, J4, IA <: SciMLBase.DAEInitializationAlgorithm , IK,
324354 }
325355 """
326356 The indexes of differential equations in the linearized system.
327357 """
328- diff_idxs:: DI
358+ diff_idxs:: Vector{Int}
329359 """
330360 The indexes of algebraic equations in the linearized system.
331361 """
332- alge_idxs:: AI
362+ alge_idxs:: Vector{Int}
333363 """
334- The indexes of parameters in the linearized system which represent
364+ Getter function for parameters in the linearized system which represent
335365 input variables.
336366 """
337- inputs:: I
367+ inputs_getter:: I
368+ """
369+ Number of input variables.
370+ """
371+ num_inputs:: Int
338372 """
339373 The number of unknowns in the linearized system.
340374 """
@@ -409,7 +443,7 @@ function (linfun::LinearizationFunction)(u, p, t)
409443 end
410444
411445 fun = linfun. prob. f
412- input_vals = [ linfun. prob . ps[i] for i in linfun. inputs]
446+ input_vals = linfun. inputs_getter ( linfun. prob)
413447 if u != = nothing # Handle systems without unknowns
414448 linfun. num_states == length (u) ||
415449 error (" Number of unknown variables ($(linfun. num_states) ) does not match the number of input unknowns ($(length (u)) )" )
@@ -420,6 +454,9 @@ function (linfun::LinearizationFunction)(u, p, t)
420454 linfun. prob, integ, fun, linfun. initializealg, Val (true );
421455 linfun. initialize_kwargs...
422456 )
457+ u = u:: typeof (linfun. prob. u0)
458+ p = p:: typeof (linfun. prob. p)
459+ success = success:: Bool
423460 if ! success
424461 error (" Initialization algorithm $(linfun. initializealg) failed with `unknowns = $u ` and `p = $p `." )
425462 end
@@ -433,7 +470,7 @@ function (linfun::LinearizationFunction)(u, p, t)
433470 linfun. num_states == 0 ||
434471 error (" Number of unknown variables (0) does not match the expected number of unknowns ($(linfun. num_states) )" )
435472 fg_xz = zeros (0 , 0 )
436- h_xz = fg_u = zeros (0 , length (linfun. inputs ))
473+ h_xz = fg_u = zeros (0 , length (linfun. num_inputs ))
437474 end
438475 h_u = linfun. hp_jac (
439476 input_vals,
@@ -889,6 +926,18 @@ function linearize(
889926 return solve (prob; allow_input_derivatives)
890927end
891928
929+ function __linearize_multiple_op_barrier (ssys, lin_fun; ops, ts, allow_input_derivatives)
930+ T = eltype (lin_fun. prob. u0)
931+ results = @NamedTuple {A:: Matrix{T} , B:: Matrix{T} , C:: Matrix{T} , D:: Matrix{T} }[]
932+ xpts = @NamedTuple {x:: typeof (lin_fun. prob. u0), p:: typeof (lin_fun. prob. p), t:: typeof (lin_fun. prob. tspan[1 ])}[]
933+ for (op, t) in zip (ops, ts)
934+ res, xpt = linearize (ssys, lin_fun; op, t, allow_input_derivatives):: Tuple{eltype(results), eltype(xpts)}
935+ push! (results, res)
936+ push! (xpts, xpt)
937+ end
938+ return results, xpts
939+ end
940+
892941function linearize (
893942 sys, inputs, outputs; op = Dict (), t = 0.0 ,
894943 allow_input_derivatives = false ,
@@ -906,10 +955,8 @@ function linearize(
906955 zero_dummy_der, op = ops[1 ], t = ts[1 ],
907956 ignore_system_initial_conditions = true , kwargs...
908957 )
909- results = map (zip (ops, ts)) do (op_i, ti)
910- linearize (ssys, lin_fun; op = op_i, t = ti, allow_input_derivatives)
911- end
912- return first .(results), ssys, last .(results)
958+ ress, ops = __linearize_multiple_op_barrier (ssys, lin_fun; ops, ts, allow_input_derivatives)
959+ return ress, ssys, ops
913960 end
914961 ignore_system_ics = false
915962 if op isa LinearizationOpPoint
0 commit comments