Skip to content

Commit 21f7b9e

Browse files
fix: avoid generating massive functions in get_mtkparameters_reconstructor
This used to be the lion's share of the compile time in large models. The infrastructure added here might be useful in other places too.
1 parent 46f2ef2 commit 21f7b9e

1 file changed

Lines changed: 235 additions & 58 deletions

File tree

lib/ModelingToolkitBase/src/systems/problem_utils.jl

Lines changed: 235 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,217 @@ function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray})
718718
return pca.p_constructor(pca.(x))
719719
end
720720

721+
"""
722+
$TYPEDEF
723+
724+
Callable struct designed for use by `MTKParametersReconstructor`. Uses a fixed set of templates to
725+
act as a very dynamic (and limited) observed function returning an array. See `__apply_copy_template`
726+
for the supported templates.
727+
"""
728+
struct CopyParamsByTemplate{T, N}
729+
"""
730+
Whether this is a "root" object, in that it directly contains instantiable templates instead
731+
of recursing into other `CopyParamsByTemplate`.
732+
"""
733+
isroot::Bool
734+
"""
735+
List of templates.
736+
"""
737+
template::T # TODO: This field is parametric because I thought we might want to specialize it in some cases.
738+
"""
739+
Size of the returned buffer.
740+
"""
741+
size::NTuple{N, Int}
742+
end
743+
744+
function __apply_copy_template(valp, @nospecialize(template))
745+
p = parameter_values(valp)
746+
u = state_values(valp)
747+
if template isa ParameterIndex{SciMLStructures.Tunable, UnitRange{Int}}
748+
return p.tunable[template.idx]
749+
elseif template isa ParameterIndex{SciMLStructures.Initials, UnitRange{Int}}
750+
return p.initials[template.idx]
751+
elseif template isa ParameterIndex{SciMLStructures.Discrete, Tuple{Int, UnitRange{Int}}}
752+
return p.discrete[template.idx[1]][template.idx[2]]
753+
elseif template isa ParameterIndex{SciMLStructures.Constants, Tuple{Int, UnitRange{Int}}}
754+
return p.constant[template.idx[1]][template.idx[2]]
755+
elseif template isa ParameterIndex{Nonnumeric, Tuple{Int, UnitRange{Int}}}
756+
return p.nonnumeric[template.idx[1]][template.idx[2]]
757+
elseif template isa UnitRange{Int}
758+
return u[template]
759+
elseif template isa ObservedWrapper
760+
return template(valp)
761+
elseif template isa CopyParamsByTemplate
762+
return template(valp)
763+
else
764+
# MethodError because this is a manual dispatch chain
765+
throw(MethodError(__apply_copy_template, (valp, template)))
766+
end
767+
end
768+
769+
function (cp::CopyParamsByTemplate)(src)
770+
if cp.isroot
771+
reshape(mapreduce(Base.Fix1(__apply_copy_template, src), vcat, cp.template), cp.size)
772+
else
773+
reshape(map(Base.Fix1(__apply_copy_template, src), cp.template), cp.size)
774+
end
775+
end
776+
777+
function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray{SymbolicT}; kws...)
778+
template = []
779+
for sym in syms
780+
symidx = parameter_index(srcsys, sym)
781+
if symidx === nothing
782+
symidx = variable_index(srcsys, sym)
783+
if symidx === nothing
784+
if isempty(template)
785+
push!(template, SymbolicT[sym])
786+
continue
787+
end
788+
prev = template[end]
789+
if prev isa Vector{SymbolicT}
790+
push!(prev, sym)
791+
else
792+
push!(template, SymbolicT[sym])
793+
end
794+
continue
795+
end
796+
if isempty(template)
797+
push!(template, symidx:symidx)
798+
continue
799+
end
800+
prev = template[end]
801+
if prev isa UnitRange{Int} && last(prev) + 1 == symidx
802+
template[end] = first(prev):symidx
803+
else
804+
push!(template, symidx:symidx)
805+
end
806+
continue
807+
end
808+
portion = symidx.portion
809+
_bufidx = symidx.idx
810+
if portion isa SciMLStructures.Tunable || portion isa SciMLStructures.Initials
811+
812+
end
813+
bufidx::UnitRange{Int} = if _bufidx isa AbstractVector{Int}
814+
@assert isequal(vec(_bufidx), first(_bufidx):last(_bufidx))
815+
subidx = nothing
816+
first(_bufidx):last(_bufidx)
817+
elseif _bufidx isa Int
818+
subidx = nothing
819+
_bufidx:_bufidx
820+
elseif _bufidx isa NTuple{2, Int}
821+
subidx = _bufidx[1]
822+
_bufidx[2]:_bufidx[2]
823+
else
824+
# Will error due to the typeassert on `bufidx`
825+
nothing
826+
end
827+
if isempty(template)
828+
if subidx === nothing
829+
push!(template, ParameterIndex(symidx.portion, bufidx))
830+
else
831+
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
832+
end
833+
continue
834+
end
835+
prev = template[end]
836+
if prev isa ParameterIndex && prev.portion === symidx.portion && (subidx === nothing && last(prev.idx) + 1 == first(bufidx) || subidx == prev.idx[1] && last(prev.idx[2]) + 1 == first(bufidx))
837+
if subidx === nothing
838+
template[end] = ParameterIndex(prev.portion, first(prev.idx):last(bufidx))
839+
else
840+
template[end] = ParameterIndex(prev.portion, (subidx, first(prev.idx[2]):last(bufidx)))
841+
end
842+
elseif subidx === nothing
843+
push!(template, ParameterIndex(symidx.portion, bufidx))
844+
else
845+
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
846+
end
847+
end
848+
849+
for i in eachindex(template)
850+
if template[i] isa Vector{SymbolicT}
851+
template[i] = concrete_getu(srcsys, template[i]; kws...)
852+
end
853+
end
854+
return CopyParamsByTemplate(true, template, size(syms))
855+
end
856+
857+
function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray; kws...)
858+
return CopyParamsByTemplate(false, [CopyParamsByTemplate(srcsys, sym; kws...) for sym in syms], size(syms))
859+
end
860+
861+
struct MTKParametersReconstructor{T, I, D, C, N}
862+
tunables_fn::T
863+
initials_fn::I
864+
discretes_fn::D
865+
consts_fn::C
866+
nonnumerics_fn::N
867+
diffcache_buffer_idx::Int
868+
end
869+
870+
function (recon::MTKParametersReconstructor)(src, dst)
871+
if state_values(src) === nothing && !(current_time(src) isa Real && isfinite(current_time(src)))
872+
baseT = Nothing
873+
elseif state_values(src) === nothing
874+
baseT = typeof(current_time(src))
875+
elseif !(current_time(src) isa Real && isfinite(current_time(src)))
876+
baseT = eltype(state_values(src))
877+
else
878+
baseT = promote_type(typeof(current_time(src)), eltype(state_values(src)))
879+
end
880+
src_ps = parameter_values(src)
881+
dst_ps = parameter_values(dst)
882+
oldcache = dst_ps.caches
883+
# These type-assertions allow `remake` to infer even though `CopyParamsByTemplate` doesn't
884+
# TODO: Find a solution to type-assert constants/discretes/nonnumerics
885+
# TODO: Maybe `CopyParamsByTemplate` could use a specializing tuple-form when `FullSpecialize`?
886+
if dst_ps.tunable isa SVector{0}
887+
tunablevals = recon.tunables_fn(src)
888+
else
889+
tunable_elT = promote_type(eltype(dst_ps.tunable), eltype(src_ps.tunable))
890+
if baseT !== Nothing
891+
tunable_elT = promote_type(tunable_elT, baseT)
892+
end
893+
baseT = tunable_elT
894+
if ArrayInterface.ismutable(dst_ps.tunable)
895+
tunable_T = Base.promote_op(similar, typeof(dst_ps.tunable), Type{tunable_elT})
896+
tunablevals = recon.tunables_fn(src)::tunable_T
897+
else
898+
tunable_T = StaticArraysCore.similar_type(typeof(dst_ps.tunable), tunable_elT)
899+
tunablevals = recon.tunables_fn(src)::tunable_T
900+
end
901+
end
902+
if dst_ps.initials isa SVector{0}
903+
initialvals = recon.initials_fn(src)
904+
else
905+
initial_elT = promote_type(eltype(dst_ps.initials), eltype(src_ps.initials))
906+
if baseT !== Nothing
907+
initial_elT = promote_type(initial_elT, baseT)
908+
end
909+
if ArrayInterface.ismutable(dst_ps.initials)
910+
initial_T = Base.promote_op(similar, typeof(dst_ps.initials), Type{initial_elT})
911+
initialvals = recon.initials_fn(src)::initial_T
912+
else
913+
initial_T = StaticArraysCore.similar_type(typeof(dst_ps.initials), initial_elT)
914+
initialvals = recon.initials_fn(src)::initial_T
915+
end
916+
end
917+
918+
nonnumerics = recon.nonnumerics_fn(src)::typeof(dst_ps.nonnumeric)
919+
(; diffcache_buffer_idx) = recon
920+
if !iszero(diffcache_buffer_idx)
921+
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
922+
end
923+
return promote_with_nothing(
924+
promote_type_with_nothing(eltype(tunablevals), initialvals),
925+
MTKParameters(
926+
tunablevals, initialvals, recon.discretes_fn(src),
927+
recon.consts_fn(src), nonnumerics, oldcache isa Tuple{} ? () : copy.(oldcache)
928+
)
929+
)
930+
end
931+
721932
"""
722933
$(TYPEDSIGNATURES)
723934
@@ -732,10 +943,10 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
732943
unwrapped.
733944
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
734945
"""
735-
function get_mtkparameters_reconstructor(
946+
function MTKParametersReconstructor(
736947
srcsys::AbstractSystem, dstsys::AbstractSystem;
737948
initials = false, unwrap_initials = false, p_constructor = identity,
738-
eval_expression = false, eval_module = @__MODULE__, force_time_independent = false,
949+
force_time_independent = false,
739950
kwargs...
740951
)
741952
_p_constructor = p_constructor
@@ -754,32 +965,27 @@ function get_mtkparameters_reconstructor(
754965
tunable_getter = if isempty(tunable_syms)
755966
Returns(SVector{0, Float64}())
756967
else
757-
p_constructor concrete_getu(
758-
srcsys, tunable_syms; eval_expression, eval_module,
759-
force_time_independent, kwargs...
760-
)
968+
p_constructor CopyParamsByTemplate(srcsys, tunable_syms; kwargs...)
761969
end
762970
initials_getter = if initials && !isempty(syms[2])
763-
initsyms = Vector{Any}(syms[2])
764-
allsyms = Set(variable_symbols(srcsys))
971+
initsyms = syms[2]::Vector{SymbolicT}
972+
allsyms = Set{SymbolicT}(variable_symbols(srcsys))
765973
if unwrap_initials
766974
for i in eachindex(initsyms)
767975
sym = initsyms[i]
768-
innersym = if operation(sym) === getindex
769-
sym, idxs... = arguments(sym)
770-
only(arguments(sym))[idxs...]
976+
arr, isarr = split_indexed_var(sym)
977+
innersym = if isarr
978+
sidx = get_stable_index(sym)
979+
first(arguments(arr))[sidx]
771980
else
772-
only(arguments(sym))
981+
first(arguments(arr))
773982
end
774983
if innersym in allsyms
775984
initsyms[i] = innersym
776985
end
777986
end
778987
end
779-
p_constructor concrete_getu(
780-
srcsys, initsyms; eval_expression, eval_module,
781-
force_time_independent, kwargs...
782-
)
988+
p_constructor CopyParamsByTemplate(srcsys, initsyms; kwargs...)
783989
else
784990
Returns(SVector{0, Float64}())
785991
end
@@ -792,29 +998,25 @@ function get_mtkparameters_reconstructor(
792998
p_constructor(map(x -> x.length, bufsizes))
793999
end
7941000
)
1001+
7951002
# discretes need to be blocked arrays
7961003
# the `getu` returns a tuple of arrays corresponding to `p.discretes`
7971004
# `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
7981005
# `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
7991006
# tuple of `BlockedArray`s
8001007
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes)
801-
Base.Fix1(broadcast, p_constructor)
1008+
Base.Fix1(broadcast, p_constructor) Tuple
8021009
# This `broadcast.(collect, ...)` avoids `ReshapedArray`/`SubArray`s from
8031010
# appearing in the result.
804-
concrete_getu(
805-
srcsys, Tuple(broadcast.(collect, syms[3]));
806-
eval_expression, eval_module, force_time_independent, kwargs...
807-
)
1011+
CopyParamsByTemplate(srcsys, broadcast.(collect, syms[3]); kwargs...)
8081012
end
809-
const_getter = if syms[4] == ()
1013+
const_getter = if isempty(syms[4])
8101014
Returns(())
8111015
else
812-
Base.Fix1(broadcast, p_constructor) concrete_getu(
813-
srcsys, Tuple(syms[4]);
814-
eval_expression, eval_module, force_time_independent, kwargs...
815-
)
1016+
Base.Fix1(broadcast, p_constructor) Tuple CopyParamsByTemplate(srcsys, syms[4]; kwargs...)
8161017
end
817-
nonnumeric_getter = if syms[5] == ()
1018+
diffcache_buffer_idx = 0
1019+
nonnumeric_getter = if isempty(syms[5])
8181020
Returns(())
8191021
else
8201022
ic = get_index_cache(dstsys)
@@ -825,44 +1027,19 @@ function get_mtkparameters_reconstructor(
8251027
)
8261028

8271029
diffcache_params = SU.getmetadata(dstsys, DiffCacheParams, Dict{SymbolicT, Int}())::Dict{SymbolicT, Int}
828-
diffcache_buffer_idx = 0
8291030
if !isempty(diffcache_params)
8301031
representative = first(keys(diffcache_params))
8311032
diffcache_buffer_idx, _ = ic.nonnumeric_idx[representative]
8321033
@set! buftypes[diffcache_buffer_idx] = identity
1034+
for (i, sym) in enumerate(syms[5][diffcache_buffer_idx])
1035+
end
8331036
end
8341037
# nonnumerics retain the assigned buffer type without narrowing
8351038
Base.Fix1(broadcast, _p_constructor)
836-
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes)
837-
concrete_getu(
838-
srcsys, Tuple(syms[5]);
839-
eval_expression, eval_module, force_time_independent, kwargs...
840-
)
841-
end
842-
getters = (
843-
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter,
844-
)
845-
getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx
846-
function _getter(valp, initprob)
847-
oldcache = parameter_values(initprob).caches
848-
tunablevals = getters[1](valp)
849-
initialvals = getters[2](valp)
850-
nonnumerics = getters[5](valp)
851-
if !iszero(diffcache_buffer_idx)
852-
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
853-
end
854-
return promote_with_nothing(
855-
promote_type_with_nothing(eltype(tunablevals), initialvals),
856-
MTKParameters(
857-
tunablevals, initialvals, getters[3](valp),
858-
getters[4](valp), nonnumerics, oldcache isa Tuple{} ? () :
859-
copy.(oldcache)
860-
)
861-
)
862-
end
1039+
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) Tuple CopyParamsByTemplate(srcsys, syms[5]; kwargs...)
8631040
end
8641041

865-
return getter
1042+
return MTKParametersReconstructor(tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, diffcache_buffer_idx)
8661043
end
8671044

8681045
function call(f, args...)
@@ -890,7 +1067,7 @@ function ReconstructInitializeprob(
8901067
kwargs...
8911068
)
8921069
if is_split(dstsys)
893-
pgetter = get_mtkparameters_reconstructor(
1070+
pgetter = MTKParametersReconstructor(
8941071
srcsys, dstsys; p_constructor, eval_expression, eval_module,
8951072
force_time_independent = is_steadystateprob, kwargs...
8961073
)
@@ -970,7 +1147,7 @@ function construct_initializeprobpmap(
9701147
)
9711148
@assert is_initializesystem(initsys)
9721149
if is_split(sys)
973-
return let getter = get_mtkparameters_reconstructor(
1150+
return let getter = MTKParametersReconstructor(
9741151
initsys, sys; initials = true, unwrap_initials = true, p_constructor,
9751152
eval_expression, eval_module, kwargs...
9761153
)

0 commit comments

Comments
 (0)