-
-
Notifications
You must be signed in to change notification settings - Fork 255
Expand file tree
/
Copy pathsccnonlinearproblem.jl
More file actions
657 lines (597 loc) · 24.6 KB
/
sccnonlinearproblem.jl
File metadata and controls
657 lines (597 loc) · 24.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
struct CacheWriter{F}
fn::F
end
function (cw::CacheWriter)(p, sols)
return cw.fn(p.caches, sols, p)
end
const SCCCacheVarsExprsElT = Dict{TypeT, Vector{SymbolicT}}
const SCC_EXPLICITFUN_CACHE_OUT = unwrap(only(@parameters __outₘₜₖ::Vector{Vector{Any}}))
function CacheWriter(
sys::AbstractSystem, buffer_types::Vector{TypeT},
exprs::SCCCacheVarsExprsElT, solsyms;
eval_expression = false, eval_module = @__MODULE__, cse = true, sparse = false
)
ps = parameters(sys; initial_parameters = true)
rps = reorder_parameters(sys, ps)
cache_writes = SymbolicT[]
for (i, T) in enumerate(buffer_types)
regions = SU.RegionsT()
values = Symbolics.SArgsT()
output = SCC_EXPLICITFUN_CACHE_OUT[i]
cacheexprs = get(exprs, T, SymbolicT[])
isempty(cacheexprs) && continue
N = length(cacheexprs)
allocator = Symbolics.STerm(
Returns, Symbolics.SArgsT((output,));
type = SU.FnType{Tuple, Vector{T}, Any}, shape = SU.ShapeVecT((1:N,))
)
for (j, expr) in enumerate(cacheexprs)
push!(regions, SU.ShapeVecT((j:j,)))
push!(values, Symbolics.SConst([expr]))
end
maker = SU.ArrayMaker{VartypeT}(regions, values; shape = SU.ShapeVecT((1:N,)))
writer = Code.with_allocator(allocator, maker)
push!(cache_writes, writer)
end
body = Symbolics.STerm(
tuple, cache_writes;
type = Vector{Any}, shape = SU.ShapeVecT((1:length(cache_writes),))
)
fn, _ = build_function_wrapper(
sys, body, SCC_EXPLICITFUN_CACHE_OUT, solsyms..., rps...;
p_start = length(solsyms) + 2, p_end = length(rps) + length(solsyms) + 1,
compress_args = [2:(length(solsyms) + 1)],
expression = Val{true}, cse,
iip_config = (true, false)
)
fn = eval_or_rgf(fn; eval_expression, eval_module)
fn = GeneratedFunctionWrapper{(3, 3, is_split(sys))}(fn, nothing)
return CacheWriter(fn)
end
"""
$TYPEDSIGNATURES
Subset a system to have only the given unknowns `vscc` and equations `escc`. Observed
equations are subset, delete accordingly. Requires that `sys` is complete and flattened.
# Keyword arguments
- `available_vars`: A list of variables that the subset system should assume are precomputed
or already available. Will be mutated with the unknowns and observables of the subset. This
is useful for SCC decomposition.
"""
function subset_system(
sys::System, vscc::Vector{Int}, escc::Vector{Int};
available_vars = Set{SymbolicT}()
)
check_complete(sys, "subset_system")
@assert isempty(get_systems(sys)) "`subset_system` requires a flattened system"
dvs = unknowns(sys)
ps = parameters(sys)
eqs = full_equations(sys)
# subset unknowns and equations
_dvs = dvs[vscc]
_eqs = eqs[escc]
union!(available_vars, _dvs)
subsys = ConstructionBase.setproperties(
sys; unknowns = _dvs, eqs = _eqs, observed = Equation[],
parameter_bindings_graph = get_parameter_bindings_graph(sys), complete = true
)
if get_index_cache(sys) !== nothing
@set! subsys.index_cache = subset_unknowns_observed(
get_index_cache(sys), sys, _dvs, SymbolicT[],
)
end
cached_param_arr_assigns = check_mutable_cache(
sys, MTKBase.ParameterArrayAssignments, MTKBase.ParameterArrayAssignments, nothing
)
if cached_param_arr_assigns isa MTKBase.ParameterArrayAssignments
store_to_mutable_cache!(
subsys, MTKBase.ParameterArrayAssignments, cached_param_arr_assigns
)
end
return subsys
end
const BlockIdxsT = typeof(BlockVector{Int}(undef_blocks, Int[]))
struct SCCDecomposition
subsystems::Vector{System}
var_sccs::Vector{Vector{Int}}
eq_sccs::Vector{Vector{Int}}
islinear::BitVector
hints::Vector{StructuralHint.Type}
# Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
# dict to maintain a consistent order of buffers across SCCs
cachetypes::Vector{TypeT}
cachesizes::Vector{Int}
# explicitfun! related information for each SCC
# We need to compute buffer sizes before doing any codegen
scc_cachevars::Vector{SCCCacheVarsExprsElT}
scc_cacheexprs::Vector{SCCCacheVarsExprsElT}
end
function SCCDecomposition()
return SCCDecomposition(
System[], Vector{Int}[], Vector{Int}[], BitVector(), StructuralHint.Type[],
TypeT[], Int[], SCCCacheVarsExprsElT[], SCCCacheVarsExprsElT[]
)
end
function SCCDecomposition(
sys::System, var_sccs::Vector{Vector{Int}}, eq_sccs::Vector{Vector{Int}};
combine_sccs = true
)
active_decomposition = SCCDecomposition()
final_decomposition = SCCDecomposition()
available_vars = Set{SymbolicT}()
ts = get_tearing_state(sys)::TearingState
icg = BipartiteGraphs.InducedCondensationGraph(ts.structure.graph, var_sccs)
active = Set{Int}()
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
nbors = Set{Int}(collect(Graphs.inneighbors(icg, i)))
intersect!(nbors, active)
while !isempty(nbors)
finalize_scc!(
final_decomposition, active_decomposition, first(nbors), active, nbors
)
end
subsys = subset_system(
sys, vscc, escc; available_vars
)
push!(active_decomposition.subsystems, subsys)
push!(active_decomposition.var_sccs, vscc)
push!(active_decomposition.eq_sccs, escc)
push!(active_decomposition.hints, StructuralHint.NoHint())
push!(active_decomposition.islinear, calculate_A_b(subsys; throw = false) !== nothing)
push!(active, i)
if !combine_sccs
finalize_scc!(
final_decomposition, active_decomposition, first(active), active, active
)
end
end
while !isempty(active)
finalize_scc!(final_decomposition, active_decomposition, first(active), active, active)
end
return final_decomposition
end
function copy_scc!(dst::SCCDecomposition, src::SCCDecomposition, tgt::Int)
push!(dst.subsystems, src.subsystems[tgt])
push!(dst.var_sccs, src.var_sccs[tgt])
push!(dst.eq_sccs, src.eq_sccs[tgt])
push!(dst.hints, src.hints[tgt])
push!(dst.islinear, src.islinear[tgt])
return dst
end
function finalize_scc!(
final_decomposition::SCCDecomposition, active_decomposition::SCCDecomposition, i::Int,
active::Set{Int}, nbors::Set{Int}; linear_scc_combine_range::Int = 2
)
if !active_decomposition.islinear[i]
copy_scc!(final_decomposition, active_decomposition, i)
delete!(active, i)
delete!(nbors, i)
return
end
subsys = active_decomposition.subsystems[i]
neqs = length(equations(subsys))
if neqs == 1
to_merge = Int[]
for j in active
if active_decomposition.islinear[j] && length(equations(active_decomposition.subsystems[j])) == 1
push!(to_merge, j)
end
end
# The way we run `subset_system`, if a later SCC (later index in `.subsystems`)
# needs an observed from a previous SCC, that equation won't be present in `obsidxs`
# and it will rely on `explicitfun!` for the required value. This means that if we
# don't merge SCCs in sorted order, the observed equations of the resultant bigger
# system won't be topologically sorted.
sort!(to_merge)
if length(to_merge) > 1
active_decomposition.hints[to_merge[1]] = StructuralHint.Diagonal()
end
# Don't remove old to avoid affecting ordering
_collapse_into!(active_decomposition, to_merge[1], @view(to_merge[2:end]))
copy_scc!(final_decomposition, active_decomposition, to_merge[1])
setdiff!(active, to_merge)
setdiff!(nbors, to_merge)
# TODO: Merge this SCC again if it is small
return
end
merge_candidates = Int[]
for j in active
active_decomposition.islinear[j] || continue
neqs_j = length(equations(active_decomposition.subsystems[j]))
neqs_j == 1 && continue
push!(merge_candidates, j)
end
comparator = length ∘ equations ∘ Base.Fix1(getindex, active_decomposition.subsystems)
sort!(merge_candidates; by = comparator)
# mapping from size of SCC to number of times it occurs in the range `low..high`.
# Effectively a sorted multiset.
sizes_in_range = DataStructures.SortedDict{Int, Int}()
while i in active
empty!(sizes_in_range)
low = 1
high = 1
best_low = 1
best_high = 1
while checkbounds(Bool, merge_candidates, high)
neqs_high = comparator(high)
sizes_in_range[neqs_high] = get(sizes_in_range, neqs_high, 0) + 1
absdiff = first(last(sizes_in_range)) - first(first(sizes_in_range))
while absdiff > linear_scc_combine_range
neqs_low = comparator(low)
low += 1
n_low = sizes_in_range[neqs_low] -= 1
if iszero(n_low)
delete!(sizes_in_range, neqs_low)
end
absdiff = first(last(sizes_in_range)) - first(first(sizes_in_range))
end
if (high - low) > (best_high - best_low)
best_high = high
best_low = low
end
high += 1
end
if best_low == best_high
copy_scc!(final_decomposition, active_decomposition, merge_candidates[best_low])
delete!(active, merge_candidates[best_low])
delete!(nbors, merge_candidates[best_low])
deleteat!(merge_candidates, best_low)
continue
end
merge_target = merge_candidates[best_low]
to_merge = view(merge_candidates, (best_low + 1):best_high)
# See the note in the diagonal SCC case above to know why sorting is necessary
sort!(to_merge)
largest_collapsed_scc = max(maximum(comparator, to_merge), comparator(merge_target))
band_size = largest_collapsed_scc - 1
active_decomposition.hints[merge_target] = StructuralHint.Banded(band_size, band_size)
_collapse_into!(active_decomposition, merge_target, to_merge)
copy_scc!(final_decomposition, active_decomposition, merge_target)
setdiff!(active, to_merge)
setdiff!(nbors, to_merge)
delete!(active, merge_target)
delete!(nbors, merge_target)
deleteat!(merge_candidates, best_low:best_high)
end
return nothing
end
function build_caches!(sys::System, decomposition::SCCDecomposition)
banned_vars = Set{SymbolicT}()
state = Dict{SymbolicT, SymbolicT}()
ir = get_irstructure(sys)
for i in eachindex(decomposition.subsystems)
empty!(banned_vars)
empty!(state)
subsys = decomposition.subsystems[i]
union!(banned_vars, unknowns(subsys))
for u in unknowns(subsys)
push!(banned_vars, split_indexed_var(u)[1])
end
# While we own the system and so mutation _should_ be safe, `IRInfo` exists.
# It stores the indices corresponding to equations in the `IRStructure`, which
# would be incorrect if we mutate.
_eqs = copy(get_eqs(subsys))
exprs_to_search = SymbolicT[]
for i in eachindex(_eqs)
push!(exprs_to_search, _eqs[i].rhs)
end
subexpressions_not_involving_vars!(ir, exprs_to_search, banned_vars, state)
subber = SU.IRSubstituter{false}(ir, state; filterer = !SU.default_is_atomic)
for i in eachindex(_eqs)
_eqs[i] = _eqs[i].lhs ~ subber(_eqs[i].rhs)
end
subsys = decomposition.subsystems[i] = ConstructionBase.setproperties(subsys; eqs = _eqs)
if decomposition.islinear[i]
store_to_mutable_cache!(subsys, CachedLinearAb, nothing)
# cached_ab = check_mutable_cache(
# subsys, CachedLinearAb, CachedLinearAb, nothing
# )
# if cached_ab isa CachedLinearAb
# subber = SU.Substituter{false}(state)
# I, J, V = findnz(cached_ab.A)
# map!(subber, V, V)
# map!(subber, cached_ab.b, cached_ab.b)
# end
end
# map from symtype to cached variables and their expressions
cachevars = SCCCacheVarsExprsElT()
cacheexprs = SCCCacheVarsExprsElT()
push!(decomposition.scc_cachevars, cachevars)
push!(decomposition.scc_cacheexprs, cacheexprs)
for (k, v) in state
k = unwrap(k)
v = unwrap(v)
T = symtype(k)
buf = get!(() -> SymbolicT[], cachevars, T)
push!(buf, v)
buf = get!(() -> SymbolicT[], cacheexprs, T)
push!(buf, k)
end
all_cacheexprs = reduce(vcat, values(cacheexprs); init = SymbolicT[])
# update the sizes of cache buffers
for (T, buf) in cachevars
idx = findfirst(isequal(T), decomposition.cachetypes)
if idx === nothing
push!(decomposition.cachetypes, T)
push!(decomposition.cachesizes, 0)
idx = lastindex(decomposition.cachetypes)
end
decomposition.cachesizes[idx] = max(decomposition.cachesizes[idx], length(buf))
end
end
end
"""
$TYPEDSIGNATURES
Make SCC `i` a combination of SCC `i` and SCCs in `js`, where `js` is an iterable of
integers. Only modifies SCC `i`. Does not change the number of SCCs stored.
"""
function _collapse_into!(decomposition::SCCDecomposition, i::Int, js)
parent = decomposition.subsystems[i]
new_eqs = copy(equations(parent))
new_dvs = copy(unknowns(parent))
cached_ab::Union{CachedLinearAb, Nothing} = if decomposition.islinear[i]
calculate_A_b(parent)
check_mutable_cache(parent, CachedLinearAb, CachedLinearAb, nothing)::CachedLinearAb
else
nothing
end
for j in js
cur = decomposition.subsystems[j]
append!(new_eqs, equations(cur))
append!(new_dvs, unknowns(cur))
decomposition.islinear[i] &= decomposition.islinear[j]
if cached_ab isa CachedLinearAb && decomposition.islinear[j]
A = cached_ab.A
b = cached_ab.b
calculate_A_b(cur)
jcache = check_mutable_cache(cur, CachedLinearAb, CachedLinearAb, nothing)::CachedLinearAb
A = blockdiag(A, jcache.A)
b = vcat(b, jcache.b)
cached_ab = CachedLinearAb(A, b)
end
append!(decomposition.var_sccs[i], decomposition.var_sccs[j])
append!(decomposition.eq_sccs[i], decomposition.eq_sccs[j])
end
new_parent = decomposition.subsystems[i] = ConstructionBase.setproperties(
parent; eqs = new_eqs, unknowns = new_dvs,
)
if cached_ab isa CachedLinearAb
store_to_mutable_cache!(new_parent, CachedLinearAb, cached_ab)
end
return nothing
end
struct SCCNonlinearFunction{iip} end
function SCCNonlinearFunction{iip}(
decomposition::SCCDecomposition, i::Int, cachesyms, op; eval_expression = false,
eval_module = @__MODULE__, cse = true, kwargs...
) where {iip}
subsys = decomposition.subsystems[i]
islin = decomposition.islinear[i]
# generate linear problem instead
if islin
return LinearFunction{iip}(
subsys; eval_expression, eval_module, cse, cachesyms,
structural_hint = decomposition.hints[i], kwargs...
)
end
rps = reorder_parameters(subsys)
f = generate_rhs(
subsys; expression = Val{false}, wrap_gfw = Val{true}, cachesyms,
eval_expression, eval_module,
)
return NonlinearFunction{iip}(f; sys = subsys)
end
function SciMLBase.SCCNonlinearProblem(sys::System, args...; kwargs...)
return SCCNonlinearProblem{true}(sys, args...; kwargs...)
end
function SciMLBase.SCCNonlinearProblem{iip}(
sys::System, op; eval_expression = false,
eval_module = @__MODULE__, cse = true, u0_constructor = identity,
missing_guess_value = default_missing_guess_value(), combine_sccs = true, kwargs...
) where {iip}
if !iscomplete(sys) || get_tearing_state(sys) === nothing
error("A simplified `System` is required. Call `mtkcompile` on the system before creating an `SCCNonlinearProblem`.")
end
if !is_split(sys)
error("The system has been simplified with `split = false`. `SCCNonlinearProblem` is not compatible with this system. Pass `split = true` to `mtkcompile` to use `SCCNonlinearProblem`.")
end
ts = get_tearing_state(sys)
sched = get_schedule(sys)
if sched === nothing
@warn "System is simplified but does not have a schedule. This should not happen."
var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts)
condensed_graph = MatchedCondensationGraph(
DiCMOBiGraph{true}(
complete(ts.structure.graph),
complete(var_eq_matching)
),
var_sccs
)
toporder = topological_sort_by_dfs(condensed_graph)
var_sccs = var_sccs[toporder]
eq_sccs = map(Base.Fix1(getindex, var_eq_matching), var_sccs)
else
var_sccs = map(copy, sched.var_sccs)
# Equations are already in the order of SCCs
eq_sccs = length.(var_sccs)
cumsum!(eq_sccs, eq_sccs)
eq_sccs = map(enumerate(eq_sccs)) do (i, lasti)
i == 1 ? collect(1:lasti) : collect((eq_sccs[i - 1] + 1):lasti)
end
end
if length(var_sccs) == 1
if calculate_A_b(sys; throw = false) !== nothing
linprob = LinearProblem{iip}(
sys, op; eval_expression, eval_module,
u0_constructor, cse, kwargs...
)
# Required for filling missing parameter values when this is an initialization
# problem
if state_values(linprob) === nothing
linprob = remake(
linprob;
u0 = u0_constructor(ones(eltype(linprob.A), size(linprob.A, 2)))
)
end
return SCCNonlinearProblem((linprob,), (Returns(nothing),), parameter_values(linprob), true; sys)
else
return NonlinearProblem{iip}(
sys, op; eval_expression, eval_module, u0_constructor, cse, missing_guess_value, kwargs...
)
end
end
dvs = unknowns(sys)
ps = parameters(sys)
eqs = equations(sys)
_, u0, p = process_SciMLProblem(
EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, symbolic_u0 = true,
missing_guess_value, kwargs...
)
op = calculate_op_from_u0_p(sys, u0, p)
explicitfuns = []
nlfuns = []
decomposition = SCCDecomposition(sys, var_sccs, eq_sccs; combine_sccs)
# Invalidate the SCC information - `decomposition` is the source of truth now
var_sccs = nothing
eq_sccs = nothing
build_caches!(sys, decomposition)
for i in eachindex(decomposition.subsystems)
cachevars = decomposition.scc_cachevars[i]
cacheexprs = decomposition.scc_cacheexprs[i]
subsys = decomposition.subsystems[i]
if isempty(cachevars)
push!(explicitfuns, Returns(nothing))
else
solsyms = view.((dvs,), view(decomposition.var_sccs, 1:(i - 1)))
push!(
explicitfuns,
CacheWriter(
sys, decomposition.cachetypes, cacheexprs, solsyms;
eval_expression, eval_module, cse
)
)
end
cachebufsyms = Vector{SymbolicT}[]
for T in decomposition.cachetypes
push!(cachebufsyms, get(cachevars, T, SymbolicT[]))
end
f = SCCNonlinearFunction{iip}(
decomposition, i, cachebufsyms, op;
eval_expression, eval_module, cse, kwargs...
)
push!(nlfuns, f)
end
u0_eltype = Union{}
for x in u0
symbolic_type(x) == NotSymbolic() || continue
u0_eltype = typeof(x)
break
end
if u0_eltype === Union{} || u0_eltype === Nothing
u0_eltype = Float64
end
u0_eltype = float(u0_eltype)
if !isempty(decomposition.cachetypes)
templates = map(decomposition.cachetypes, decomposition.cachesizes) do T, n
# Real refers to `eltype(u0)`
if T == Real
T = u0_eltype
elseif T <: Array && eltype(T) == Real
T = Array{u0_eltype, ndims(T)}
end
BufferTemplate(T, n)
end
p = rebuild_with_caches(p, templates...)
end
# yes, `get_p_constructor` since this is only used for `LinearProblem` and
# will retain the shape of `A`
u0_constructor = get_p_constructor(u0_constructor, typeof(u0), u0_eltype)
subprobs = []
subber = Symbolics.FixpointSubstituter{true}(AtomicArrayDictSubstitutionWrapper(op))
for (i, (f, vscc)) in enumerate(zip(nlfuns, decomposition.var_sccs))
_u0 = SymbolicUtils.Code.create_array(
typeof(u0), eltype(u0), Val(1), Val(length(vscc)), u0[vscc]...
)
symbolic_idxs = findall(x -> x === nothing || symbolic_type(x) !== NotSymbolic(), _u0)
if f isa LinearFunction
_u0 = isempty(symbolic_idxs) ? _u0 : zeros(u0_eltype, length(_u0))
_u0 = u0_constructor(u0_eltype.(_u0))
cachevars = decomposition.scc_cachevars[i]
cacheexprs = decomposition.scc_cacheexprs[i]
for T in keys(cachevars)
for (var, expr) in zip(cachevars[T], cacheexprs[T])
isequal(var, expr) && continue
has_possibly_indexed_key(op, var) && continue
write_possibly_indexed_array!(op, var, expr, COMMON_NOTHING)
end
end
symbolic_interface = f.interface
A, b = get_A_b_from_LinearFunction(
sys, f, subber; eval_expression, eval_module, u0_constructor, u0_eltype
)
for (j, val) in zip(vscc, _u0)
write_possibly_indexed_array!(op, dvs[j], Symbolics.SConst(val), COMMON_NOTHING)
end
prob = LinearProblem{iip}(A, b, p; f = symbolic_interface, u0 = _u0)
else
if !isempty(symbolic_idxs)
Moshi.Match.@match missing_guess_value begin
MissingGuessValue.Constant(val) => begin
_u0[symbolic_idxs] .= val
_u0 = unwrap_const.(_u0)
cval = Symbolics.SConst(val)
for j in symbolic_idxs
write_possibly_indexed_array!(op, dvs[vscc[j]], cval, COMMON_NOTHING)
end
end
MissingGuessValue.Random(rng) => begin
newval = rand(rng, length(symbolic_idxs))
_u0[symbolic_idxs] .= newval
for (idx, j) in enumerate(symbolic_idxs)
write_possibly_indexed_array!(
op, dvs[vscc[j]], Symbolics.SConst(newval[idx]), COMMON_NOTHING
)
end
end
MissingGuessValue.Error() => throw(MissingGuessError(dvs[vscc], _u0))
end
end
_u0 = u0_constructor(u0_eltype.(_u0))
prob = NonlinearProblem(f, _u0, p)
end
push!(subprobs, prob)
end
new_dvs = dvs[reduce(vcat, decomposition.var_sccs)]
new_eqs = eqs[reduce(vcat, decomposition.eq_sccs)]
sys = ConstructionBase.setproperties(
sys; unknowns = new_dvs, eqs = new_eqs, index_cache = subset_unknowns_observed(
get_index_cache(sys), sys, new_dvs, SymbolicT[]
)
)
if length(subprobs) <= 5
return SCCNonlinearProblem(Tuple(subprobs), Tuple(explicitfuns), p, true; sys)
else
return SCCNonlinearProblem(subprobs, explicitfuns, p, true; sys)
end
end
function calculate_op_from_u0_p(sys::System, u0::Union{Nothing, AbstractVector}, p::MTKParameters)
op = SymmapT()
if u0 !== nothing
for (var, val) in zip(unknowns(sys), u0)
val === nothing && continue
write_possibly_indexed_array!(op, var, Symbolics.SConst(val), COMMON_NOTHING)
end
end
rps = reorder_parameters(sys)
@assert length(rps) == length(p)
for (i, pvars) in enumerate(rps)
for (var, val) in zip(pvars, p[i])
write_possibly_indexed_array!(op, var, Symbolics.SConst(val), COMMON_NOTHING)
end
end
_ss = unhack_system(sys)
for eq in observed(_ss)
write_possibly_indexed_array!(op, eq.lhs, eq.rhs, COMMON_NOTHING)
end
merge!(op, bindings(sys))
return op
end