Skip to content

Commit 2053909

Browse files
committed
fix: promote the output array of initialization to the tunable eltype
This adresses the cases where the we generate output arrays from fully constant RHS, which would end up being incompatible with the elype of parameters in ForwardDiff and other similar contexts.
1 parent 48732b1 commit 2053909

2 files changed

Lines changed: 68 additions & 3 deletions

File tree

lib/ModelingToolkitBase/src/systems/problem_utils.jl

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,41 @@ end
11591159
safe_float(x) = x
11601160
safe_float(x::AbstractArray) = isempty(x) ? x : float(x)
11611161

1162+
"""
1163+
PromoteToTunableEltype(observed)
1164+
1165+
Wraps an `initializeprob` observed function so its output array is promoted to an
1166+
eltype compatible with the current tunable parameters. Addresses the case where
1167+
the observed function is generated from fully constant RHS (e.g. `initialization_eqs
1168+
= [s ~ 0]`): the resulting `create_array(Array, nothing, …, 0, 0)` would otherwise
1169+
produce `Vector{Int64}`, which — when downstream `remake` reinstalls it as `u0` —
1170+
silently defeats ForwardDiff/Tracker/Measurements promotion of `u0`.
1171+
1172+
Replaces the previous `safe_float` layer by subsuming it: `promote_type(Int, Float64)
1173+
== Float64`, so plain problems still get `Vector{Float64}`; `promote_type(Int,
1174+
ForwardDiff.Dual)` yields the Dual type.
1175+
"""
1176+
struct PromoteToTunableEltype{F}
1177+
observed::F
1178+
end
1179+
1180+
function (p::PromoteToTunableEltype)(nlsol)
1181+
raw = p.observed(nlsol)
1182+
raw isa AbstractArray || return raw
1183+
isempty(raw) && return raw
1184+
T = promote_type(eltype(raw), _tunable_eltype(parameter_values(nlsol)), Float64)
1185+
T === eltype(raw) ? raw : convert(AbstractArray{T}, raw)
1186+
end
1187+
1188+
_tunable_eltype(p::MTKParameters) = isempty(p.tunable) ? Bool : eltype(p.tunable)
1189+
function _tunable_eltype(p)
1190+
if SciMLStructures.isscimlstructure(p)
1191+
tun = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]
1192+
return isempty(tun) ? Bool : eltype(tun)
1193+
end
1194+
return Bool
1195+
end
1196+
11621197
"""
11631198
$(TYPEDSIGNATURES)
11641199
@@ -1258,8 +1293,8 @@ function maybe_build_initialization_problem(
12581293
if isempty(solved_unknowns)
12591294
initializeprobmap = Returns(nothing)
12601295
else
1261-
initializeprobmap = u0_constructor safe_float
1262-
getu(initializeprob, solved_unknowns)
1296+
initializeprobmap = u0_constructor PromoteToTunableEltype(
1297+
getu(initializeprob, solved_unknowns))
12631298
end
12641299
else
12651300
initializeprobmap = nothing

lib/ModelingToolkitBase/test/initializationsystem.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ end
911911

912912
@testset "No initialization for variables" begin
913913
@variables x = 1.0
914-
@parameters p = 10.0
914+
@parameters p = 10.0
915915

916916
eqs = [
917917
0 ~ x^2 + 2p * x + 3p
@@ -2013,3 +2013,33 @@ end
20132013
@test !(prob.f.initialization_data.initializeprob isa SCCNonlinearProblem)
20142014
end
20152015
end
2016+
2017+
@testset "Issue #4457" begin
2018+
@parameters m=1.5 d=9.0
2019+
@variables s(t) v(t)
2020+
2021+
eqs = [
2022+
D(s) ~ v
2023+
m * D(v) ~ 1 - d * v
2024+
]
2025+
2026+
sys = mtkcompile(System(eqs, t;
2027+
name = :model,
2028+
initialization_eqs = [s ~ 0, v ~ 0],
2029+
))
2030+
2031+
prob = ODEProblem{true, FullSpecialize}(sys, [], (0.0, 200.0))
2032+
sol = solve(prob, Tsit5(); saveat = 0.1)
2033+
@test SciMLBase.successful_retcode(sol)
2034+
2035+
setter = setp_oop(prob, [sys.m, sys.d])
2036+
2037+
function loss(x)
2038+
p = setter(prob, x)
2039+
newprob = remake(prob; p)
2040+
newsol = solve(newprob, Tsit5(); saveat = 0.1)
2041+
sum(abs2, newsol[sys.s])
2042+
end
2043+
2044+
ForwardDiff.gradient(loss, [3.0, 20.0])
2045+
end

0 commit comments

Comments
 (0)