Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 235 additions & 58 deletions lib/ModelingToolkitBase/src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,217 @@ function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray})
return pca.p_constructor(pca.(x))
end

"""
$TYPEDEF

Callable struct designed for use by `MTKParametersReconstructor`. Uses a fixed set of templates to
act as a very dynamic (and limited) observed function returning an array. See `__apply_copy_template`
for the supported templates.
"""
struct CopyParamsByTemplate{T, N}
"""
Whether this is a "root" object, in that it directly contains instantiable templates instead
of recursing into other `CopyParamsByTemplate`.
"""
isroot::Bool
"""
List of templates.
"""
template::T # TODO: This field is parametric because I thought we might want to specialize it in some cases.
"""
Size of the returned buffer.
"""
size::NTuple{N, Int}
end

function __apply_copy_template(valp, @nospecialize(template))
p = parameter_values(valp)
u = state_values(valp)
if template isa ParameterIndex{SciMLStructures.Tunable, UnitRange{Int}}
return p.tunable[template.idx]
elseif template isa ParameterIndex{SciMLStructures.Initials, UnitRange{Int}}
return p.initials[template.idx]
elseif template isa ParameterIndex{SciMLStructures.Discrete, Tuple{Int, UnitRange{Int}}}
return p.discrete[template.idx[1]][template.idx[2]]
elseif template isa ParameterIndex{SciMLStructures.Constants, Tuple{Int, UnitRange{Int}}}
return p.constant[template.idx[1]][template.idx[2]]
elseif template isa ParameterIndex{Nonnumeric, Tuple{Int, UnitRange{Int}}}
return p.nonnumeric[template.idx[1]][template.idx[2]]
elseif template isa UnitRange{Int}
return u[template]
elseif template isa ObservedWrapper
return template(valp)
elseif template isa CopyParamsByTemplate
return template(valp)
else
# MethodError because this is a manual dispatch chain
throw(MethodError(__apply_copy_template, (valp, template)))
end
end

function (cp::CopyParamsByTemplate)(src)
if cp.isroot
reshape(mapreduce(Base.Fix1(__apply_copy_template, src), vcat, cp.template), cp.size)
else
reshape(map(Base.Fix1(__apply_copy_template, src), cp.template), cp.size)
end
end

function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray{SymbolicT}; kws...)
template = []
for sym in syms
symidx = parameter_index(srcsys, sym)
if symidx === nothing
symidx = variable_index(srcsys, sym)
if symidx === nothing
if isempty(template)
push!(template, SymbolicT[sym])
continue
end
prev = template[end]
if prev isa Vector{SymbolicT}
push!(prev, sym)
else
push!(template, SymbolicT[sym])
end
continue
end
if isempty(template)
push!(template, symidx:symidx)
continue
end
prev = template[end]
if prev isa UnitRange{Int} && last(prev) + 1 == symidx
template[end] = first(prev):symidx
else
push!(template, symidx:symidx)
end
continue
end
portion = symidx.portion
_bufidx = symidx.idx
if portion isa SciMLStructures.Tunable || portion isa SciMLStructures.Initials

end
bufidx::UnitRange{Int} = if _bufidx isa AbstractVector{Int}
@assert isequal(vec(_bufidx), first(_bufidx):last(_bufidx))
subidx = nothing
first(_bufidx):last(_bufidx)
elseif _bufidx isa Int
subidx = nothing
_bufidx:_bufidx
elseif _bufidx isa NTuple{2, Int}
subidx = _bufidx[1]
_bufidx[2]:_bufidx[2]
else
# Will error due to the typeassert on `bufidx`
nothing
end
if isempty(template)
if subidx === nothing
push!(template, ParameterIndex(symidx.portion, bufidx))
else
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
end
continue
end
prev = template[end]
if prev isa ParameterIndex && prev.portion === symidx.portion && (subidx === nothing && last(prev.idx) + 1 == first(bufidx) || subidx == prev.idx[1] && last(prev.idx[2]) + 1 == first(bufidx))
if subidx === nothing
template[end] = ParameterIndex(prev.portion, first(prev.idx):last(bufidx))
else
template[end] = ParameterIndex(prev.portion, (subidx, first(prev.idx[2]):last(bufidx)))
end
elseif subidx === nothing
push!(template, ParameterIndex(symidx.portion, bufidx))
else
push!(template, ParameterIndex(symidx.portion, (subidx, bufidx)))
end
end

for i in eachindex(template)
if template[i] isa Vector{SymbolicT}
template[i] = concrete_getu(srcsys, template[i]; kws...)
end
end
return CopyParamsByTemplate(true, template, size(syms))
end

function CopyParamsByTemplate(srcsys::AbstractSystem, syms::AbstractArray; kws...)
return CopyParamsByTemplate(false, [CopyParamsByTemplate(srcsys, sym; kws...) for sym in syms], size(syms))
end

