From 53dcd46f9e97011dd54b596c052691a80a8eb09c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 19 May 2026 18:16:58 +0530 Subject: [PATCH 1/5] refactor: add `wrap_as_any` flag to `concrete_getu` --- lib/ModelingToolkitBase/src/systems/problem_utils.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/ModelingToolkitBase/src/systems/problem_utils.jl b/lib/ModelingToolkitBase/src/systems/problem_utils.jl index d6c9400fd0..9e7f99a910 100644 --- a/lib/ModelingToolkitBase/src/systems/problem_utils.jl +++ b/lib/ModelingToolkitBase/src/systems/problem_utils.jl @@ -672,7 +672,7 @@ Note that the getter ONLY works for problem-like objects, since it generates an function. It does NOT work for solutions. """ Base.@nospecializeinfer function concrete_getu( - indp, syms; + indp, syms; wrap_as_any = false, eval_expression, eval_module, force_time_independent = false, kwargs... ) @nospecialize @@ -680,6 +680,9 @@ Base.@nospecializeinfer function concrete_getu( indp, syms; wrap_delays = false, eval_expression, eval_module, force_time_independent, kwargs... ) + if wrap_as_any + return ObservedWrapper{is_time_dependent(indp) && !force_time_independent, Any}(obsfn) + end return ObservedWrapper{is_time_dependent(indp) && !force_time_independent}(obsfn) end From bbd1c3b22b9dbb189fc4896ffdc05bd12c03e943 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 19 May 2026 18:16:19 +0530 Subject: [PATCH 2/5] refactor: improve `isinitial` --- .../src/systems/abstractsystem.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/lib/ModelingToolkitBase/src/systems/abstractsystem.jl b/lib/ModelingToolkitBase/src/systems/abstractsystem.jl index 491cf1e633..aa7e6928f5 100644 --- a/lib/ModelingToolkitBase/src/systems/abstractsystem.jl +++ b/lib/ModelingToolkitBase/src/systems/abstractsystem.jl @@ -626,12 +626,17 @@ end """ Returns true if the parameter `p` is of the form `Initial(x)`. """ +function isinitial(p::SymbolicT) + p, _ = split_indexed_var(p) + Moshi.Match.@match p begin + BSImpl.Term(; f) => f isa Initial + _ => false + end +end function isinitial(p) - p = unwrap(p) - return iscall(p) && ( - operation(p) isa Initial || - operation(p) === getindex && isinitial(arguments(p)[1]) - ) + up = unwrap(p) + up === p && return false + return isinitial(up) end """ From d060e81440653e421642d854e12dc7bc929cfc89 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 15 May 2026 19:14:43 +0530 Subject: [PATCH 3/5] fix: avoid generating massive functions in `get_mtkparameters_reconstructor` This used to be the lion's share of the compile time in large models. The infrastructure added here might be useful in other places too. --- .../src/systems/problem_utils.jl | 299 ++++++++++++++---- 1 file changed, 241 insertions(+), 58 deletions(-) diff --git a/lib/ModelingToolkitBase/src/systems/problem_utils.jl b/lib/ModelingToolkitBase/src/systems/problem_utils.jl index 9e7f99a910..b05cbea2dd 100644 --- a/lib/ModelingToolkitBase/src/systems/problem_utils.jl +++ b/lib/ModelingToolkitBase/src/systems/problem_utils.jl @@ -721,6 +721,223 @@ function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray}) return pca.p_constructor(pca.(x)) end +""" + $TYPEDEF + +Callable struct designed for use by `MTKParametersReconstructor`. Uses a fixed set of templates to +act as a very dynamic (and limited) observed function returning an array. See `__apply_copy_template` +for the supported templates. +""" +struct CopyParamsByTemplate{IsRoot, T, N} + """ + List of templates. + """ + template::T # TODO: This field is parametric because I thought we might want to specialize it in some cases. + """ + Size of the returned buffer. + """ + size::NTuple{N, Int} +end + +function CopyParamsByTemplate{IR}(temp::T, size::NTuple{N, Int}) where {IR, T, N} + return CopyParamsByTemplate{IR, T, N}(temp, size) +end + +function __apply_copy_template(valp, template) + p = parameter_values(valp) + u = state_values(valp) + if template isa ParameterIndex{SciMLStructures.Tunable, UnitRange{Int}} + if p isa MTKParameters + return p.tunable[template.idx] + else + return p[template.idx] + end + elseif template isa ParameterIndex{SciMLStructures.Initials, UnitRange{Int}} + return p.initials[template.idx] + elseif template isa ParameterIndex{SciMLStructures.Discrete, Tuple{Int, UnitRange{Int}}} + return p.discrete[template.idx[1]][template.idx[2]] + elseif template isa ParameterIndex{SciMLStructures.Constants, Tuple{Int, UnitRange{Int}}} + return p.constant[template.idx[1]][template.idx[2]] + elseif template isa ParameterIndex{Nonnumeric, Tuple{Int, UnitRange{Int}}} + return p.nonnumeric[template.idx[1]][template.idx[2]] + elseif template isa UnitRange{Int} + return u[template] + elseif template isa ObservedWrapper + return template(valp) + elseif template isa CopyParamsByTemplate + return template(valp) + elseif template isa IndepVarTemplate + return current_time(valp) + else + # MethodError because this is a manual dispatch chain + throw(MethodError(__apply_copy_template, (valp, template))) + end +end + +function (cp::CopyParamsByTemplate{IsRoot})(src) where {IsRoot} + if IsRoot + reshape(mapreduce(Base.Fix1(__apply_copy_template, src), vcat, cp.template), cp.size) + else + buffers = map(Base.Fix1(__apply_copy_template, src), cp.template) + if cp.template isa Tuple + buffers = collect(buffers) + end + reshape(buffers, cp.size) + end +end + +struct IndepVarTemplate end +const IV_TEMPLATE = IndepVarTemplate() + +Base.@nospecializeinfer function __specialize_templates(template::Vector{Any}, elem_types::Set{DataType}) + if length(template) <= 4 + return Tuple(template) + elseif length(elem_types) <= 4 + return Vector{Union{collect(elem_types)...}}(template) + else + return template + end +end + +function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray{SymbolicT}; kws...) + template = [] + elem_types = Set{DataType}() + iv = get_iv(srcsys) + for sym in syms + if iv isa SymbolicT && isequal(iv, sym) + push!(template, IV_TEMPLATE) + push!(elem_types, IndepVarTemplate) + continue + end + symidx = parameter_index(srcsys, sym) + if symidx === nothing + symidx = variable_index(srcsys, sym) + if symidx === nothing + if isempty(template) + push!(elem_types, Vector{SymbolicT}) + push!(template, SymbolicT[sym]) + continue + end + prev = template[end] + if prev isa Vector{SymbolicT} + push!(prev, sym) + else + push!(elem_types, Vector{SymbolicT}) + push!(template, SymbolicT[sym]) + end + continue + end + if isempty(template) + push!(elem_types, UnitRange{Int}) + push!(template, symidx:symidx) + continue + end + prev = template[end] + if prev isa UnitRange{Int} && last(prev) + 1 == symidx + template[end] = first(prev):symidx + else + push!(elem_types, UnitRange{Int}) + push!(template, symidx:symidx) + end + continue + elseif symidx isa Int + symidx = ParameterIndex(SciMLStructures.Tunable(), symidx) + end + portion = symidx.portion + _bufidx = symidx.idx + bufidx::UnitRange{Int} = if _bufidx isa AbstractVector{Int} + @assert isequal(vec(_bufidx), first(_bufidx):last(_bufidx)) + subidx = nothing + first(_bufidx):last(_bufidx) + elseif _bufidx isa Int + subidx = nothing + _bufidx:_bufidx + elseif _bufidx isa NTuple{2, Int} + subidx = _bufidx[1] + _bufidx[2]:_bufidx[2] + else + # Will error due to the typeassert on `bufidx` + nothing + end + if isempty(template) + pidx = if subidx === nothing + ParameterIndex(symidx.portion, bufidx) + else + ParameterIndex(symidx.portion, (subidx, bufidx)) + end + push!(template, pidx) + push!(elem_types, typeof(pidx)) + continue + end + prev = template[end] + if prev isa ParameterIndex && prev.portion === symidx.portion && (subidx === nothing && last(prev.idx) + 1 == first(bufidx) || subidx == prev.idx[1] && last(prev.idx[2]) + 1 == first(bufidx)) + if subidx === nothing + template[end] = ParameterIndex(prev.portion, first(prev.idx):last(bufidx)) + else + template[end] = ParameterIndex(prev.portion, (subidx, first(prev.idx[2]):last(bufidx))) + end + elseif subidx === nothing + push!(template, ParameterIndex(symidx.portion, bufidx)) + push!(elem_types, typeof(template[end])) + else + push!(template, ParameterIndex(symidx.portion, (subidx, bufidx))) + push!(elem_types, typeof(template[end])) + end + end + + for i in eachindex(template) + if template[i] isa Vector{SymbolicT} + template[i] = concrete_getu(srcsys, template[i]; wrap_as_any = true, kws...) + delete!(elem_types, Vector{SymbolicT}) + push!(elem_types, typeof(template[i])) + end + end + + return CopyParamsByTemplate{true}(__specialize_templates(template, elem_types), size(syms)) +end + +function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray; kws...) + template = [] + elem_types = Set{DataType}() + for sym in syms + push!(template, CopyParamsByTemplate(srcsys, sym; kws...)) + push!(elem_types, typeof(template[end])) + end + return CopyParamsByTemplate{false}(__specialize_templates(template, elem_types), size(syms)) +end + +struct MTKParametersReconstructor{T, I, D, C, N} + tunables_fn::T + initials_fn::I + discretes_fn::D + consts_fn::C + nonnumerics_fn::N + diffcache_buffer_idx::Int +end + +# TODO: make this infer when the nonnumerics are non-trivial +function (recon::MTKParametersReconstructor)(src, dst) + src_ps = parameter_values(src) + dst_ps = parameter_values(dst) + oldcache = dst_ps.caches + # I don't know why but this makes it infer properly + if recon.tunables_fn isa ComposedFunction + tunablevals = recon.tunables_fn.outer(recon.tunables_fn.inner(src)) + else + tunablevals = recon.tunables_fn(src) + end + initialvals = recon.initials_fn(src) + nonnumerics = recon.nonnumerics_fn(src)::typeof(dst_ps.nonnumeric) + (; diffcache_buffer_idx) = recon + if !iszero(diffcache_buffer_idx) + @set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx]) + end + return MTKParameters( + tunablevals, initialvals, recon.discretes_fn(src), + recon.consts_fn(src), nonnumerics, oldcache isa Tuple{} ? () : copy.(oldcache) + ) +end + """ $(TYPEDSIGNATURES) @@ -735,10 +952,10 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns unwrapped. - `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`. """ -function get_mtkparameters_reconstructor( +function MTKParametersReconstructor( srcsys::AbstractSystem, dstsys::AbstractSystem; initials = false, unwrap_initials = false, p_constructor = identity, - eval_expression = false, eval_module = @__MODULE__, force_time_independent = false, + force_time_independent = false, kwargs... ) _p_constructor = p_constructor @@ -757,32 +974,27 @@ function get_mtkparameters_reconstructor( tunable_getter = if isempty(tunable_syms) Returns(SVector{0, Float64}()) else - p_constructor ∘ concrete_getu( - srcsys, tunable_syms; eval_expression, eval_module, - force_time_independent, kwargs... - ) + p_constructor ∘ CopyParamsByTemplate(srcsys, tunable_syms; kwargs...) end initials_getter = if initials && !isempty(syms[2]) - initsyms = Vector{Any}(syms[2]) - allsyms = Set(variable_symbols(srcsys)) + initsyms = syms[2]::Vector{SymbolicT} + allsyms = Set{SymbolicT}(variable_symbols(srcsys)) if unwrap_initials for i in eachindex(initsyms) sym = initsyms[i] - innersym = if operation(sym) === getindex - sym, idxs... = arguments(sym) - only(arguments(sym))[idxs...] + arr, isarr = split_indexed_var(sym) + innersym = if isarr + sidx = get_stable_index(sym) + first(arguments(arr))[sidx] else - only(arguments(sym)) + first(arguments(arr)) end if innersym in allsyms initsyms[i] = innersym end end end - p_constructor ∘ concrete_getu( - srcsys, initsyms; eval_expression, eval_module, - force_time_independent, kwargs... - ) + p_constructor ∘ CopyParamsByTemplate(srcsys, initsyms; kwargs...) else Returns(SVector{0, Float64}()) end @@ -795,29 +1007,25 @@ function get_mtkparameters_reconstructor( p_constructor(map(x -> x.length, bufsizes)) end ) + # discretes need to be blocked arrays # the `getu` returns a tuple of arrays corresponding to `p.discretes` # `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple # `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a # tuple of `BlockedArray`s Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘ - Base.Fix1(broadcast, p_constructor) ∘ + Base.Fix1(broadcast, p_constructor) ∘ Tuple ∘ # This `broadcast.(collect, ...)` avoids `ReshapedArray`/`SubArray`s from # appearing in the result. - concrete_getu( - srcsys, Tuple(broadcast.(collect, syms[3])); - eval_expression, eval_module, force_time_independent, kwargs... - ) + CopyParamsByTemplate(srcsys, broadcast.(collect, syms[3]); kwargs...) end - const_getter = if syms[4] == () + const_getter = if isempty(syms[4]) Returns(()) else - Base.Fix1(broadcast, p_constructor) ∘ concrete_getu( - srcsys, Tuple(syms[4]); - eval_expression, eval_module, force_time_independent, kwargs... - ) + Base.Fix1(broadcast, p_constructor) ∘ Tuple ∘ CopyParamsByTemplate(srcsys, syms[4]; kwargs...) end - nonnumeric_getter = if syms[5] == () + diffcache_buffer_idx = 0 + nonnumeric_getter = if isempty(syms[5]) Returns(()) else ic = get_index_cache(dstsys) @@ -828,44 +1036,19 @@ function get_mtkparameters_reconstructor( ) diffcache_params = SU.getmetadata(dstsys, DiffCacheParams, Dict{SymbolicT, Int}())::Dict{SymbolicT, Int} - diffcache_buffer_idx = 0 if !isempty(diffcache_params) representative = first(keys(diffcache_params)) diffcache_buffer_idx, _ = ic.nonnumeric_idx[representative] @set! buftypes[diffcache_buffer_idx] = identity + for (i, sym) in enumerate(syms[5][diffcache_buffer_idx]) + end end # nonnumerics retain the assigned buffer type without narrowing Base.Fix1(broadcast, _p_constructor) ∘ - Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ - concrete_getu( - srcsys, Tuple(syms[5]); - eval_expression, eval_module, force_time_independent, kwargs... - ) - end - getters = ( - tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, - ) - getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx - function _getter(valp, initprob) - oldcache = parameter_values(initprob).caches - tunablevals = getters[1](valp) - initialvals = getters[2](valp) - nonnumerics = getters[5](valp) - if !iszero(diffcache_buffer_idx) - @set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx]) - end - return promote_with_nothing( - promote_type_with_nothing(eltype(tunablevals), initialvals), - MTKParameters( - tunablevals, initialvals, getters[3](valp), - getters[4](valp), nonnumerics, oldcache isa Tuple{} ? () : - copy.(oldcache) - ) - ) - end + Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ Tuple ∘ CopyParamsByTemplate(srcsys, syms[5]; kwargs...) end - return getter + return MTKParametersReconstructor(tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, diffcache_buffer_idx) end function call(f, args...) @@ -893,7 +1076,7 @@ function ReconstructInitializeprob( kwargs... ) if is_split(dstsys) - pgetter = get_mtkparameters_reconstructor( + pgetter = MTKParametersReconstructor( srcsys, dstsys; p_constructor, eval_expression, eval_module, force_time_independent = is_steadystateprob, kwargs... ) @@ -973,7 +1156,7 @@ function construct_initializeprobpmap( ) @assert is_initializesystem(initsys) if is_split(sys) - return let getter = get_mtkparameters_reconstructor( + return let getter = MTKParametersReconstructor( initsys, sys; initials = true, unwrap_initials = true, p_constructor, eval_expression, eval_module, kwargs... ) From 53a488f6c40afcee27458b8fdf2c49f644013540 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 20 May 2026 00:03:27 +0530 Subject: [PATCH 4/5] refactor: improve `get_updated_symbolic_problem` Enables supporting ReverseDiff AD --- .../src/systems/nonlinear/initializesystem.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/lib/ModelingToolkitBase/src/systems/nonlinear/initializesystem.jl b/lib/ModelingToolkitBase/src/systems/nonlinear/initializesystem.jl index 2c86663f63..292a3a6634 100644 --- a/lib/ModelingToolkitBase/src/systems/nonlinear/initializesystem.jl +++ b/lib/ModelingToolkitBase/src/systems/nonlinear/initializesystem.jl @@ -812,14 +812,9 @@ function DiffEqBase.get_updated_symbolic_problem( end u0 = DiffEqBase.promote_u0(u0, buffer, t0) + u0 = ArrayInterface.restructure(u0, meta.get_updated_u0(prob, initdata.initializeprob)) - if ArrayInterface.ismutable(u0) - T = typeof(u0) - else - T = StaticArrays.similar_type(u0) - end - - return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob)), p) + return remake(prob; u0, p) end """ From 1f026bdb459726c2e7405ed12b8a46f7fd7f18be Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 20 May 2026 00:02:47 +0530 Subject: [PATCH 5/5] feat: add MTKTrackerExt --- lib/ModelingToolkitBase/Project.toml | 2 ++ lib/ModelingToolkitBase/ext/MTKTrackerExt.jl | 14 ++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 lib/ModelingToolkitBase/ext/MTKTrackerExt.jl diff --git a/lib/ModelingToolkitBase/Project.toml b/lib/ModelingToolkitBase/Project.toml index 5f6e9eb742..6781128148 100644 --- a/lib/ModelingToolkitBase/Project.toml +++ b/lib/ModelingToolkitBase/Project.toml @@ -73,6 +73,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] MTKBifurcationKitExt = "BifurcationKit" @@ -86,6 +87,7 @@ MTKLabelledArraysExt = "LabelledArrays" MTKLatexifyExt = "Latexify" MTKMooncakeExt = "Mooncake" MTKPyomoDynamicOptExt = "Pyomo" +MTKTrackerExt = "Tracker" [compat] ADTypes = "1.14.0" diff --git a/lib/ModelingToolkitBase/ext/MTKTrackerExt.jl b/lib/ModelingToolkitBase/ext/MTKTrackerExt.jl new file mode 100644 index 0000000000..fd97a518df --- /dev/null +++ b/lib/ModelingToolkitBase/ext/MTKTrackerExt.jl @@ -0,0 +1,14 @@ +module MTKTrackerExt + +import ModelingToolkitBase as MTKBase +import Tracker + +function MTKBase.promote_type_with_nothing(::Type{Tracker.TrackedReal{T}}, x::Tracker.TrackedArray{T}) where {T} + return Tracker.TrackedReal{T} +end + +function MTKBase.promote_with_nothing(::Type{Tracker.TrackedReal{T}}, x::Tracker.TrackedArray{T}) where {T} + return x +end + +end