Skip to content

Commit 282f1fb

Browse files
Merge pull request #4511 from SciML/as/better-scc-caching
refactor: use `ArrayMaker` for `CacheWriter` codegen
2 parents 5db718d + 31cd3bd commit 282f1fb

4 files changed

Lines changed: 47 additions & 24 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ StateSelection = "1.9.1"
118118
StaticArrays = "1.9.14"
119119
StochasticDiffEq = "6.82.0, 7"
120120
SymbolicIndexingInterface = "0.3.39"
121-
SymbolicUtils = "4.13"
121+
SymbolicUtils = "4.28"
122122
Symbolics = "7"
123123
UnPack = "0.1, 1.0"
124124
julia = "1.9"

lib/ModelingToolkitBase/src/systems/codegen_utils.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ generated functions, and `args` are the arguments.
256256
`MTKParameters` object are present. These are collapsed into a single argument and
257257
destructured inside the function. `p_start` must also be provided for non-split systems
258258
since it is used by `wrap_delays`.
259+
- `compress_args`: A list of argument ranges that end before `p_start`.
260+
Each range will be compressed into a single argument to the function. For example,
261+
If there are 5 elements in `args` and `compress_args = [2:3]`, then the generated function
262+
will take 4 arguments, where the second should be an indexable collection of the second
263+
and third elements in `args`.
259264
- `wrap_delays`: Whether to transform delayed unknowns of `sys` present in `expr` into
260265
calls to a history function. The history function is added to the list of arguments
261266
right before parameters, at the index `p_start`.
@@ -290,7 +295,7 @@ All other keyword arguments are forwarded to `build_function`.
290295
"""
291296
Base.@nospecializeinfer function build_function_wrapper(
292297
sys::AbstractSystem, @nospecialize(expr), @nospecialize(args...); p_start = 2,
293-
p_end = is_time_dependent(sys) ? length(args) - 1 : length(args),
298+
p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), compress_args = UnitRange{Int}[],
294299
non_standard_param_layout = false,
295300
wrap_delays = is_dde(sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity,
296301
add_observed = true, obsidxs_to_use = nothing,
@@ -335,11 +340,6 @@ Base.@nospecializeinfer function build_function_wrapper(
335340
required_arrvars = Set{SymbolicT}()
336341
search_buffer = SU.IRStructureSearchBuffer(ir, required_arrvars)
337342
SU.search_variables!(search_buffer, expr; is_atomic = find_arrvars_is_atomic, recurse = !SU.default_is_atomic)
338-
# TODO: This is only required because `CacheWriter` has its body in `extra_assignments`. Rewrite
339-
# that to use `ArrayMaker` and remove this.
340-
for assign in extra_assignments
341-
SU.search_variables!(search_buffer, assign; is_atomic = find_arrvars_is_atomic, recurse = !SU.default_is_atomic)
342-
end
343343

344344
# assignments for reconstructing scalarized array symbolics
345345
if non_standard_param_layout
@@ -382,6 +382,14 @@ Base.@nospecializeinfer function build_function_wrapper(
382382
end
383383
end
384384

385+
sort!(compress_args; by = first)
386+
reverse!(compress_args)
387+
for (i, range) in enumerate(compress_args)
388+
compressed = DestructuredArgs(args[range], Symbol(:__compressed, i))
389+
deleteat!(args, range)
390+
insert!(args, first(range), compressed)
391+
end
392+
385393
# add preface assignments
386394
if has_preface(sys) && (pref = preface(sys)) !== nothing
387395
append!(assignments, pref)

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ function __init__()
182182
SU.hashcons(unwrap(ODE_GAMMA[2]), true)
183183
SU.hashcons(unwrap(ODE_GAMMA[3]), true)
184184
SU.hashcons(unwrap(ODE_C), true)
185+
SU.hashcons(SCC_EXPLICITFUN_CACHE_OUT, true)
185186
end
186187

187188
end # module

src/problems/sccnonlinearproblem.jl

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,45 @@ end
88

99
const SCCCacheVarsExprsElT = Dict{TypeT, Vector{SymbolicT}}
1010

11+
const SCC_EXPLICITFUN_CACHE_OUT = unwrap(only(@parameters __outₘₜₖ::Vector{Vector{Any}}))
12+
1113
function CacheWriter(
1214
sys::AbstractSystem, buffer_types::Vector{TypeT},
13-
exprs::SCCCacheVarsExprsElT, solsyms, obseqs::Vector{Equation};
15+
exprs::SCCCacheVarsExprsElT, solsyms;
1416
eval_expression = false, eval_module = @__MODULE__, cse = true, sparse = false
1517
)
1618
ps = parameters(sys; initial_parameters = true)
1719
rps = reorder_parameters(sys, ps)
18-
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
19-
body = map(eachindex(buffer_types), buffer_types) do i, T
20-
Symbol(:tmp, i) SetArray(true, :(out[$i]), get(exprs, T, []))
21-
end
22-
23-
function argument_name(i::Int)
24-
if i <= length(solsyms)
25-
return :($(generated_argument_name(1))[$i])
20+
cache_writes = SymbolicT[]
21+
for (i, T) in enumerate(buffer_types)
22+
regions = SU.RegionsT()
23+
values = Symbolics.SArgsT()
24+
output = SCC_EXPLICITFUN_CACHE_OUT[i]
25+
cacheexprs = get(exprs, T, SymbolicT[])
26+
isempty(cacheexprs) && continue
27+
N = length(cacheexprs)
28+
allocator = Symbolics.STerm(
29+
Returns, Symbolics.SArgsT((output,));
30+
type = SU.FnType{Tuple, Vector{T}, Any}, shape = SU.ShapeVecT((1:N,))
31+
)
32+
for (j, expr) in enumerate(cacheexprs)
33+
push!(regions, SU.ShapeVecT((j:j,)))
34+
push!(values, Symbolics.SConst([expr]))
2635
end
27-
return generated_argument_name(i - length(solsyms))
36+
maker = SU.ArrayMaker{VartypeT}(regions, values; shape = SU.ShapeVecT((1:N,)))
37+
writer = Code.with_allocator(allocator, maker)
38+
push!(cache_writes, writer)
2839
end
29-
array_assignments = array_variable_assignments(solsyms...; argument_name)
40+
body = Symbolics.STerm(
41+
tuple, cache_writes;
42+
type = Vector{Any}, shape = SU.ShapeVecT((1:length(cache_writes),))
43+
)
44+
3045
fn, _ = build_function_wrapper(
31-
sys, nothing, :out,
32-
DestructuredArgs(DestructuredArgs.(solsyms), generated_argument_name(1)),
33-
rps...; p_start = 3, p_end = length(rps) + 2,
34-
expression = Val{true}, add_observed = false, cse,
35-
extra_assignments = [array_assignments; obs_assigns; body],
46+
sys, body, SCC_EXPLICITFUN_CACHE_OUT, solsyms..., rps...;
47+
p_start = length(solsyms) + 2, p_end = length(rps) + length(solsyms) + 1,
48+
compress_args = [2:(length(solsyms) + 1)],
49+
expression = Val{true}, cse,
3650
iip_config = (true, false)
3751
)
3852
fn = eval_or_rgf(fn; eval_expression, eval_module)
@@ -553,7 +567,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(
553567
push!(
554568
explicitfuns,
555569
CacheWriter(
556-
sys, decomposition.cachetypes, cacheexprs, solsyms, obs[_prevobsidxs];
570+
sys, decomposition.cachetypes, cacheexprs, solsyms;
557571
eval_expression, eval_module, cse
558572
)
559573
)

0 commit comments

Comments
 (0)