Skip to content

Commit 0120442

Browse files
fix: avoid generating massive functions in get_mtkparameters_reconstructor
This used to be the lion's share of the compile time in large models. The infrastructure added here might be useful in other places too.
1 parent 7cd38e2 commit 0120442

1 file changed

Lines changed: 191 additions & 58 deletions

File tree

lib/ModelingToolkitBase/src/systems/problem_utils.jl

Lines changed: 191 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,173 @@ function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray})
718718
return pca.p_constructor(pca.(x))
719719
end
720720

721+
"""
722+
$TYPEDEF
723+
724+
Callable struct designed for use by `MTKParametersReconstructor`. Uses a fixed set of templates to
725+
act as a very dynamic (and limited) observed function returning an array. See `__apply_copy_template`
726+
for the supported templates.
727+
"""
728+
struct CopyParamsByTemplate{T, N}
729+
"""
730+
Whether this is a "root" object, in that it directly contains instantiable templates instead
731+
of recursing into other `CopyParamsByTemplate`.
732+
"""
733+
isroot::Bool
734+
"""
735+
List of templates.
736+
"""
737+
template::T # TODO: This field is parametric because I thought we might want to specialize it in some cases.
738+
"""
739+
Size of the returned buffer.
740+
"""
741+
size::NTuple{N, Int}
742+
end
743+
744+
function __apply_copy_template(valp, @nospecialize(template))
745+
p = parameter_values(valp)
746+
u = state_values(valp)
747+
if template isa ParameterIndex{SciMLStructures.Tunable, UnitRange{Int}}
748+
return p.tunable[template.idx]
749+
elseif template isa ParameterIndex{SciMLStructures.Initials, UnitRange{Int}}
750+
return p.initials[template.idx]
751+
elseif template isa ParameterIndex{SciMLStructures.Discrete, Tuple{Int, UnitRange{Int}}}
752+
return p.discrete[template.idx[1]][template.idx[2]]
753+
elseif template isa ParameterIndex{SciMLStructures.Constants, Tuple{Int, UnitRange{Int}}}
754+
return p.constant[template.idx[1]][template.idx[2]]
755+
elseif template isa ParameterIndex{Nonnumeric, Tuple{Int, UnitRange{Int}}}
756+
return p.nonnumeric[template.idx[1]][template.idx[2]]
757+
elseif template isa UnitRange{Int}
758+
return u[template]
759+
elseif template isa ObservedWrapper
760+
return template(valp)
761+
elseif template isa CopyParamsByTemplate
762+
return template(valp)
763+
else
764+
# MethodError because this is a manual dispatch chain
765+
throw(MethodError(__apply_copy_template, (valp, template)))
766+
end
767+
end
768+
769+
function (cp::CopyParamsByTemplate)(src)
770+
if cp.isroot
771+
reshape(mapreduce(Base.Fix1(__apply_copy_template, src), vcat, cp.template), cp.size)
772+
else
773+
reshape(map(Base.Fix1(__apply_copy_template, src), cp.template), cp.size)
774+
end
775+
end
776+
777+
function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray{SymbolicT}; kws...)
778+
template = []
779+
for sym in syms
780+
symidx = parameter_index(srcsys, sym)
781+
if symidx === nothing
782+
symidx = variable_index(srcsys, sym)
783+
if symidx === nothing
784+
if isempty(template)
785+
push!(template, SymbolicT[sym])
786+
continue
787+
end
788+
prev = template[end]
789+
if prev isa Vector{SymbolicT}
790+
push!(prev, sym)
791+
else
792+
push!(template, SymbolicT[sym])
793+
end
794+
continue
795+
end
796+
if isempty(template)
797+
push!(template, symidx:symidx)
798+
continue
799+
end
800+
prev = template[end]
801+
if prev isa UnitRange{Int} && last(prev) + 1 == symidx
802+
template[end] = first(prev):symidx
803+
else
804+
push!(template, symidx:symidx)
805+
end
806+
continue
807+
end
808+
portion = symidx.portion
809+
_bufidx = symidx.idx
810+
if portion isa SciMLStructures.Tunable || portion isa SciMLStructures.Initials
811+
812+
end
813+
bufidx::UnitRange{Int} = if _bufidx isa AbstractVector{Int}
814+
@assert isequal(vec(_bufidx), first(_bufidx):last(_bufidx))
815+
subidx = nothing
816+
first(_bufidx):last(_bufidx)
817+
elseif _bufidx isa Int
818+
subidx = nothing
819+
_bufidx:_bufidx
820+
elseif _bufidx isa NTuple{2, Int}
821+
subidx = _bufidx[1]
822+
_bufidx[2]:_bufidx[2]
823+
else
824+
# Will error due to the typeassert on `bufidx`
825+
nothing
826+
end
827+
if isempty(template)
828+
if subidx === nothing
829+
push!(template, ParameterIndex(symidx.portion, bufidx))
830+
else
831+
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
832+
end
833+
continue
834+
end
835+
prev = template[end]
836+
if prev isa ParameterIndex && prev.portion === symidx.portion && (subidx === nothing && last(prev.idx) + 1 == first(bufidx) || subidx == prev.idx[1] && last(prev.idx[2]) + 1 == first(bufidx))
837+
if subidx === nothing
838+
template[end] = ParameterIndex(prev.portion, first(prev.idx):last(bufidx))
839+
else
840+
template[end] = ParameterIndex(prev.portion, (subidx, first(prev.idx[2]):last(bufidx)))
841+
end
842+
elseif subidx === nothing
843+
push!(template, ParameterIndex(symidx.portion, bufidx))
844+
else
845+
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
846+
end
847+
end
848+
849+
for i in eachindex(template)
850+
if template[i] isa Vector{SymbolicT}
851+
template[i] = concrete_getu(srcsys, template[i]; kws...)
852+
end
853+
end
854+
return CopyParamsByTemplate(true, template, size(syms))
855+
end
856+
857+
function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray; kws...)
858+
return CopyParamsByTemplate(false, [CopyParamsByTemplate(srcsys, sym; kws...) for sym in syms], size(syms))
859+
end
860+
861+
struct MTKParametersReconstructor{T, I, D, C, N}
862+
tunables_fn::T
863+
initials_fn::I
864+
discretes_fn::D
865+
consts_fn::C
866+
nonnumerics_fn::N
867+
diffcache_buffer_idx::Int
868+
end
869+
870+
function (recon::MTKParametersReconstructor)(src, dst)
871+
oldcache = parameter_values(dst).caches
872+
tunablevals = recon.tunables_fn(src)
873+
initialvals = recon.initials_fn(src)
874+
nonnumerics = recon.nonnumerics_fn(src)
875+
(; diffcache_buffer_idx) = recon
876+
if !iszero(diffcache_buffer_idx)
877+
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
878+
end
879+
return promote_with_nothing(
880+
promote_type_with_nothing(eltype(tunablevals), initialvals),
881+
MTKParameters(
882+
tunablevals, initialvals, recon.discretes_fn(src),
883+
recon.consts_fn(src), nonnumerics, oldcache isa Tuple{} ? () : copy.(oldcache)
884+
)
885+
)
886+
end
887+
721888
"""
722889
$(TYPEDSIGNATURES)
723890
@@ -732,10 +899,10 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
732899
unwrapped.
733900
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
734901
"""
735-
function get_mtkparameters_reconstructor(
902+
function MTKParametersReconstructor(
736903
srcsys::AbstractSystem, dstsys::AbstractSystem;
737904
initials = false, unwrap_initials = false, p_constructor = identity,
738-
eval_expression = false, eval_module = @__MODULE__, force_time_independent = false,
905+
force_time_independent = false,
739906
kwargs...
740907
)
741908
_p_constructor = p_constructor
@@ -754,32 +921,27 @@ function get_mtkparameters_reconstructor(
754921
tunable_getter = if isempty(tunable_syms)
755922
Returns(SVector{0, Float64}())
756923
else
757-
p_constructor concrete_getu(
758-
srcsys, tunable_syms; eval_expression, eval_module,
759-
force_time_independent, kwargs...
760-
)
924+
p_constructor CopyParamsByTemplate(srcsys, tunable_syms; kwargs...)
761925
end
762926
initials_getter = if initials && !isempty(syms[2])
763-
initsyms = Vector{Any}(syms[2])
764-
allsyms = Set(variable_symbols(srcsys))
927+
initsyms = syms[2]::Vector{SymbolicT}
928+
allsyms = Set{SymbolicT}(variable_symbols(srcsys))
765929
if unwrap_initials
766930
for i in eachindex(initsyms)
767931
sym = initsyms[i]
768-
innersym = if operation(sym) === getindex
769-
sym, idxs... = arguments(sym)
770-
only(arguments(sym))[idxs...]
932+
arr, isarr = split_indexed_var(sym)
933+
innersym = if isarr
934+
sidx = get_stable_index(sym)
935+
first(arguments(arr))[sidx]
771936
else
772-
only(arguments(sym))
937+
first(arguments(arr))
773938
end
774939
if innersym in allsyms
775940
initsyms[i] = innersym
776941
end
777942
end
778943
end
779-
p_constructor concrete_getu(
780-
srcsys, initsyms; eval_expression, eval_module,
781-
force_time_independent, kwargs...
782-
)
944+
p_constructor CopyParamsByTemplate(srcsys, initsyms; kwargs...)
783945
else
784946
Returns(SVector{0, Float64}())
785947
end
@@ -792,29 +954,25 @@ function get_mtkparameters_reconstructor(
792954
p_constructor(map(x -> x.length, bufsizes))
793955
end
794956
)
957+
795958
# discretes need to be blocked arrays
796959
# the `getu` returns a tuple of arrays corresponding to `p.discretes`
797960
# `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
798961
# `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
799962
# tuple of `BlockedArray`s
800963
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes)
801-
Base.Fix1(broadcast, p_constructor)
964+
Base.Fix1(broadcast, p_constructor) Tuple
802965
# This `broadcast.(collect, ...)` avoids `ReshapedArray`/`SubArray`s from
803966
# appearing in the result.
804-
concrete_getu(
805-
srcsys, Tuple(broadcast.(collect, syms[3]));
806-
eval_expression, eval_module, force_time_independent, kwargs...
807-
)
967+
CopyParamsByTemplate(srcsys, broadcast.(collect(syms[3])); kwargs...)
808968
end
809-
const_getter = if syms[4] == ()
969+
const_getter = if isempty(syms[4])
810970
Returns(())
811971
else
812-
Base.Fix1(broadcast, p_constructor) concrete_getu(
813-
srcsys, Tuple(syms[4]);
814-
eval_expression, eval_module, force_time_independent, kwargs...
815-
)
972+
Base.Fix1(broadcast, p_constructor) Tuple CopyParamsByTemplate(srcsys, syms[4]; kwargs...)
816973
end
817-
nonnumeric_getter = if syms[5] == ()
974+
diffcache_buffer_idx = 0
975+
nonnumeric_getter = if isempty(syms[5])
818976
Returns(())
819977
else
820978
ic = get_index_cache(dstsys)
@@ -825,44 +983,19 @@ function get_mtkparameters_reconstructor(
825983
)
826984

827985
diffcache_params = SU.getmetadata(dstsys, DiffCacheParams, Dict{SymbolicT, Int}())::Dict{SymbolicT, Int}
828-
diffcache_buffer_idx = 0
829986
if !isempty(diffcache_params)
830987
representative = first(keys(diffcache_params))
831988
diffcache_buffer_idx, _ = ic.nonnumeric_idx[representative]
832989
@set! buftypes[diffcache_buffer_idx] = identity
990+
for (i, sym) in enumerate(syms[5][diffcache_buffer_idx])
991+
end
833992
end
834993
# nonnumerics retain the assigned buffer type without narrowing
835994
Base.Fix1(broadcast, _p_constructor)
836-
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes)
837-
concrete_getu(
838-
srcsys, Tuple(syms[5]);
839-
eval_expression, eval_module, force_time_independent, kwargs...
840-
)
841-
end
842-
getters = (
843-
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter,
844-
)
845-
getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx
846-
function _getter(valp, initprob)
847-
oldcache = parameter_values(initprob).caches
848-
tunablevals = getters[1](valp)
849-
initialvals = getters[2](valp)
850-
nonnumerics = getters[5](valp)
851-
if !iszero(diffcache_buffer_idx)
852-
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
853-
end
854-
return promote_with_nothing(
855-
promote_type_with_nothing(eltype(tunablevals), initialvals),
856-
MTKParameters(
857-
tunablevals, initialvals, getters[3](valp),
858-
getters[4](valp), nonnumerics, oldcache isa Tuple{} ? () :
859-
copy.(oldcache)
860-
)
861-
)
862-
end
995+
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) Tuple CopyParamsByTemplate(srcsys, syms[5]; kwargs...)
863996
end
864997

865-
return getter
998+
return MTKParametersReconstructor(tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, diffcache_buffer_idx)
866999
end
8671000

8681001
function call(f, args...)
@@ -890,7 +1023,7 @@ function ReconstructInitializeprob(
8901023
kwargs...
8911024
)
8921025
if is_split(dstsys)
893-
pgetter = get_mtkparameters_reconstructor(
1026+
pgetter = MTKParametersReconstructor(
8941027
srcsys, dstsys; p_constructor, eval_expression, eval_module,
8951028
force_time_independent = is_steadystateprob, kwargs...
8961029
)
@@ -970,7 +1103,7 @@ function construct_initializeprobpmap(
9701103
)
9711104
@assert is_initializesystem(initsys)
9721105
if is_split(sys)
973-
return let getter = get_mtkparameters_reconstructor(
1106+
return let getter = MTKParametersReconstructor(
9741107
initsys, sys; initials = true, unwrap_initials = true, p_constructor,
9751108
eval_expression, eval_module, kwargs...
9761109
)

0 commit comments

Comments
 (0)