-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathModelingToolkitTearing.jl
More file actions
219 lines (190 loc) · 7.16 KB
/
ModelingToolkitTearing.jl
File metadata and controls
219 lines (190 loc) · 7.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
module ModelingToolkitTearing
using DocStringExtensions
import StateSelection
using ModelingToolkitBase
import ModelingToolkitBase as MTKBase
using BipartiteGraphs
using Symbolics
using SymbolicUtils
import SymbolicUtils as SU
using Moshi.Match: @match
using Moshi.Data: @data
import Moshi
using OrderedCollections
using Setfield: @set!, @set
using Graphs
import SciMLBase
import SymbolicIndexingInterface as SII
using SymbolicIndexingInterface
import CommonSolve
import SparseArrays
using StateSelection: SelectedState, CLIL
using ModelingToolkitBase: VariableType, VARIABLE, BROWNIAN, complete, observed
using Symbolics: SymbolicT, VartypeT
using SymbolicUtils: BSImpl, unwrap
using SciMLBase: LinearProblem
using SparseArrays: nonzeros
import LinearAlgebra
import UUIDs: UUID, uuid4
const TimeDomain = SciMLBase.AbstractClock
include("tearingstate.jl")
include("utils.jl")
include("stateselection_interface.jl")
include("trivial_tearing_interface.jl")
"""
$TYPEDEF
Supertype for all reassembling algorithms. A reassembling algorithm takes as input the
`TearingState`, `TearingResult` and integer incidence matrix
`mm::StateSelection.CLIL.SparseMatrixCLIL`. The matrix `mm` may be `nothing`. The
algorithm must also accept arbitrary keyword arguments. The following keyword arguments
will always be provided:
- `fully_determined::Bool`: flag indicating whether the system is fully determined.
The output of a reassembling algorithm must be the torn system.
A reassemble algorithm must also implement `with_fully_determined`
"""
abstract type ReassembleAlgorithm end
include("reassemble.jl")
struct UnhackSystemCacheKey end
function MTKBase.should_invalidate_mutable_cache_entry(::Type{UnhackSystemCacheKey}, patch::NamedTuple)
return true
end
function MTKBase.unhack_system(sys::System)
cached_sys = MTKBase.check_mutable_cache(sys, UnhackSystemCacheKey, System, nothing)
if cached_sys isa System
return cached_sys
end
# Observed are copied by the masking operation
obseqs = observed(sys)
eqs = copy(equations(sys))
obs_mask = trues(length(obseqs))
for (i, eq) in enumerate(obseqs)
obs_mask[i] = @match eq.rhs begin
BSImpl.Term(; f, args) => if f === change_origin
false
elseif f === SU.array_literal
result = true
for (si, ai) in zip(SU.stable_eachindex(eq.lhs), Iterators.drop(eachindex(args), 1))
result &= isequal(eq.lhs[si], args[ai])
result || break
end
!result
else
true
end
_ => true
end
end
obseqs = obseqs[obs_mask]
# Map from ldiv operation to index of the equations where it is the RHS. A
# positive index is into `obseqs`, a negative index is into `eqs`. The variable
# order for the ldiv comes from the LHS of the corresponding equations.
inline_linear_scc_map = Dict{SymbolicT, Vector{Int}}()
for (i, eq) in enumerate(obseqs)
populate_inline_scc_map!(inline_linear_scc_map, eq, i, false)
end
for (i, eq) in enumerate(eqs)
populate_inline_scc_map!(inline_linear_scc_map, eq, i, true)
end
# Now, we want to turn all inlined linear SCCs into algebraic equations. If an element
# of the SCC is a differential variable, we'll introduce the `toterm` as a new algebraic.
# Otherwise, the observed equation is removed.
resize!(obs_mask, length(obseqs))
fill!(obs_mask, true)
additional_eqs = Equation[]
additional_vars = SymbolicT[]
additional_subs = Dict{SymbolicT, SymbolicT}()
# Also need to update schedule
sched = MTKBase.get_schedule(sys)
if sched isa MTKBase.Schedule
sched = copy(sched)
end
for (linsolve, idxs) in inline_linear_scc_map
A, b = @match linsolve begin
BSImpl.Term(; args) => args
end
A = collect(A)::Matrix{SymbolicT}
b = collect(b)::Vector{SymbolicT}
x = Vector{SymbolicT}(undef, length(b))
for (i, idx) in enumerate(idxs)
is_obs = idx > 0
idx = abs(idx)
if is_obs
var = obseqs[idx].lhs
x[i] = var
obs_mask[idx] = false
else
var = MTKBase.default_toterm(eqs[idx].lhs)
if sched isa MTKBase.Schedule
sched.dummy_sub[eqs[idx].lhs] = x[i] = var
end
eqs[idx] = eqs[idx].lhs ~ var
end
push!(additional_vars, var)
additional_subs[linsolve[i]] = x[i]
end
resid = A * x - b
for res in resid
push!(additional_eqs, Symbolics.COMMON_ZERO ~ res)
end
end
@assert length(additional_eqs) == length(additional_vars)
# If a linear SCC contains both `D(w)` and `w_t`, it'll contain the equation `D(w) ~ w_t`.
# When unhacking it, `D(w)` will be totermed into `w_t`. This, `additional_vars` contains
# two `w_t` and an equation that is `0 ~ 0`. Find the `0 ~ 0` equations, and remove them
# along with the duplicate variables.
# See https://github.com/SciML/ModelingToolkit.jl/issues/4196 for further details.
additional_eqs_mask = trues(length(additional_eqs))
for (i, eq) in enumerate(additional_eqs)
additional_eqs_mask[i] = !SU._iszero(eq.rhs)
end
additional_eqs = additional_eqs[additional_eqs_mask]
additional_vars = additional_vars[additional_eqs_mask]
subst = SU.Substituter{false}(additional_subs, SU.default_substitute_filter)
obseqs = obseqs[obs_mask]
map!(subst, obseqs, obseqs)
append!(eqs, additional_eqs)
map!(subst, eqs, eqs)
if sched isa MTKBase.Schedule
map!(subst, values(sched.dummy_sub))
end
dvs = [unknowns(sys); additional_vars]
newsys = @set sys.observed = obseqs
@set! newsys.eqs = eqs
@set! newsys.unknowns = dvs
@set! newsys.schedule = sched
MTKBase.store_to_mutable_cache!(sys, UnhackSystemCacheKey, newsys)
return newsys
end
function populate_inline_scc_map!(
inline_linear_scc_map::Dict{SymbolicT, Vector{Int}},
eq::Equation, eq_i::Int,
is_diffeq::Bool)
is_diffeq && SU.isconst(eq.lhs) && return eq
ldiv, idx, is_ldiv = maybe_extract_inline_linsolve(eq.rhs)
is_ldiv || return
len = length(ldiv)
buffer = get!(() -> zeros(Int, len), inline_linear_scc_map, ldiv)
iszero(buffer[idx]) || return
buffer[idx] = ifelse(is_diffeq, -eq_i, eq_i)
end
function maybe_extract_inline_linsolve(rhs::SymbolicT)
@match rhs begin
BSImpl.Term(; f, args) && if f === getindex && length(args) == 2 end => begin
maybe_ldiv = args[1]
_idx = args[2]
@match maybe_ldiv begin
BSImpl.Term(; f) && if f === INLINE_LINEAR_SCC_OP end => begin
return maybe_ldiv, unwrap_const(_idx)::Int, true
end
_ => nothing
end
end
_ => nothing
end
return Symbolics.COMMON_ZERO, 0, false
end
include("clock_inference/clock.jl")
include("clock_inference/state_machines.jl")
include("clock_inference/clock_inference.jl")
include("clock_inference/interface.jl")
end # module ModelingToolkitTearing