Skip to content

Commit 5db718d

Browse files
Merge pull request #4469 from SebastianM-C/smc/init_promotion
fix: promote the output array of initialization to the tunable eltype
2 parents 530eb54 + d0904f9 commit 5db718d

2 files changed

Lines changed: 106 additions & 3 deletions

File tree

lib/ModelingToolkitBase/src/systems/problem_utils.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,46 @@ end
11631163
safe_float(x) = x
11641164
safe_float(x::AbstractArray) = isempty(x) ? x : float(x)
11651165

1166+
"""
1167+
PromoteToTunableEltype(observed, floatT)
1168+
1169+
Wraps an `initializeprob` observed function so its output array is promoted to an
1170+
eltype compatible with the current tunable parameters. Addresses the case where
1171+
the observed function is generated from fully constant RHS (e.g. `initialization_eqs
1172+
= [s ~ 0]`): the resulting `create_array(Array, nothing, …, 0, 0)` would otherwise
1173+
produce `Vector{Int64}`, which — when downstream `remake` reinstalls it as `u0` —
1174+
silently defeats ForwardDiff/Tracker/Measurements promotion of `u0`.
1175+
1176+
`floatT` is the static floor (the same `floatT` the rest of the construction pipeline
1177+
commits to, derived from the user's varmap). It guarantees integer RHS gets lifted
1178+
to a float without overriding the user's chosen precision (e.g. `Float32`). The
1179+
dynamic tunable eltype is read fresh from `parameter_values(nlsol)` on every call,
1180+
so a later `remake` with `ForwardDiff.Dual` parameters still wins via `promote_type`.
1181+
"""
1182+
struct PromoteToTunableEltype{F, floatT}
1183+
observed::F
1184+
end
1185+
1186+
PromoteToTunableEltype(observed, ::Type{T}) where {T} =
1187+
PromoteToTunableEltype{typeof(observed), T}(observed)
1188+
1189+
function (p::PromoteToTunableEltype{F, floatT})(nlsol) where {F, floatT}
1190+
raw = p.observed(nlsol)
1191+
raw isa AbstractArray || return raw
1192+
isempty(raw) && return raw
1193+
T = promote_type(eltype(raw), _tunable_eltype(parameter_values(nlsol)), floatT)
1194+
T === eltype(raw) ? raw : convert(AbstractArray{T}, raw)
1195+
end
1196+
1197+
_tunable_eltype(p::MTKParameters) = isempty(p.tunable) ? Bool : eltype(p.tunable)
1198+
function _tunable_eltype(p)
1199+
if SciMLStructures.isscimlstructure(p)
1200+
tun = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]
1201+
return isempty(tun) ? Bool : eltype(tun)
1202+
end
1203+
return Bool
1204+
end
1205+
11661206
"""
11671207
$(TYPEDSIGNATURES)
11681208
@@ -1262,8 +1302,8 @@ function maybe_build_initialization_problem(
12621302
if isempty(solved_unknowns)
12631303
initializeprobmap = nothing
12641304
else
1265-
initializeprobmap = u0_constructor safe_float
1266-
getu(initializeprob, solved_unknowns)
1305+
initializeprobmap = u0_constructor PromoteToTunableEltype(
1306+
getu(initializeprob, solved_unknowns), floatT)
12671307
end
12681308
else
12691309
initializeprobmap = nothing

lib/ModelingToolkitBase/test/initializationsystem.jl

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ end
910910

911911
@testset "No initialization for variables" begin
912912
@variables x = 1.0
913-
@parameters p = 10.0
913+
@parameters p = 10.0
914914

915915
eqs = [
916916
0 ~ x^2 + 2p * x + 3p
@@ -2039,3 +2039,66 @@ if @isdefined(ModelingToolkit)
20392039
@test_nowarn ForwardDiff.gradient(costfn, [1.2])
20402040
end
20412041
end
2042+
2043+
@testset "Output arrays from constant RHS under ForwardDiff" begin
2044+
# Issue #4457
2045+
@parameters m=1.5 d=9.0
2046+
@variables s(t) v(t)
2047+
2048+
eqs = [
2049+
D(s) ~ v
2050+
D(v) ~ (1 - d * v) / m
2051+
]
2052+
2053+
sys = mtkcompile(System(eqs, t;
2054+
name = :model,
2055+
initialization_eqs = [s ~ 0, v ~ 0],
2056+
))
2057+
2058+
prob = ODEProblem(sys, [], (0.0, 200.0))
2059+
sol = solve(prob, Tsit5(); saveat = 0.1)
2060+
@test SciMLBase.successful_retcode(sol)
2061+
2062+
setter = setp_oop(prob, [sys.m, sys.d])
2063+
2064+
function loss1(x)
2065+
p = setter(prob, x)
2066+
newprob = remake(prob; p)
2067+
newsol = solve(newprob, Tsit5(); saveat = 0.1)
2068+
sum(abs2, newsol[sys.s])
2069+
end
2070+
2071+
@test_nowarn ForwardDiff.gradient(loss1, [3.0, 20.0])
2072+
2073+
# Issue 3924
2074+
function create_sys()
2075+
@parameters p1 = 0.5 [tunable = true] (p23[1:2] = [1, 3.0]) [tunable = true] p4 = 3 * p1 [tunable = false] y0 = 1.2 [tunable = true]
2076+
@variables x(t) = 2p1 y(t) = y0 z(t) = x + y
2077+
2078+
eqs = [D(x) ~ p1 * x - p23[1] * x * y
2079+
D(y) ~ -p23[2] * y + p4 * x * y
2080+
z ~ x + y]
2081+
2082+
mtkcompile(System(eqs, t, name=:sys))
2083+
end
2084+
2085+
sys = create_sys()
2086+
2087+
sub_sys = subset_tunables(sys, [sys.p23])
2088+
2089+
prob = ODEProblem(sub_sys, [], (0, 1.))
2090+
2091+
setter = setsym_oop(prob, Symbolics.scalarize(sys.p23));
2092+
2093+
function loss2(x, ps)
2094+
setter, prob = ps
2095+
u0, p = setter(prob, x)
2096+
new_prob = remake(prob; u0, p)
2097+
sol = solve(new_prob, Tsit5())
2098+
sum(sol)
2099+
end
2100+
2101+
@test_nowarn loss2([1., 2], (setter, prob))
2102+
2103+
@test_nowarn ForwardDiff.gradient(Base.Fix2(loss2, (setter, prob)), [1, 2.])
2104+
end

0 commit comments

Comments
 (0)