struct MTKParametersReconstructor{T, I, D, C, N}
tunables_fn::T
initials_fn::I
discretes_fn::D
consts_fn::C
nonnumerics_fn::N
diffcache_buffer_idx::Int
end

function (recon::MTKParametersReconstructor)(src, dst)
if state_values(src) === nothing && !(applicable(current_time, src) && current_time(src) isa Real && isfinite(current_time(src)))
baseT = Nothing
elseif state_values(src) === nothing
baseT = typeof(current_time(src))
elseif !(applicable(current_time, src) && current_time(src) isa Real && isfinite(current_time(src)))
baseT = eltype(state_values(src))
else
baseT = promote_type(typeof(current_time(src)), eltype(state_values(src)))
end
src_ps = parameter_values(src)
dst_ps = parameter_values(dst)
oldcache = dst_ps.caches
# These type-assertions allow `remake` to infer even though `CopyParamsByTemplate` doesn't
# TODO: Find a solution to type-assert constants/discretes/nonnumerics
# TODO: Maybe `CopyParamsByTemplate` could use a specializing tuple-form when `FullSpecialize`?
if dst_ps.tunable isa SVector{0}
tunablevals = recon.tunables_fn(src)
else
tunable_elT = promote_type(eltype(dst_ps.tunable), eltype(src_ps.tunable))
if baseT !== Nothing
tunable_elT = promote_type(tunable_elT, baseT)
end
baseT = tunable_elT
if ArrayInterface.ismutable(dst_ps.tunable)
tunable_T = Base.promote_op(similar, typeof(dst_ps.tunable), Type{tunable_elT})
tunablevals = recon.tunables_fn(src)::tunable_T
else
tunable_T = StaticArraysCore.similar_type(typeof(dst_ps.tunable), tunable_elT)
tunablevals = recon.tunables_fn(src)::tunable_T
end
end
if dst_ps.initials isa SVector{0}
initialvals = recon.initials_fn(src)
else
initial_elT = promote_type(eltype(dst_ps.initials), eltype(src_ps.initials))
if baseT !== Nothing
initial_elT = promote_type(initial_elT, baseT)
end
if ArrayInterface.ismutable(dst_ps.initials)
initial_T = Base.promote_op(similar, typeof(dst_ps.initials), Type{initial_elT})
initialvals = recon.initials_fn(src)::initial_T
else
initial_T = StaticArraysCore.similar_type(typeof(dst_ps.initials), initial_elT)
initialvals = recon.initials_fn(src)::initial_T
end
end

nonnumerics = recon.nonnumerics_fn(src)::typeof(dst_ps.nonnumeric)
(; diffcache_buffer_idx) = recon
if !iszero(diffcache_buffer_idx)
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
end
return promote_with_nothing(
promote_type_with_nothing(eltype(tunablevals), initialvals),
MTKParameters(
tunablevals, initialvals, recon.discretes_fn(src),
recon.consts_fn(src), nonnumerics, oldcache isa Tuple{} ? () : copy.(oldcache)
)
)
end

