Skip to content

Commit 2054fc6

Browse files
Merge pull request #4459 from SciML/as/fix-dual-cache
fix: correctly promote cache buffers in `remake`
2 parents 6382316 + b8ac008 commit 2054fc6

4 files changed

Lines changed: 48 additions & 6 deletions

File tree

lib/ModelingToolkitBase/src/systems/nonlinear/initializesystem.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,11 @@ function promote_with_nothing(::Type{T}, p::MTKParameters) where {T}
670670
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
671671
initials = promote_with_nothing(T, p.initials)
672672
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
673+
for i in eachindex(p.caches)
674+
if eltype(p.caches[i]) <: AbstractFloat
675+
@set! p.caches[i] = promote_with_nothing(T, p.caches[i])
676+
end
677+
end
673678
return p
674679
end
675680

lib/ModelingToolkitBase/src/systems/problem_utils.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -845,15 +845,19 @@ function get_mtkparameters_reconstructor(
845845
getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx
846846
function _getter(valp, initprob)
847847
oldcache = parameter_values(initprob).caches
848+
tunablevals = getters[1](valp)
848849
initialvals = getters[2](valp)
849850
nonnumerics = getters[5](valp)
850851
if !iszero(diffcache_buffer_idx)
851852
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{eltype(initialvals)}.(nonnumerics[diffcache_buffer_idx])
852853
end
853-
return MTKParameters(
854-
getters[1](valp), initialvals, getters[3](valp),
855-
getters[4](valp), nonnumerics, oldcache isa Tuple{} ? () :
856-
copy.(oldcache)
854+
return promote_with_nothing(
855+
promote_type_with_nothing(eltype(tunablevals), initialvals),
856+
MTKParameters(
857+
tunablevals, initialvals, getters[3](valp),
858+
getters[4](valp), nonnumerics, oldcache isa Tuple{} ? () :
859+
copy.(oldcache)
860+
)
857861
)
858862
end
859863
end

lib/ModelingToolkitBase/test/initializationsystem.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,3 +2013,30 @@ end
20132013
@test !(prob.f.initialization_data.initializeprob isa SCCNonlinearProblem)
20142014
end
20152015
end
2016+
2017+
if @isdefined(ModelingToolkit)
2018+
# `SCCNonlinearProblem` uses the cache buffers
2019+
@testset "Cache buffers are correctly promoted during initialization" begin
2020+
@parameters g
2021+
@variables x(t) y(t) [state_priority = 10] λ(t) yˍt(t) xˍt(t) xˍtt(t)
2022+
@mtkcomplete pend = index_reduced_pend()
2023+
g_true = 9.81
2024+
prob = ODEProblem(
2025+
pend, [x => 1, D(y) => 0, g => g_true], (0.0, 0.5);
2026+
guesses ==> 0, y => 1, x => 1], missing_guess_value
2027+
)
2028+
@test !isempty(prob.f.initialization_data.initializeprob.p.caches)
2029+
saveat = range(prob.tspan..., length = 30)
2030+
sol = solve(prob, Rodas5P(); saveat)
2031+
@test SciMLBase.successful_retcode(sol)
2032+
setter = setp_oop(prob, [g])
2033+
2034+
function costfn(theta)
2035+
p = setter(prob, theta)
2036+
newprob = SciMLBase.remake(prob; p)
2037+
sol = solve(newprob, Rodas5P(); saveat = 0.1)
2038+
sum(abs2, sol[x])
2039+
end
2040+
@test_nowarn ForwardDiff.gradient(costfn, [1.2])
2041+
end
2042+
end

lib/ModelingToolkitBase/test/split_parameters.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,13 @@ if @isdefined(ModelingToolkit)
162162
@named d = Step(start_time = 1.0, duration = 10.0, offset = 0.0, height = 1.0) # Disturbance
163163
model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.inertia2.phi] # This is the state realization we want to control
164164
inputs = [model.torque.tau.u]
165-
op = [model.torque.tau.u => 0.0]
165+
op = [
166+
model.inertia1.w => 1.0
167+
model.inertia2.w => 1.0
168+
model.inertia1.phi => 1.0
169+
model.inertia2.phi => 1.0
170+
model.torque.tau.u => 0.0
171+
]
166172
matrices, ssys = ModelingToolkit.linearize(
167173
wr(model), inputs, model_outputs; op,
168174
guesses = [model.inertia2.flange_a.phi => 0.0, model.inertia1.flange_b.phi => 0.0]
@@ -197,7 +203,7 @@ if @isdefined(ModelingToolkit)
197203
connect(add.output, :u, model.torque.tau)
198204
]
199205
@named closed_loop = System(connections, t, systems = [model, state_feedback, add, d])
200-
S = get_sensitivity(closed_loop, :u)
206+
S = get_sensitivity(closed_loop, :u; op = x_costs)
201207
end
202208

203209
@testset "Indexing MTKParameters with ParameterIndex" begin

0 commit comments

Comments
 (0)