Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions lib/ModelingToolkitBase/src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,46 @@ end
safe_float(x) = x
safe_float(x::AbstractArray) = isempty(x) ? x : float(x)

"""
PromoteToTunableEltype(observed, floatT)

Wraps an `initializeprob` observed function so its output array is promoted to an
eltype compatible with the current tunable parameters. Addresses the case where
the observed function is generated from fully constant RHS (e.g. `initialization_eqs
= [s ~ 0]`): the resulting `create_array(Array, nothing, …, 0, 0)` would otherwise
produce `Vector{Int64}`, which — when downstream `remake` reinstalls it as `u0` —
silently defeats ForwardDiff/Tracker/Measurements promotion of `u0`.

`floatT` is the static floor (the same `floatT` the rest of the construction pipeline
commits to, derived from the user's varmap). It guarantees integer RHS gets lifted
to a float without overriding the user's chosen precision (e.g. `Float32`). The
dynamic tunable eltype is read fresh from `parameter_values(nlsol)` on every call,
so a later `remake` with `ForwardDiff.Dual` parameters still wins via `promote_type`.
"""
struct PromoteToTunableEltype{F, floatT}
observed::F
end

PromoteToTunableEltype(observed, ::Type{T}) where {T} =
PromoteToTunableEltype{typeof(observed), T}(observed)

function (p::PromoteToTunableEltype{F, floatT})(nlsol) where {F, floatT}
raw = p.observed(nlsol)
raw isa AbstractArray || return raw
isempty(raw) && return raw
T = promote_type(eltype(raw), _tunable_eltype(parameter_values(nlsol)), floatT)
T === eltype(raw) ? raw : convert(AbstractArray{T}, raw)
end

_tunable_eltype(p::MTKParameters) = isempty(p.tunable) ? Bool : eltype(p.tunable)
function _tunable_eltype(p)
if SciMLStructures.isscimlstructure(p)
tun = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]
return isempty(tun) ? Bool : eltype(tun)
end
return Bool
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -1262,8 +1302,8 @@ function maybe_build_initialization_problem(
if isempty(solved_unknowns)
initializeprobmap = nothing
else
initializeprobmap = u0_constructor ∘ safe_float ∘
getu(initializeprob, solved_unknowns)
initializeprobmap = u0_constructor ∘ PromoteToTunableEltype(
getu(initializeprob, solved_unknowns), floatT)
end
else
initializeprobmap = nothing
Expand Down
65 changes: 64 additions & 1 deletion lib/ModelingToolkitBase/test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ end

@testset "No initialization for variables" begin
@variables x = 1.0
@parameters p = 10.0
@parameters p = 10.0

eqs = [
0 ~ x^2 + 2p * x + 3p
Expand Down Expand Up @@ -2039,3 +2039,66 @@ if @isdefined(ModelingToolkit)
@test_nowarn ForwardDiff.gradient(costfn, [1.2])
end
end

@testset "Output arrays from constant RHS under ForwardDiff" begin
# Issue #4457
@parameters m=1.5 d=9.0
@variables s(t) v(t)

eqs = [
D(s) ~ v
D(v) ~ (1 - d * v) / m
]

sys = mtkcompile(System(eqs, t;
name = :model,
initialization_eqs = [s ~ 0, v ~ 0],
))

prob = ODEProblem(sys, [], (0.0, 200.0))
sol = solve(prob, Tsit5(); saveat = 0.1)
@test SciMLBase.successful_retcode(sol)

setter = setp_oop(prob, [sys.m, sys.d])

function loss1(x)
p = setter(prob, x)
newprob = remake(prob; p)
newsol = solve(newprob, Tsit5(); saveat = 0.1)
sum(abs2, newsol[sys.s])
end

@test_nowarn ForwardDiff.gradient(loss1, [3.0, 20.0])

# Issue 3924
function create_sys()
@parameters p1 = 0.5 [tunable = true] (p23[1:2] = [1, 3.0]) [tunable = true] p4 = 3 * p1 [tunable = false] y0 = 1.2 [tunable = true]
@variables x(t) = 2p1 y(t) = y0 z(t) = x + y

eqs = [D(x) ~ p1 * x - p23[1] * x * y
D(y) ~ -p23[2] * y + p4 * x * y
z ~ x + y]

mtkcompile(System(eqs, t, name=:sys))
end

sys = create_sys()

sub_sys = subset_tunables(sys, [sys.p23])

prob = ODEProblem(sub_sys, [], (0, 1.))

setter = setsym_oop(prob, Symbolics.scalarize(sys.p23));

function loss2(x, ps)
setter, prob = ps
u0, p = setter(prob, x)
new_prob = remake(prob; u0, p)
sol = solve(new_prob, Tsit5())
sum(sol)
end

@test_nowarn loss2([1., 2], (setter, prob))

@test_nowarn ForwardDiff.gradient(Base.Fix2(loss2, (setter, prob)), [1, 2.])
end
Loading