"""
$(TYPEDSIGNATURES)

Expand All @@ -732,10 +943,10 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
unwrapped.
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
"""
function get_mtkparameters_reconstructor(
function MTKParametersReconstructor(
srcsys::AbstractSystem, dstsys::AbstractSystem;
initials = false, unwrap_initials = false, p_constructor = identity,
eval_expression = false, eval_module = @__MODULE__, force_time_independent = false,
force_time_independent = false,
kwargs...
)
_p_constructor = p_constructor
Expand All @@ -754,32 +965,27 @@ function get_mtkparameters_reconstructor(
tunable_getter = if isempty(tunable_syms)
Returns(SVector{0, Float64}())
else
p_constructor ∘ concrete_getu(
srcsys, tunable_syms; eval_expression, eval_module,
force_time_independent, kwargs...
)
p_constructor ∘ CopyParamsByTemplate(srcsys, tunable_syms; kwargs...)
end
initials_getter = if initials && !isempty(syms[2])
initsyms = Vector{Any}(syms[2])
allsyms = Set(variable_symbols(srcsys))
initsyms = syms[2]::Vector{SymbolicT}
allsyms = Set{SymbolicT}(variable_symbols(srcsys))
if unwrap_initials
for i in eachindex(initsyms)
sym = initsyms[i]
innersym = if operation(sym) === getindex
sym, idxs... = arguments(sym)
only(arguments(sym))[idxs...]
arr, isarr = split_indexed_var(sym)
innersym = if isarr
sidx = get_stable_index(sym)
first(arguments(arr))[sidx]
else
only(arguments(sym))
first(arguments(arr))
end
if innersym in allsyms
initsyms[i] = innersym
end
end
end
p_constructor ∘ concrete_getu(
srcsys, initsyms; eval_expression, eval_module,
force_time_independent, kwargs...
)
p_constructor ∘ CopyParamsByTemplate(srcsys, initsyms; kwargs...)
else
Returns(SVector{0, Float64}())
end
Expand All @@ -792,29 +998,25 @@ function get_mtkparameters_reconstructor(
p_constructor(map(x -> x.length, bufsizes))
end
)

# discretes need to be blocked arrays
# the `getu` returns a tuple of arrays corresponding to `p.discretes`
# `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
# `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
# tuple of `BlockedArray`s
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘
Base.Fix1(broadcast, p_constructor) ∘
Base.Fix1(broadcast, p_constructor) ∘ Tuple ∘
# This `broadcast.(collect, ...)` avoids `ReshapedArray`/`SubArray`s from
# appearing in the result.
concrete_getu(
srcsys, Tuple(broadcast.(collect, syms[3]));
eval_expression, eval_module, force_time_independent, kwargs...
)
CopyParamsByTemplate(srcsys, broadcast.(collect, syms[3]); kwargs...)
end
const_getter = if syms[4] == ()
const_getter = if isempty(syms[4])
Returns(())
else
Base.Fix1(broadcast, p_constructor) ∘ concrete_getu(
srcsys, Tuple(syms[4]);
eval_expression, eval_module, force_time_independent, kwargs...
)
Base.Fix1(broadcast, p_constructor) ∘ Tuple ∘ CopyParamsByTemplate(srcsys, syms[4]; kwargs...)
end
nonnumeric_getter = if syms[5] == ()
diffcache_buffer_idx = 0
nonnumeric_getter = if isempty(syms[5])
Returns(())
else
ic = get_index_cache(dstsys)
Expand All @@ -825,44 +1027,19 @@ function get_mtkparameters_reconstructor(
)

diffcache_params = SU.getmetadata(dstsys, DiffCacheParams, Dict{SymbolicT, Int}())::Dict{SymbolicT, Int}
diffcache_buffer_idx = 0
if !isempty(diffcache_params)
representative = first(keys(diffcache_params))
diffcache_buffer_idx, _ = ic.nonnumeric_idx[representative]
@set! buftypes[diffcache_buffer_idx] = identity
for (i, sym) in enumerate(syms[5][diffcache_buffer_idx])
end
end
# nonnumerics retain the assigned buffer type without narrowing
Base.Fix1(broadcast, _p_constructor) ∘
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘
concrete_getu(
srcsys, Tuple(syms[5]);
eval_expression, eval_module, force_time_independent, kwargs...
)
end
getters = (
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter,
)
getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx
function _getter(valp, initprob)
oldcache = parameter_values(initprob).caches
tunablevals = getters[1](valp)
initialvals = getters[2](valp)
nonnumerics = getters[5](valp)
if !iszero(diffcache_buffer_idx)
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
end
return promote_with_nothing(
promote_type_with_nothing(eltype(tunablevals), initialvals),
MTKParameters(
tunablevals, initialvals, getters[3](valp),
getters[4](valp), nonnumerics, oldcache isa Tuple{} ? () :
copy.(oldcache)
)
)
end
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ Tuple ∘ CopyParamsByTemplate(srcsys, syms[5]; kwargs...)
end

return getter
return MTKParametersReconstructor(tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, diffcache_buffer_idx)
end

function call(f, args...)
Expand Down Expand Up @@ -890,7 +1067,7 @@ function ReconstructInitializeprob(
kwargs...
)
if is_split(dstsys)
pgetter = get_mtkparameters_reconstructor(
pgetter = MTKParametersReconstructor(
srcsys, dstsys; p_constructor, eval_expression, eval_module,
force_time_independent = is_steadystateprob, kwargs...
)
Expand Down Expand Up @@ -970,7 +1147,7 @@ function construct_initializeprobpmap(
)
@assert is_initializesystem(initsys)
if is_split(sys)
return let getter = get_mtkparameters_reconstructor(
return let getter = MTKParametersReconstructor(
initsys, sys; initials = true, unwrap_initials = true, p_constructor,
eval_expression, eval_module, kwargs...
)
Expand Down
5 changes: 4 additions & 1 deletion lib/ModelingToolkitBase/test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1659,7 +1659,10 @@ end
initialization_eqs = [
Initial(X2) ~ Γ[1] - Initial(X1),
]
@mtkcompile nlsys = System(eqs, [X1, X2], [k1, k2, Γ]; initialization_eqs)
@mtkcomplete nlsys = System(
[0 ~ k1 * (Γ[1] - X1) - k2 * X1], [X1], [k1, k2, Γ];
initialization_eqs, observed = [X2 ~ Γ[1] - X1]
)

@testset "solves initialization" begin
u0 = [X1 => 1.0, X2 => 2.0]
Expand Down
Loading