|
8 | 8 |
|
9 | 9 | const SCCCacheVarsExprsElT = Dict{TypeT, Vector{SymbolicT}} |
10 | 10 |
|
| 11 | +const SCC_EXPLICITFUN_CACHE_OUT = unwrap(only(@parameters __outₘₜₖ::Vector{Vector{Any}})) |
| 12 | + |
11 | 13 | function CacheWriter( |
12 | 14 | sys::AbstractSystem, buffer_types::Vector{TypeT}, |
13 | | - exprs::SCCCacheVarsExprsElT, solsyms, obseqs::Vector{Equation}; |
| 15 | + exprs::SCCCacheVarsExprsElT, solsyms; |
14 | 16 | eval_expression = false, eval_module = @__MODULE__, cse = true, sparse = false |
15 | 17 | ) |
16 | 18 | ps = parameters(sys; initial_parameters = true) |
17 | 19 | rps = reorder_parameters(sys, ps) |
18 | | - obs_assigns = [eq.lhs ← eq.rhs for eq in obseqs] |
19 | | - body = map(eachindex(buffer_types), buffer_types) do i, T |
20 | | - Symbol(:tmp, i) ← SetArray(true, :(out[$i]), get(exprs, T, [])) |
21 | | - end |
22 | | - |
23 | | - function argument_name(i::Int) |
24 | | - if i <= length(solsyms) |
25 | | - return :($(generated_argument_name(1))[$i]) |
| 20 | + cache_writes = SymbolicT[] |
| 21 | + for (i, T) in enumerate(buffer_types) |
| 22 | + regions = SU.RegionsT() |
| 23 | + values = Symbolics.SArgsT() |
| 24 | + output = SCC_EXPLICITFUN_CACHE_OUT[i] |
| 25 | + cacheexprs = get(exprs, T, SymbolicT[]) |
| 26 | + isempty(cacheexprs) && continue |
| 27 | + N = length(cacheexprs) |
| 28 | + allocator = Symbolics.STerm( |
| 29 | + Returns, Symbolics.SArgsT((output,)); |
| 30 | + type = SU.FnType{Tuple, Vector{T}, Any}, shape = SU.ShapeVecT((1:N,)) |
| 31 | + ) |
| 32 | + for (j, expr) in enumerate(cacheexprs) |
| 33 | + push!(regions, SU.ShapeVecT((j:j,))) |
| 34 | + push!(values, Symbolics.SConst([expr])) |
26 | 35 | end |
27 | | - return generated_argument_name(i - length(solsyms)) |
| 36 | + maker = SU.ArrayMaker{VartypeT}(regions, values; shape = SU.ShapeVecT((1:N,))) |
| 37 | + writer = Code.with_allocator(allocator, maker) |
| 38 | + push!(cache_writes, writer) |
28 | 39 | end |
29 | | - array_assignments = array_variable_assignments(solsyms...; argument_name) |
| 40 | + body = Symbolics.STerm( |
| 41 | + tuple, cache_writes; |
| 42 | + type = Vector{Any}, shape = SU.ShapeVecT((1:length(cache_writes),)) |
| 43 | + ) |
| 44 | + |
30 | 45 | fn, _ = build_function_wrapper( |
31 | | - sys, nothing, :out, |
32 | | - DestructuredArgs(DestructuredArgs.(solsyms), generated_argument_name(1)), |
33 | | - rps...; p_start = 3, p_end = length(rps) + 2, |
34 | | - expression = Val{true}, add_observed = false, cse, |
35 | | - extra_assignments = [array_assignments; obs_assigns; body], |
| 46 | + sys, body, SCC_EXPLICITFUN_CACHE_OUT, solsyms..., rps...; |
| 47 | + p_start = length(solsyms) + 2, p_end = length(rps) + length(solsyms) + 1, |
| 48 | + compress_args = [2:(length(solsyms) + 1)], |
| 49 | + expression = Val{true}, cse, |
36 | 50 | iip_config = (true, false) |
37 | 51 | ) |
38 | 52 | fn = eval_or_rgf(fn; eval_expression, eval_module) |
@@ -553,7 +567,7 @@ function SciMLBase.SCCNonlinearProblem{iip}( |
553 | 567 | push!( |
554 | 568 | explicitfuns, |
555 | 569 | CacheWriter( |
556 | | - sys, decomposition.cachetypes, cacheexprs, solsyms, obs[_prevobsidxs]; |
| 570 | + sys, decomposition.cachetypes, cacheexprs, solsyms; |
557 | 571 | eval_expression, eval_module, cse |
558 | 572 | ) |
559 | 573 | ) |
|
0 commit comments