Skip to content

Commit 7f84f75

Browse files
fix: avoid generating massive functions in get_mtkparameters_reconstructor
This used to be the lion's share of the compile time in large models. The infrastructure added here might be useful in other places too.
1 parent 46f2ef2 commit 7f84f75

1 file changed

Lines changed: 219 additions & 58 deletions

File tree

lib/ModelingToolkitBase/src/systems/problem_utils.jl

Lines changed: 219 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,201 @@ function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray})
718718
return pca.p_constructor(pca.(x))
719719
end
720720

721+
"""
722+
$TYPEDEF
723+
724+
Callable struct designed for use by `MTKParametersReconstructor`. Uses a fixed set of templates to
725+
act as a very dynamic (and limited) observed function returning an array. See `__apply_copy_template`
726+
for the supported templates.
727+
"""
728+
struct CopyParamsByTemplate{T, N}
729+
"""
730+
Whether this is a "root" object, in that it directly contains instantiable templates instead
731+
of recursing into other `CopyParamsByTemplate`.
732+
"""
733+
isroot::Bool
734+
"""
735+
List of templates.
736+
"""
737+
template::T # TODO: This field is parametric because I thought we might want to specialize it in some cases.
738+
"""
739+
Size of the returned buffer.
740+
"""
741+
size::NTuple{N, Int}
742+
end
743+
744+
function __apply_copy_template(valp, @nospecialize(template))
745+
p = parameter_values(valp)
746+
u = state_values(valp)
747+
if template isa ParameterIndex{SciMLStructures.Tunable, UnitRange{Int}}
748+
return p.tunable[template.idx]
749+
elseif template isa ParameterIndex{SciMLStructures.Initials, UnitRange{Int}}
750+
return p.initials[template.idx]
751+
elseif template isa ParameterIndex{SciMLStructures.Discrete, Tuple{Int, UnitRange{Int}}}
752+
return p.discrete[template.idx[1]][template.idx[2]]
753+
elseif template isa ParameterIndex{SciMLStructures.Constants, Tuple{Int, UnitRange{Int}}}
754+
return p.constant[template.idx[1]][template.idx[2]]
755+
elseif template isa ParameterIndex{Nonnumeric, Tuple{Int, UnitRange{Int}}}
756+
return p.nonnumeric[template.idx[1]][template.idx[2]]
757+
elseif template isa UnitRange{Int}
758+
return u[template]
759+
elseif template isa ObservedWrapper
760+
return template(valp)
761+
elseif template isa CopyParamsByTemplate
762+
return template(valp)
763+
else
764+
# MethodError because this is a manual dispatch chain
765+
throw(MethodError(__apply_copy_template, (valp, template)))
766+
end
767+
end
768+
769+
function (cp::CopyParamsByTemplate)(src)
770+
if cp.isroot
771+
reshape(mapreduce(Base.Fix1(__apply_copy_template, src), vcat, cp.template), cp.size)
772+
else
773+
reshape(map(Base.Fix1(__apply_copy_template, src), cp.template), cp.size)
774+
end
775+
end
776+
777+
function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray{SymbolicT}; kws...)
778+
template = []
779+
for sym in syms
780+
symidx = parameter_index(srcsys, sym)
781+
if symidx === nothing
782+
symidx = variable_index(srcsys, sym)
783+
if symidx === nothing
784+
if isempty(template)
785+
push!(template, SymbolicT[sym])
786+
continue
787+
end
788+
prev = template[end]
789+
if prev isa Vector{SymbolicT}
790+
push!(prev, sym)
791+
else
792+
push!(template, SymbolicT[sym])
793+
end
794+
continue
795+
end
796+
if isempty(template)
797+
push!(template, symidx:symidx)
798+
continue
799+
end
800+
prev = template[end]
801+
if prev isa UnitRange{Int} && last(prev) + 1 == symidx
802+
template[end] = first(prev):symidx
803+
else
804+
push!(template, symidx:symidx)
805+
end
806+
continue
807+
end
808+
portion = symidx.portion
809+
_bufidx = symidx.idx
810+
if portion isa SciMLStructures.Tunable || portion isa SciMLStructures.Initials
811+
812+
end
813+
bufidx::UnitRange{Int} = if _bufidx isa AbstractVector{Int}
814+
@assert isequal(vec(_bufidx), first(_bufidx):last(_bufidx))
815+
subidx = nothing
816+
first(_bufidx):last(_bufidx)
817+
elseif _bufidx isa Int
818+
subidx = nothing
819+
_bufidx:_bufidx
820+
elseif _bufidx isa NTuple{2, Int}
821+
subidx = _bufidx[1]
822+
_bufidx[2]:_bufidx[2]
823+
else
824+
# Will error due to the typeassert on `bufidx`
825+
nothing
826+
end
827+
if isempty(template)
828+
if subidx === nothing
829+
push!(template, ParameterIndex(symidx.portion, bufidx))
830+
else
831+
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
832+
end
833+
continue
834+
end
835+
prev = template[end]
836+
if prev isa ParameterIndex && prev.portion === symidx.portion && (subidx === nothing && last(prev.idx) + 1 == first(bufidx) || subidx == prev.idx[1] && last(prev.idx[2]) + 1 == first(bufidx))
837+
if subidx === nothing
838+
template[end] = ParameterIndex(prev.portion, first(prev.idx):last(bufidx))
839+
else
840+
template[end] = ParameterIndex(prev.portion, (subidx, first(prev.idx[2]):last(bufidx)))
841+
end
842+
elseif subidx === nothing
843+
push!(template, ParameterIndex(symidx.portion, bufidx))
844+
else
845+
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
846+
end
847+
end
848+
849+
for i in eachindex(template)
850+
if template[i] isa Vector{SymbolicT}
851+
template[i] = concrete_getu(srcsys, template[i]; kws...)
852+
end
853+
end
854+
return CopyParamsByTemplate(true, template, size(syms))
855+
end
856+
857+
function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray; kws...)
858+
return CopyParamsByTemplate(false, [CopyParamsByTemplate(srcsys, sym; kws...) for sym in syms], size(syms))
859+
end
860+
861+
struct MTKParametersReconstructor{T, I, D, C, N}
862+
tunables_fn::T
863+
initials_fn::I
864+
discretes_fn::D
865+
consts_fn::C
866+
nonnumerics_fn::N
867+
diffcache_buffer_idx::Int
868+
end
869+
870+
function (recon::MTKParametersReconstructor)(src, dst)
871+
src_ps = parameter_values(src)
872+
dst_ps = parameter_values(dst)
873+
oldcache = dst_ps.caches
874+
# These type-assertions allow `remake` to infer even though `CopyParamsByTemplate` doesn't
875+
# TODO: Find a solution to type-assert constants/discretes/nonnumerics
876+
# TODO: Maybe `CopyParamsByTemplate` could use a specializing tuple-form when `FullSpecialize`?
877+
if dst_ps.tunable isa SVector{0}
878+
tunablevals = recon.tunables_fn(src)
879+
else
880+
tunable_elT = promote_type(eltype(dst_ps.tunable), eltype(src_ps.tunable))
881+
if ArrayInterface.ismutable(dst_ps.tunable)
882+
tunable_T = Base.promote_op(similar, typeof(dst_ps.tunable), Type{tunable_elT})
883+
tunablevals = recon.tunables_fn(src)::tunable_T
884+
else
885+
tunable_T = StaticArraysCore.similar_type(typeof(dst_ps.tunable), tunable_elT)
886+
tunablevals = recon.tunables_fn(src)::tunable_T
887+
end
888+
end
889+
if dst_ps.initials isa SVector{0}
890+
initialvals = recon.initials_fn(src)
891+
else
892+
initial_elT = promote_type(eltype(dst_ps.initials), eltype(src_ps.initials))
893+
if ArrayInterface.ismutable(dst_ps.initials)
894+
initial_T = Base.promote_op(similar, typeof(dst_ps.initials), Type{initial_elT})
895+
initialvals = recon.initials_fn(src)::initial_T
896+
else
897+
initial_T = StaticArraysCore.similar_type(typeof(dst_ps.initials), initial_elT)
898+
initialvals = recon.initials_fn(src)::initial_T
899+
end
900+
end
901+
902+
nonnumerics = recon.nonnumerics_fn(src)
903+
(; diffcache_buffer_idx) = recon
904+
if !iszero(diffcache_buffer_idx)
905+
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
906+
end
907+
return promote_with_nothing(
908+
promote_type_with_nothing(eltype(tunablevals), initialvals),
909+
MTKParameters(
910+
tunablevals, initialvals, recon.discretes_fn(src),
911+
recon.consts_fn(src), nonnumerics, oldcache isa Tuple{} ? () : copy.(oldcache)
912+
)
913+
)
914+
end
915+
721916
"""
722917
$(TYPEDSIGNATURES)
723918
@@ -732,10 +927,10 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
732927
unwrapped.
733928
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
734929
"""
735-
function get_mtkparameters_reconstructor(
930+
function MTKParametersReconstructor(
736931
srcsys::AbstractSystem, dstsys::AbstractSystem;
737932
initials = false, unwrap_initials = false, p_constructor = identity,
738-
eval_expression = false, eval_module = @__MODULE__, force_time_independent = false,
933+
force_time_independent = false,
739934
kwargs...
740935
)
741936
_p_constructor = p_constructor
@@ -754,32 +949,27 @@ function get_mtkparameters_reconstructor(
754949
tunable_getter = if isempty(tunable_syms)
755950
Returns(SVector{0, Float64}())
756951
else
757-
p_constructor concrete_getu(
758-
srcsys, tunable_syms; eval_expression, eval_module,
759-
force_time_independent, kwargs...
760-
)
952+
p_constructor CopyParamsByTemplate(srcsys, tunable_syms; kwargs...)
761953
end
762954
initials_getter = if initials && !isempty(syms[2])
763-
initsyms = Vector{Any}(syms[2])
764-
allsyms = Set(variable_symbols(srcsys))
955+
initsyms = syms[2]::Vector{SymbolicT}
956+
allsyms = Set{SymbolicT}(variable_symbols(srcsys))
765957
if unwrap_initials
766958
for i in eachindex(initsyms)
767959
sym = initsyms[i]
768-
innersym = if operation(sym) === getindex
769-
sym, idxs... = arguments(sym)
770-
only(arguments(sym))[idxs...]
960+
arr, isarr = split_indexed_var(sym)
961+
innersym = if isarr
962+
sidx = get_stable_index(sym)
963+
first(arguments(arr))[sidx]
771964
else
772-
only(arguments(sym))
965+
first(arguments(arr))
773966
end
774967
if innersym in allsyms
775968
initsyms[i] = innersym
776969
end
777970
end
778971
end
779-
p_constructor concrete_getu(
780-
srcsys, initsyms; eval_expression, eval_module,
781-
force_time_independent, kwargs...
782-
)
972+
p_constructor CopyParamsByTemplate(srcsys, initsyms; kwargs...)
783973
else
784974
Returns(SVector{0, Float64}())
785975
end
@@ -792,29 +982,25 @@ function get_mtkparameters_reconstructor(
792982
p_constructor(map(x -> x.length, bufsizes))
793983
end
794984
)
985+
795986
# discretes need to be blocked arrays
796987
# the `getu` returns a tuple of arrays corresponding to `p.discretes`
797988
# `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
798989
# `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
799990
# tuple of `BlockedArray`s
800991
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes)
801-
Base.Fix1(broadcast, p_constructor)
992+
Base.Fix1(broadcast, p_constructor) Tuple
802993
# This `broadcast.(collect, ...)` avoids `ReshapedArray`/`SubArray`s from
803994
# appearing in the result.
804-
concrete_getu(
805-
srcsys, Tuple(broadcast.(collect, syms[3]));
806-
eval_expression, eval_module, force_time_independent, kwargs...
807-
)
995+
CopyParamsByTemplate(srcsys, broadcast.(collect, syms[3]); kwargs...)
808996
end
809-
const_getter = if syms[4] == ()
997+
const_getter = if isempty(syms[4])
810998
Returns(())
811999
else
812-
Base.Fix1(broadcast, p_constructor) concrete_getu(
813-
srcsys, Tuple(syms[4]);
814-
eval_expression, eval_module, force_time_independent, kwargs...
815-
)
1000+
Base.Fix1(broadcast, p_constructor) Tuple CopyParamsByTemplate(srcsys, syms[4]; kwargs...)
8161001
end
817-
nonnumeric_getter = if syms[5] == ()
1002+
diffcache_buffer_idx = 0
1003+
nonnumeric_getter = if isempty(syms[5])
8181004
Returns(())
8191005
else
8201006
ic = get_index_cache(dstsys)
@@ -825,44 +1011,19 @@ function get_mtkparameters_reconstructor(
8251011
)
8261012

8271013
diffcache_params = SU.getmetadata(dstsys, DiffCacheParams, Dict{SymbolicT, Int}())::Dict{SymbolicT, Int}
828-
diffcache_buffer_idx = 0
8291014
if !isempty(diffcache_params)
8301015
representative = first(keys(diffcache_params))
8311016
diffcache_buffer_idx, _ = ic.nonnumeric_idx[representative]
8321017
@set! buftypes[diffcache_buffer_idx] = identity
1018+
for (i, sym) in enumerate(syms[5][diffcache_buffer_idx])
1019+
end
8331020
end
8341021
# nonnumerics retain the assigned buffer type without narrowing
8351022
Base.Fix1(broadcast, _p_constructor)
836-
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes)
837-
concrete_getu(
838-
srcsys, Tuple(syms[5]);
839-
eval_expression, eval_module, force_time_independent, kwargs...
840-
)
841-
end
842-
getters = (
843-
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter,
844-
)
845-
getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx
846-
function _getter(valp, initprob)
847-
oldcache = parameter_values(initprob).caches
848-
tunablevals = getters[1](valp)
849-
initialvals = getters[2](valp)
850-
nonnumerics = getters[5](valp)
851-
if !iszero(diffcache_buffer_idx)
852-
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
853-
end
854-
return promote_with_nothing(
855-
promote_type_with_nothing(eltype(tunablevals), initialvals),
856-
MTKParameters(
857-
tunablevals, initialvals, getters[3](valp),
858-
getters[4](valp), nonnumerics, oldcache isa Tuple{} ? () :
859-
copy.(oldcache)
860-
)
861-
)
862-
end
1023+
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) Tuple CopyParamsByTemplate(srcsys, syms[5]; kwargs...)
8631024
end
8641025

865-
return getter
1026+
return MTKParametersReconstructor(tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, diffcache_buffer_idx)
8661027
end
8671028

8681029
function call(f, args...)
@@ -890,7 +1051,7 @@ function ReconstructInitializeprob(
8901051
kwargs...
8911052
)
8921053
if is_split(dstsys)
893-
pgetter = get_mtkparameters_reconstructor(
1054+
pgetter = MTKParametersReconstructor(
8941055
srcsys, dstsys; p_constructor, eval_expression, eval_module,
8951056
force_time_independent = is_steadystateprob, kwargs...
8961057
)
@@ -970,7 +1131,7 @@ function construct_initializeprobpmap(
9701131
)
9711132
@assert is_initializesystem(initsys)
9721133
if is_split(sys)
973-
return let getter = get_mtkparameters_reconstructor(
1134+
return let getter = MTKParametersReconstructor(
9741135
initsys, sys; initials = true, unwrap_initials = true, p_constructor,
9751136
eval_expression, eval_module, kwargs...
9761137
)

0 commit comments

Comments
 (0)