diff --git a/Project.toml b/Project.toml index 6be683ce47..e295e20a65 100644 --- a/Project.toml +++ b/Project.toml @@ -118,7 +118,7 @@ StateSelection = "1.9.1" StaticArrays = "1.9.14" StochasticDiffEq = "6.82.0, 7" SymbolicIndexingInterface = "0.3.39" -SymbolicUtils = "4.13" +SymbolicUtils = "4.28" Symbolics = "7" UnPack = "0.1, 1.0" julia = "1.9" diff --git a/lib/ModelingToolkitBase/src/systems/codegen_utils.jl b/lib/ModelingToolkitBase/src/systems/codegen_utils.jl index 9513b939cd..95ea45da94 100644 --- a/lib/ModelingToolkitBase/src/systems/codegen_utils.jl +++ b/lib/ModelingToolkitBase/src/systems/codegen_utils.jl @@ -256,6 +256,11 @@ generated functions, and `args` are the arguments. `MTKParameters` object are present. These are collapsed into a single argument and destructured inside the function. `p_start` must also be provided for non-split systems since it is used by `wrap_delays`. +- `compress_args`: A list of argument ranges that end before `p_start`. + Each range will be compressed into a single argument to the function. For example, + If there are 5 elements in `args` and `compress_args = [2:3]`, then the generated function + will take 4 arguments, where the second should be an indexable collection of the second + and third elements in `args`. - `wrap_delays`: Whether to transform delayed unknowns of `sys` present in `expr` into calls to a history function. The history function is added to the list of arguments right before parameters, at the index `p_start`. @@ -290,7 +295,7 @@ All other keyword arguments are forwarded to `build_function`. """ Base.@nospecializeinfer function build_function_wrapper( sys::AbstractSystem, @nospecialize(expr), @nospecialize(args...); p_start = 2, - p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), + p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), compress_args = UnitRange{Int}[], non_standard_param_layout = false, wrap_delays = is_dde(sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity, add_observed = true, obsidxs_to_use = nothing, @@ -335,11 +340,6 @@ Base.@nospecializeinfer function build_function_wrapper( required_arrvars = Set{SymbolicT}() search_buffer = SU.IRStructureSearchBuffer(ir, required_arrvars) SU.search_variables!(search_buffer, expr; is_atomic = find_arrvars_is_atomic, recurse = !SU.default_is_atomic) - # TODO: This is only required because `CacheWriter` has its body in `extra_assignments`. Rewrite - # that to use `ArrayMaker` and remove this. - for assign in extra_assignments - SU.search_variables!(search_buffer, assign; is_atomic = find_arrvars_is_atomic, recurse = !SU.default_is_atomic) - end # assignments for reconstructing scalarized array symbolics if non_standard_param_layout @@ -382,6 +382,14 @@ Base.@nospecializeinfer function build_function_wrapper( end end + sort!(compress_args; by = first) + reverse!(compress_args) + for (i, range) in enumerate(compress_args) + compressed = DestructuredArgs(args[range], Symbol(:__compressed, i)) + deleteat!(args, range) + insert!(args, first(range), compressed) + end + # add preface assignments if has_preface(sys) && (pref = preface(sys)) !== nothing append!(assignments, pref) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 19bc700463..4812820c04 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -182,6 +182,7 @@ function __init__() SU.hashcons(unwrap(ODE_GAMMA[2]), true) SU.hashcons(unwrap(ODE_GAMMA[3]), true) SU.hashcons(unwrap(ODE_C), true) + SU.hashcons(SCC_EXPLICITFUN_CACHE_OUT, true) end end # module diff --git a/src/problems/sccnonlinearproblem.jl b/src/problems/sccnonlinearproblem.jl index 0cd48749e1..d1be37ecf3 100644 --- a/src/problems/sccnonlinearproblem.jl +++ b/src/problems/sccnonlinearproblem.jl @@ -8,31 +8,45 @@ end const SCCCacheVarsExprsElT = Dict{TypeT, Vector{SymbolicT}} +const SCC_EXPLICITFUN_CACHE_OUT = unwrap(only(@parameters __outₘₜₖ::Vector{Vector{Any}})) + function CacheWriter( sys::AbstractSystem, buffer_types::Vector{TypeT}, - exprs::SCCCacheVarsExprsElT, solsyms, obseqs::Vector{Equation}; + exprs::SCCCacheVarsExprsElT, solsyms; eval_expression = false, eval_module = @__MODULE__, cse = true, sparse = false ) ps = parameters(sys; initial_parameters = true) rps = reorder_parameters(sys, ps) - obs_assigns = [eq.lhs ← eq.rhs for eq in obseqs] - body = map(eachindex(buffer_types), buffer_types) do i, T - Symbol(:tmp, i) ← SetArray(true, :(out[$i]), get(exprs, T, [])) - end - - function argument_name(i::Int) - if i <= length(solsyms) - return :($(generated_argument_name(1))[$i]) + cache_writes = SymbolicT[] + for (i, T) in enumerate(buffer_types) + regions = SU.RegionsT() + values = Symbolics.SArgsT() + output = SCC_EXPLICITFUN_CACHE_OUT[i] + cacheexprs = get(exprs, T, SymbolicT[]) + isempty(cacheexprs) && continue + N = length(cacheexprs) + allocator = Symbolics.STerm( + Returns, Symbolics.SArgsT((output,)); + type = SU.FnType{Tuple, Vector{T}, Any}, shape = SU.ShapeVecT((1:N,)) + ) + for (j, expr) in enumerate(cacheexprs) + push!(regions, SU.ShapeVecT((j:j,))) + push!(values, Symbolics.SConst([expr])) end - return generated_argument_name(i - length(solsyms)) + maker = SU.ArrayMaker{VartypeT}(regions, values; shape = SU.ShapeVecT((1:N,))) + writer = Code.with_allocator(allocator, maker) + push!(cache_writes, writer) end - array_assignments = array_variable_assignments(solsyms...; argument_name) + body = Symbolics.STerm( + tuple, cache_writes; + type = Vector{Any}, shape = SU.ShapeVecT((1:length(cache_writes),)) + ) + fn, _ = build_function_wrapper( - sys, nothing, :out, - DestructuredArgs(DestructuredArgs.(solsyms), generated_argument_name(1)), - rps...; p_start = 3, p_end = length(rps) + 2, - expression = Val{true}, add_observed = false, cse, - extra_assignments = [array_assignments; obs_assigns; body], + sys, body, SCC_EXPLICITFUN_CACHE_OUT, solsyms..., rps...; + p_start = length(solsyms) + 2, p_end = length(rps) + length(solsyms) + 1, + compress_args = [2:(length(solsyms) + 1)], + expression = Val{true}, cse, iip_config = (true, false) ) fn = eval_or_rgf(fn; eval_expression, eval_module) @@ -553,7 +567,7 @@ function SciMLBase.SCCNonlinearProblem{iip}( push!( explicitfuns, CacheWriter( - sys, decomposition.cachetypes, cacheexprs, solsyms, obs[_prevobsidxs]; + sys, decomposition.cachetypes, cacheexprs, solsyms; eval_expression, eval_module, cse ) )