Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ StateSelection = "1.9.1"
StaticArrays = "1.9.14"
StochasticDiffEq = "6.82.0, 7"
SymbolicIndexingInterface = "0.3.39"
SymbolicUtils = "4.13"
SymbolicUtils = "4.28"
Symbolics = "7"
UnPack = "0.1, 1.0"
julia = "1.9"
Expand Down
20 changes: 14 additions & 6 deletions lib/ModelingToolkitBase/src/systems/codegen_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ generated functions, and `args` are the arguments.
`MTKParameters` object are present. These are collapsed into a single argument and
destructured inside the function. `p_start` must also be provided for non-split systems
since it is used by `wrap_delays`.
- `compress_args`: A list of argument ranges that end before `p_start`.
Each range will be compressed into a single argument to the function. For example,
If there are 5 elements in `args` and `compress_args = [2:3]`, then the generated function
will take 4 arguments, where the second should be an indexable collection of the second
and third elements in `args`.
- `wrap_delays`: Whether to transform delayed unknowns of `sys` present in `expr` into
calls to a history function. The history function is added to the list of arguments
right before parameters, at the index `p_start`.
Expand Down Expand Up @@ -290,7 +295,7 @@ All other keyword arguments are forwarded to `build_function`.
"""
Base.@nospecializeinfer function build_function_wrapper(
sys::AbstractSystem, @nospecialize(expr), @nospecialize(args...); p_start = 2,
p_end = is_time_dependent(sys) ? length(args) - 1 : length(args),
p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), compress_args = UnitRange{Int}[],
non_standard_param_layout = false,
wrap_delays = is_dde(sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity,
add_observed = true, obsidxs_to_use = nothing,
Expand Down Expand Up @@ -335,11 +340,6 @@ Base.@nospecializeinfer function build_function_wrapper(
required_arrvars = Set{SymbolicT}()
search_buffer = SU.IRStructureSearchBuffer(ir, required_arrvars)
SU.search_variables!(search_buffer, expr; is_atomic = find_arrvars_is_atomic, recurse = !SU.default_is_atomic)
# TODO: This is only required because `CacheWriter` has its body in `extra_assignments`. Rewrite
# that to use `ArrayMaker` and remove this.
for assign in extra_assignments
SU.search_variables!(search_buffer, assign; is_atomic = find_arrvars_is_atomic, recurse = !SU.default_is_atomic)
end

# assignments for reconstructing scalarized array symbolics
if non_standard_param_layout
Expand Down Expand Up @@ -382,6 +382,14 @@ Base.@nospecializeinfer function build_function_wrapper(
end
end

sort!(compress_args; by = first)
reverse!(compress_args)
for (i, range) in enumerate(compress_args)
compressed = DestructuredArgs(args[range], Symbol(:__compressed, i))
deleteat!(args, range)
insert!(args, first(range), compressed)
end

# add preface assignments
if has_preface(sys) && (pref = preface(sys)) !== nothing
append!(assignments, pref)
Expand Down
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ function __init__()
SU.hashcons(unwrap(ODE_GAMMA[2]), true)
SU.hashcons(unwrap(ODE_GAMMA[3]), true)
SU.hashcons(unwrap(ODE_C), true)
SU.hashcons(SCC_EXPLICITFUN_CACHE_OUT, true)
end

end # module
48 changes: 31 additions & 17 deletions src/problems/sccnonlinearproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,45 @@ end

const SCCCacheVarsExprsElT = Dict{TypeT, Vector{SymbolicT}}

const SCC_EXPLICITFUN_CACHE_OUT = unwrap(only(@parameters __outₘₜₖ::Vector{Vector{Any}}))

function CacheWriter(
sys::AbstractSystem, buffer_types::Vector{TypeT},
exprs::SCCCacheVarsExprsElT, solsyms, obseqs::Vector{Equation};
exprs::SCCCacheVarsExprsElT, solsyms;
eval_expression = false, eval_module = @__MODULE__, cse = true, sparse = false
)
ps = parameters(sys; initial_parameters = true)
rps = reorder_parameters(sys, ps)
obs_assigns = [eq.lhs ← eq.rhs for eq in obseqs]
body = map(eachindex(buffer_types), buffer_types) do i, T
Symbol(:tmp, i) ← SetArray(true, :(out[$i]), get(exprs, T, []))
end

function argument_name(i::Int)
if i <= length(solsyms)
return :($(generated_argument_name(1))[$i])
cache_writes = SymbolicT[]
for (i, T) in enumerate(buffer_types)
regions = SU.RegionsT()
values = Symbolics.SArgsT()
output = SCC_EXPLICITFUN_CACHE_OUT[i]
cacheexprs = get(exprs, T, SymbolicT[])
isempty(cacheexprs) && continue
N = length(cacheexprs)
allocator = Symbolics.STerm(
Returns, Symbolics.SArgsT((output,));
type = SU.FnType{Tuple, Vector{T}, Any}, shape = SU.ShapeVecT((1:N,))
)
for (j, expr) in enumerate(cacheexprs)
push!(regions, SU.ShapeVecT((j:j,)))
push!(values, Symbolics.SConst([expr]))
end
return generated_argument_name(i - length(solsyms))
maker = SU.ArrayMaker{VartypeT}(regions, values; shape = SU.ShapeVecT((1:N,)))
writer = Code.with_allocator(allocator, maker)
push!(cache_writes, writer)
end
array_assignments = array_variable_assignments(solsyms...; argument_name)
body = Symbolics.STerm(
tuple, cache_writes;
type = Vector{Any}, shape = SU.ShapeVecT((1:length(cache_writes),))
)

fn, _ = build_function_wrapper(
sys, nothing, :out,
DestructuredArgs(DestructuredArgs.(solsyms), generated_argument_name(1)),
rps...; p_start = 3, p_end = length(rps) + 2,
expression = Val{true}, add_observed = false, cse,
extra_assignments = [array_assignments; obs_assigns; body],
sys, body, SCC_EXPLICITFUN_CACHE_OUT, solsyms..., rps...;
p_start = length(solsyms) + 2, p_end = length(rps) + length(solsyms) + 1,
compress_args = [2:(length(solsyms) + 1)],
expression = Val{true}, cse,
iip_config = (true, false)
)
fn = eval_or_rgf(fn; eval_expression, eval_module)
Expand Down Expand Up @@ -553,7 +567,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(
push!(
explicitfuns,
CacheWriter(
sys, decomposition.cachetypes, cacheexprs, solsyms, obs[_prevobsidxs];
sys, decomposition.cachetypes, cacheexprs, solsyms;
eval_expression, eval_module, cse
)
)
Expand Down
Loading