Skip to content

Commit f1440e3

Browse files
refactor: improve type-stability of maybe_build_initialization_problem
1 parent cb06c49 commit f1440e3

2 files changed

Lines changed: 13 additions & 14 deletions

File tree

lib/ModelingToolkitBase/src/systems/nonlinear/initializesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,8 @@ function SciMLBase.remake_initialization_data(
633633
u0_constructor = get_u0_constructor(identity, typeof(newu0), floatT, false)
634634
p_constructor = get_p_constructor(identity, typeof(newu0), floatT)
635635
kws = maybe_build_initialization_problem(
636-
sys, SciMLBase.isinplace(odefn), op, t0, guesses;
637-
time_dependent_init, use_scc, initialization_eqs, floatT, fast_path = true,
636+
sys, Val{SciMLBase.isinplace(odefn)}(), op, t0, guesses, floatT;
637+
time_dependent_init, use_scc, initialization_eqs, fast_path = true,
638638
u0_constructor, p_constructor, allow_incomplete = true, check_units = false,
639639
missing_guess_value = meta.missing_guess_value
640640
)

lib/ModelingToolkitBase/src/systems/problem_utils.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,16 +1088,15 @@ struct GetUpdatedU0{GG, GIU}
10881088
get_initial_unknowns::GIU
10891089
end
10901090

1091-
function GetUpdatedU0(sys::AbstractSystem, initprob::SciMLBase.AbstractNonlinearProblem, op::AbstractDict)
1092-
@nospecialize initprob
1091+
function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict)
10931092
dvs = unknowns(sys)
10941093
eqs = equations(sys)
10951094
guessvars = trues(length(dvs))
10961095
for (i, var) in enumerate(dvs)
10971096
varval = get(op, var, COMMON_NOTHING)
10981097
guessvars[i] = varval === COMMON_NOTHING || !SU.isconst(varval)
10991098
end
1100-
get_guessvars = getu(initprob, dvs[guessvars])
1099+
get_guessvars = getu(initsys, dvs[guessvars])
11011100
get_initial_unknowns = getu(sys, Initial.(dvs))
11021101
return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns)
11031102
end
@@ -1170,14 +1169,14 @@ constructed is in implicit DAE form (`DAEProblem`). All other keyword arguments
11701169
to `InitializationProblem`.
11711170
"""
11721171
function maybe_build_initialization_problem(
1173-
sys::AbstractSystem, iip, op::AbstractDict, t, guesses;
1172+
sys::AbstractSystem, ::Val{iip}, op::SymmapT, t, guesses, ::Type{floatT};
11741173
time_dependent_init = is_time_dependent(sys), u0_constructor = identity,
1175-
p_constructor = identity, floatT = Float64, initialization_eqs = [],
1174+
p_constructor = identity, initialization_eqs = [],
11761175
use_scc = true, eval_expression = false, eval_module = @__MODULE__,
11771176
missing_guess_value = default_missing_guess_value(),
11781177
# Intercept `expression` because we don't support it here yet
11791178
implicit_dae = false, is_steadystateprob = false, expression = Val{false}, kwargs...
1180-
)
1179+
) where {iip, floatT}
11811180
guesses = merge(ModelingToolkitBase.guesses(sys), todict(guesses))
11821181

11831182
if t === nothing && is_time_dependent(sys)
@@ -1190,6 +1189,7 @@ function maybe_build_initialization_problem(
11901189
use_scc, u0_constructor, p_constructor, eval_expression, eval_module,
11911190
missing_guess_value, is_steadystateprob, kwargs...
11921191
)
1192+
initsys = initializeprob.f.sys::System
11931193
needs_remake = false
11941194
_u0 = state_values(initializeprob)
11951195
if _u0 !== nothing
@@ -1236,7 +1236,7 @@ function maybe_build_initialization_problem(
12361236
end
12371237

12381238
get_initial_unknowns = if time_dependent_init
1239-
GetUpdatedU0(sys, initializeprob, op)
1239+
GetUpdatedU0(sys, initsys, op)
12401240
else
12411241
nothing
12421242
end
@@ -1246,7 +1246,7 @@ function maybe_build_initialization_problem(
12461246
Vector{Equation}(initialization_eqs),
12471247
use_scc, time_dependent_init,
12481248
ReconstructInitializeprob(
1249-
sys, initializeprob.f.sys; u0_constructor,
1249+
sys, initsys; u0_constructor,
12501250
p_constructor, eval_expression, eval_module, is_steadystateprob, kwargs...
12511251
),
12521252
get_initial_unknowns, SetInitialUnknowns(sys), missing_guess_value
@@ -1274,7 +1274,7 @@ function maybe_build_initialization_problem(
12741274
initializeprobpmap = nothing
12751275
else
12761276
initializeprobpmap = construct_initializeprobpmap(
1277-
sys, initializeprob.f.sys; p_constructor, eval_expression, eval_module, kwargs...
1277+
sys, initsys; p_constructor, eval_expression, eval_module, kwargs...
12781278
)
12791279
end
12801280

@@ -1302,7 +1302,6 @@ function maybe_build_initialization_problem(
13021302
end
13031303
end
13041304
if implicit_dae
1305-
initsys = initializeprob.f.sys
13061305
for v in unknowns(sys)
13071306
v = Differential(get_iv(sys))(v)
13081307
ttv = default_toterm(v)
@@ -1328,10 +1327,10 @@ function maybe_build_initialization_problem(
13281327
end
13291328
missingvars = collect(missingvars)
13301329

1331-
for (i, v) in enumerate(unknowns(initializeprob.f.sys))
1330+
for (i, v) in enumerate(unknowns(initsys))
13321331
write_possibly_indexed_array!(temp_op, v, SConst(_u0[i]), COMMON_NOTHING)
13331332
end
1334-
add_observed!(initializeprob.f.sys, temp_op)
1333+
add_observed!(initsys, temp_op)
13351334
left_merge!(temp_op, ModelingToolkitBase.guesses(sys))
13361335
subber = Symbolics.FixpointSubstituter{true}(AADSubWrapper(temp_op))
13371336
for p in missingvars

0 commit comments

Comments
 (0)