Skip to content

Commit bd2fb43

Browse files
fix: avoid generating massive functions in get_mtkparameters_reconstructor
This used to be the lion's share of the compile time in large models. The infrastructure added here might be useful in other places too.
1 parent 411ae9f commit bd2fb43

1 file changed

Lines changed: 244 additions & 58 deletions

File tree

lib/ModelingToolkitBase/src/systems/problem_utils.jl

Lines changed: 244 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,226 @@ function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray})
718718
return pca.p_constructor(pca.(x))
719719
end
720720

721+
"""
722+
$TYPEDEF
723+
724+
Callable struct designed for use by `MTKParametersReconstructor`. Uses a fixed set of templates to
725+
act as a very dynamic (and limited) observed function returning an array. See `__apply_copy_template`
726+
for the supported templates.
727+
"""
728+
struct CopyParamsByTemplate{T, N}
729+
"""
730+
Whether this is a "root" object, in that it directly contains instantiable templates instead
731+
of recursing into other `CopyParamsByTemplate`.
732+
"""
733+
isroot::Bool
734+
"""
735+
List of templates.
736+
"""
737+
template::T # TODO: This field is parametric because I thought we might want to specialize it in some cases.
738+
"""
739+
Size of the returned buffer.
740+
"""
741+
size::NTuple{N, Int}
742+
end
743+
744+
function __apply_copy_template(valp, @nospecialize(template))
745+
p = parameter_values(valp)
746+
u = state_values(valp)
747+
if template isa ParameterIndex{SciMLStructures.Tunable, UnitRange{Int}}
748+
if p isa MTKParameters
749+
return p.tunable[template.idx]
750+
else
751+
return p[template.idx]
752+
end
753+
elseif template isa ParameterIndex{SciMLStructures.Initials, UnitRange{Int}}
754+
return p.initials[template.idx]
755+
elseif template isa ParameterIndex{SciMLStructures.Discrete, Tuple{Int, UnitRange{Int}}}
756+
return p.discrete[template.idx[1]][template.idx[2]]
757+
elseif template isa ParameterIndex{SciMLStructures.Constants, Tuple{Int, UnitRange{Int}}}
758+
return p.constant[template.idx[1]][template.idx[2]]
759+
elseif template isa ParameterIndex{Nonnumeric, Tuple{Int, UnitRange{Int}}}
760+
return p.nonnumeric[template.idx[1]][template.idx[2]]
761+
elseif template isa UnitRange{Int}
762+
return u[template]
763+
elseif template isa ObservedWrapper
764+
return template(valp)
765+
elseif template isa CopyParamsByTemplate
766+
return template(valp)
767+
else
768+
# MethodError because this is a manual dispatch chain
769+
throw(MethodError(__apply_copy_template, (valp, template)))
770+
end
771+
end
772+
773+
function (cp::CopyParamsByTemplate)(src)
774+
if cp.isroot
775+
reshape(mapreduce(Base.Fix1(__apply_copy_template, src), vcat, cp.template), cp.size)
776+
else
777+
reshape(map(Base.Fix1(__apply_copy_template, src), cp.template), cp.size)
778+
end
779+
end
780+
781+
function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray{SymbolicT}; kws...)
782+
template = []
783+
for sym in syms
784+
symidx = parameter_index(srcsys, sym)
785+
if symidx === nothing
786+
symidx = variable_index(srcsys, sym)
787+
if symidx === nothing
788+
if isempty(template)
789+
push!(template, SymbolicT[sym])
790+
continue
791+
end
792+
prev = template[end]
793+
if prev isa Vector{SymbolicT}
794+
push!(prev, sym)
795+
else
796+
push!(template, SymbolicT[sym])
797+
end
798+
continue
799+
end
800+
if isempty(template)
801+
push!(template, symidx:symidx)
802+
continue
803+
end
804+
prev = template[end]
805+
if prev isa UnitRange{Int} && last(prev) + 1 == symidx
806+
template[end] = first(prev):symidx
807+
else
808+
push!(template, symidx:symidx)
809+
end
810+
continue
811+
elseif symidx isa Int
812+
symidx = ParameterIndex(SciMLStructures.Tunable(), symidx)
813+
end
814+
portion = symidx.portion
815+
_bufidx = symidx.idx
816+
bufidx::UnitRange{Int} = if _bufidx isa AbstractVector{Int}
817+
@assert isequal(vec(_bufidx), first(_bufidx):last(_bufidx))
818+
subidx = nothing
819+
first(_bufidx):last(_bufidx)
820+
elseif _bufidx isa Int
821+
subidx = nothing
822+
_bufidx:_bufidx
823+
elseif _bufidx isa NTuple{2, Int}
824+
subidx = _bufidx[1]
825+
_bufidx[2]:_bufidx[2]
826+
else
827+
# Will error due to the typeassert on `bufidx`
828+
nothing
829+
end
830+
if isempty(template)
831+
if subidx === nothing
832+
push!(template, ParameterIndex(symidx.portion, bufidx))
833+
else
834+
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
835+
end
836+
continue
837+
end
838+
prev = template[end]
839+
if prev isa ParameterIndex && prev.portion === symidx.portion && (subidx === nothing && last(prev.idx) + 1 == first(bufidx) || subidx == prev.idx[1] && last(prev.idx[2]) + 1 == first(bufidx))
840+
if subidx === nothing
841+
template[end] = ParameterIndex(prev.portion, first(prev.idx):last(bufidx))
842+
else
843+
template[end] = ParameterIndex(prev.portion, (subidx, first(prev.idx[2]):last(bufidx)))
844+
end
845+
elseif subidx === nothing
846+
push!(template, ParameterIndex(symidx.portion, bufidx))
847+
else
848+
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
849+
end
850+
end
851+
852+
for i in eachindex(template)
853+
if template[i] isa Vector{SymbolicT}
854+
template[i] = concrete_getu(srcsys, template[i]; kws...)
855+
end
856+
end
857+
return CopyParamsByTemplate(true, template, size(syms))
858+
end
859+
860+
function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray; kws...)
861+
return CopyParamsByTemplate(false, [CopyParamsByTemplate(srcsys, sym; kws...) for sym in syms], size(syms))
862+
end
863+
864+
struct MTKParametersReconstructor{T, I, D, C, N}
865+
tunables_fn::T
866+
initials_fn::I
867+
discretes_fn::D
868+
consts_fn::C
869+
nonnumerics_fn::N
870+
diffcache_buffer_idx::Int
871+
end
872+
873+
function (recon::MTKParametersReconstructor)(src, dst)
874+
if state_values(src) === nothing && !(applicable(current_time, src) && current_time(src) isa Real && isfinite(current_time(src)))
875+
baseT = Nothing
876+
elseif state_values(src) === nothing
877+
baseT = typeof(current_time(src))
878+
elseif !(applicable(current_time, src) && current_time(src) isa Real && isfinite(current_time(src)))
879+
baseT = eltype(state_values(src))
880+
else
881+
baseT = promote_type(typeof(current_time(src)), eltype(state_values(src)))
882+
end
883+
src_ps = parameter_values(src)
884+
dst_ps = parameter_values(dst)
885+
oldcache = dst_ps.caches
886+
# These type-assertions allow `remake` to infer even though `CopyParamsByTemplate` doesn't
887+
# TODO: Find a solution to type-assert constants/discretes/nonnumerics
888+
# TODO: Maybe `CopyParamsByTemplate` could use a specializing tuple-form when `FullSpecialize`?
889+
if dst_ps.tunable isa SVector{0}
890+
tunablevals = recon.tunables_fn(src)
891+
else
892+
tunable_elT = eltype(dst_ps.tunable)
893+
if !(src_ps.tunable isa SVector{0})
894+
tunable_elT = promote_type(tunable_elT, eltype(src_ps.tunable))
895+
end
896+
if baseT !== Nothing
897+
tunable_elT = promote_type(tunable_elT, baseT)
898+
end
899+
baseT = tunable_elT
900+
if ArrayInterface.ismutable(dst_ps.tunable)
901+
tunable_T = Base.promote_op(similar, typeof(dst_ps.tunable), Type{tunable_elT})
902+
tunablevals = recon.tunables_fn(src)::tunable_T
903+
else
904+
tunable_T = StaticArraysCore.similar_type(typeof(dst_ps.tunable), tunable_elT)
905+
tunablevals = recon.tunables_fn(src)::tunable_T
906+
end
907+
end
908+
if dst_ps.initials isa SVector{0}
909+
initialvals = recon.initials_fn(src)
910+
else
911+
initial_elT = eltype(dst_ps.initials)
912+
if !(src_ps.initials isa SVector{0})
913+
initial_elT = promote_type(initial_elT, eltype(src_ps.initials))
914+
end
915+
if baseT !== Nothing
916+
initial_elT = promote_type(initial_elT, baseT)
917+
end
918+
if ArrayInterface.ismutable(dst_ps.initials)
919+
initial_T = Base.promote_op(similar, typeof(dst_ps.initials), Type{initial_elT})
920+
initialvals = recon.initials_fn(src)::initial_T
921+
else
922+
initial_T = StaticArraysCore.similar_type(typeof(dst_ps.initials), initial_elT)
923+
initialvals = recon.initials_fn(src)::initial_T
924+
end
925+
end
926+
927+
nonnumerics = recon.nonnumerics_fn(src)::typeof(dst_ps.nonnumeric)
928+
(; diffcache_buffer_idx) = recon
929+
if !iszero(diffcache_buffer_idx)
930+
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
931+
end
932+
return promote_with_nothing(
933+
promote_type_with_nothing(eltype(tunablevals), initialvals),
934+
MTKParameters(
935+
tunablevals, initialvals, recon.discretes_fn(src),
936+
recon.consts_fn(src), nonnumerics, oldcache isa Tuple{} ? () : copy.(oldcache)
937+
)
938+
)
939+
end
940+
721941
"""
722942
$(TYPEDSIGNATURES)
723943
@@ -732,10 +952,10 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
732952
unwrapped.
733953
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
734954
"""
735-
function get_mtkparameters_reconstructor(
955+
function MTKParametersReconstructor(
736956
srcsys::AbstractSystem, dstsys::AbstractSystem;
737957
initials = false, unwrap_initials = false, p_constructor = identity,
738-
eval_expression = false, eval_module = @__MODULE__, force_time_independent = false,
958+
force_time_independent = false,
739959
kwargs...
740960
)
741961
_p_constructor = p_constructor
@@ -754,32 +974,27 @@ function get_mtkparameters_reconstructor(
754974
tunable_getter = if isempty(tunable_syms)
755975
Returns(SVector{0, Float64}())
756976
else
757-
p_constructor concrete_getu(
758-
srcsys, tunable_syms; eval_expression, eval_module,
759-
force_time_independent, kwargs...
760-
)
977+
p_constructor CopyParamsByTemplate(srcsys, tunable_syms; kwargs...)
761978
end
762979
initials_getter = if initials && !isempty(syms[2])
763-
initsyms = Vector{Any}(syms[2])
764-
allsyms = Set(variable_symbols(srcsys))
980+
initsyms = syms[2]::Vector{SymbolicT}
981+
allsyms = Set{SymbolicT}(variable_symbols(srcsys))
765982
if unwrap_initials
766983
for i in eachindex(initsyms)
767984
sym = initsyms[i]
768-
innersym = if operation(sym) === getindex
769-
sym, idxs... = arguments(sym)
770-
only(arguments(sym))[idxs...]
985+
arr, isarr = split_indexed_var(sym)
986+
innersym = if isarr
987+
sidx = get_stable_index(sym)
988+
first(arguments(arr))[sidx]
771989
else
772-
only(arguments(sym))
990+
first(arguments(arr))
773991
end
774992
if innersym in allsyms
775993
initsyms[i] = innersym
776994
end
777995
end
778996
end
779-
p_constructor concrete_getu(
780-
srcsys, initsyms; eval_expression, eval_module,
781-
force_time_independent, kwargs...
782-
)
997+
p_constructor CopyParamsByTemplate(srcsys, initsyms; kwargs...)
783998
else
784999
Returns(SVector{0, Float64}())
7851000
end
@@ -792,29 +1007,25 @@ function get_mtkparameters_reconstructor(
7921007
p_constructor(map(x -> x.length, bufsizes))
7931008
end
7941009
)
1010+
7951011
# discretes need to be blocked arrays
7961012
# the `getu` returns a tuple of arrays corresponding to `p.discretes`
7971013
# `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
7981014
# `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
7991015
# tuple of `BlockedArray`s
8001016
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes)
801-
Base.Fix1(broadcast, p_constructor)
1017+
Base.Fix1(broadcast, p_constructor) Tuple
8021018
# This `broadcast.(collect, ...)` avoids `ReshapedArray`/`SubArray`s from
8031019
# appearing in the result.
804-
concrete_getu(
805-
srcsys, Tuple(broadcast.(collect, syms[3]));
806-
eval_expression, eval_module, force_time_independent, kwargs...
807-
)
1020+
CopyParamsByTemplate(srcsys, broadcast.(collect, syms[3]); kwargs...)
8081021
end
809-
const_getter = if syms[4] == ()
1022+
const_getter = if isempty(syms[4])
8101023
Returns(())
8111024
else
812-
Base.Fix1(broadcast, p_constructor) concrete_getu(
813-
srcsys, Tuple(syms[4]);
814-
eval_expression, eval_module, force_time_independent, kwargs...
815-
)
1025+
Base.Fix1(broadcast, p_constructor) Tuple CopyParamsByTemplate(srcsys, syms[4]; kwargs...)
8161026
end
817-
nonnumeric_getter = if syms[5] == ()
1027+
diffcache_buffer_idx = 0
1028+
nonnumeric_getter = if isempty(syms[5])
8181029
Returns(())
8191030
else
8201031
ic = get_index_cache(dstsys)
@@ -825,44 +1036,19 @@ function get_mtkparameters_reconstructor(
8251036
)
8261037

8271038
diffcache_params = SU.getmetadata(dstsys, DiffCacheParams, Dict{SymbolicT, Int}())::Dict{SymbolicT, Int}
828-
diffcache_buffer_idx = 0
8291039
if !isempty(diffcache_params)
8301040
representative = first(keys(diffcache_params))
8311041
diffcache_buffer_idx, _ = ic.nonnumeric_idx[representative]
8321042
@set! buftypes[diffcache_buffer_idx] = identity
1043+
for (i, sym) in enumerate(syms[5][diffcache_buffer_idx])
1044+
end
8331045
end
8341046
# nonnumerics retain the assigned buffer type without narrowing
8351047
Base.Fix1(broadcast, _p_constructor)
836-
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes)
837-
concrete_getu(
838-
srcsys, Tuple(syms[5]);
839-
eval_expression, eval_module, force_time_independent, kwargs...
840-
)
841-
end
842-
getters = (
843-
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter,
844-
)
845-
getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx
846-
function _getter(valp, initprob)
847-
oldcache = parameter_values(initprob).caches
848-
tunablevals = getters[1](valp)
849-
initialvals = getters[2](valp)
850-
nonnumerics = getters[5](valp)
851-
if !iszero(diffcache_buffer_idx)
852-
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
853-
end
854-
return promote_with_nothing(
855-
promote_type_with_nothing(eltype(tunablevals), initialvals),
856-
MTKParameters(
857-
tunablevals, initialvals, getters[3](valp),
858-
getters[4](valp), nonnumerics, oldcache isa Tuple{} ? () :
859-
copy.(oldcache)
860-
)
861-
)
862-
end
1048+
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) Tuple CopyParamsByTemplate(srcsys, syms[5]; kwargs...)
8631049
end
8641050

865-
return getter
1051+
return MTKParametersReconstructor(tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, diffcache_buffer_idx)
8661052
end
8671053

8681054
function call(f, args...)
@@ -890,7 +1076,7 @@ function ReconstructInitializeprob(
8901076
kwargs...
8911077
)
8921078
if is_split(dstsys)
893-
pgetter = get_mtkparameters_reconstructor(
1079+
pgetter = MTKParametersReconstructor(
8941080
srcsys, dstsys; p_constructor, eval_expression, eval_module,
8951081
force_time_independent = is_steadystateprob, kwargs...
8961082
)
@@ -970,7 +1156,7 @@ function construct_initializeprobpmap(
9701156
)
9711157
@assert is_initializesystem(initsys)
9721158
if is_split(sys)
973-
return let getter = get_mtkparameters_reconstructor(
1159+
return let getter = MTKParametersReconstructor(
9741160
initsys, sys; initials = true, unwrap_initials = true, p_constructor,
9751161
eval_expression, eval_module, kwargs...
9761162
)

0 commit comments

Comments
 (0)