Skip to content

Commit 0202075

Browse files
Merge pull request #56 from JuliaComputing/as/fix-trivial-tearing
feat: implement new `trivial_tearing!` interface for MTK, update `mm`
2 parents b86c0fe + e5eae15 commit 0202075

5 files changed

Lines changed: 256 additions & 22 deletions

File tree

lib/ModelingToolkitTearing/src/stateselection_interface.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,38 @@ function manual_dispatch_is_small_int(@nospecialize(x::Number))::Int
326326
end
327327
end
328328
end
329+
330+
function StateSelection.rm_eqs_vars!(
331+
structure::SystemStructure, eqs_to_rm::Vector{Int}, vars_to_rm::Vector{Int};
332+
eqs_sorted_and_uniqued::Bool = false, vars_sorted_and_uniqued::Bool = false
333+
)
334+
old_to_new_eq, old_to_new_var = StateSelection.default_rm_eqs_vars!(
335+
structure, eqs_to_rm, vars_to_rm; eqs_sorted_and_uniqued, vars_sorted_and_uniqued
336+
)
337+
deleteat!(structure.state_priorities, vars_to_rm)
338+
deleteat!(structure.var_types, vars_to_rm)
339+
return old_to_new_eq, old_to_new_var
340+
end
341+
342+
function StateSelection.rm_eqs_vars!(
343+
state::TearingState, eqs_to_rm::Vector{Int}, vars_to_rm::Vector{Int};
344+
eqs_sorted_and_uniqued::Bool = false, vars_sorted_and_uniqued::Bool = false
345+
)
346+
(; structure, sys) = state
347+
old_to_new_eq, old_to_new_var = StateSelection.rm_eqs_vars!(
348+
structure, eqs_to_rm, vars_to_rm; eqs_sorted_and_uniqued, vars_sorted_and_uniqued
349+
)
350+
deleteat!(state.fullvars, vars_to_rm)
351+
eqs = copy(MTKBase.get_eqs(state.sys))
352+
deleteat!(eqs, eqs_to_rm)
353+
deleteat!(state.original_eqs, eqs_to_rm)
354+
if !isempty(state.eqs_source)
355+
deleteat!(state.eqs_source, eqs_to_rm)
356+
end
357+
358+
@set! sys.eqs = eqs
359+
state.sys = sys
360+
361+
return old_to_new_eq, old_to_new_var
362+
end
363+

lib/ModelingToolkitTearing/src/trivial_tearing_interface.jl

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,8 @@ function StateSelection.possibly_explicit_equations(state::TearingState)
2323
return Iterators.map(mapfn, Iterators.filter(filterfn, eachindex(state.original_eqs)))
2424
end
2525

26-
function StateSelection.trivial_tearing_postprocess!(ts::TearingState, torn_eqs::OrderedSet{Int}, torn_vars::OrderedSet{Int})
27-
append!(ts.additional_observed, @view ts.original_eqs[collect(torn_eqs)])
28-
sort!(torn_vars)
29-
sort!(torn_eqs)
30-
if ts.structure.var_types !== nothing
31-
deleteat!(ts.structure.var_types, torn_vars)
32-
end
33-
deleteat!(ts.fullvars, torn_vars)
34-
deleteat!(ts.structure.state_priorities, torn_vars)
35-
deleteat!(ts.original_eqs, torn_eqs)
36-
sys = ts.sys
37-
eqs = copy(MTKBase.get_eqs(sys))
38-
deleteat!(eqs, torn_eqs)
39-
@set! sys.eqs = eqs
40-
ts.sys = sys
26+
function StateSelection.trivial_tearing_postprocess!(ts::TearingState, torn_eqs::Vector{Int}, torn_vars::Vector{Int})
27+
append!(ts.additional_observed, @view ts.original_eqs[torn_eqs])
4128
return ts
4229
end
4330

src/tearing.jl

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ Preemptively identify observed equations in the system and tear them. This happe
3434
any simplification. The equations torn by this process are ones that are already given in
3535
an explicit form in the system and where the LHS is not present in any other equation of
3636
the system except for other such preempitvely torn equations.
37+
38+
Optionally passing in the linear coefficient matrix `mm` will update it according to the
39+
newly torn equations.
3740
"""
3841
function trivial_tearing!(ts::TransformationState, mm::Union{SparseMatrixCLIL, Nothing} = nothing)
3942
if is_only_discrete(ts.structure)
@@ -95,33 +98,98 @@ function trivial_tearing!(ts::TransformationState, mm::Union{SparseMatrixCLIL, N
9598
# if we didn't add an equation this iteration, we won't add one next iteration
9699
added_equation || break
97100
end
101+
torn_vars_idxs = collect(matched_vars)
102+
torn_eqs_idxs = collect(trivial_idxs)
98103

99104
# For backward compatibility
100105
if hasmethod(rm_eqs_vars!, Tuple{typeof(ts), Vector{Int}, Vector{Int}})
101106
# `deleteat!` requires sorted indices, but we want to maintain relative order to pass
102107
# to `trivial_tearing_postprocess!`
103-
torn_vars_idxs = collect(matched_vars)
104-
torn_eqs_idxs = collect(trivial_idxs)
105-
trivial_tearing_postprocess!(ts, trivial_idxs, matched_vars)
108+
trivial_tearing_postprocess!(ts, torn_eqs_idxs, torn_vars_idxs)
106109
sort!(torn_eqs_idxs)
107110
sort!(torn_vars_idxs)
108-
rm_eqs_vars!(
111+
old_to_new_eq, old_to_new_var = rm_eqs_vars!(
109112
ts, torn_eqs_idxs, torn_vars_idxs; eqs_sorted_and_uniqued = true,
110113
vars_sorted_and_uniqued = true
111114
)
112115
else
113116
# `deleteat!` requires sorted indices, but we want to maintain relative order to pass
114117
# to `trivial_tearing_postprocess!`
115-
torn_vars_idxs = collect(matched_vars)
116118
sort!(torn_vars_idxs)
117-
torn_eqs_idxs = collect(trivial_idxs)
118119
sort!(torn_eqs_idxs)
119-
rm_eqs_vars!(
120+
old_to_new_eq, old_to_new_var = rm_eqs_vars!(
120121
ts.structure, torn_eqs_idxs, torn_vars_idxs; eqs_sorted_and_uniqued = true,
121122
vars_sorted_and_uniqued = true
122123
)
123124
trivial_tearing_postprocess!(ts, trivial_idxs, matched_vars)
124125
end
126+
if mm !== nothing
127+
aliases = Dict{Int, SparseArrays.SparseVector{eltype(mm), Int}}()
128+
linear_eqs = Dict{Int, Int}()
129+
for (i, eq) in enumerate(mm.nzrows)
130+
linear_eqs[eq] = i
131+
end
132+
torn_vars_set = BitSet(torn_vars_idxs)
133+
perm = Int[]
134+
for (var, eq) in zip(torn_vars_idxs, torn_eqs_idxs)
135+
ieq = get(linear_eqs, eq, 0)
136+
iszero(ieq) && continue
137+
# `trivial_tearing!` considers equations already written by the user in a
138+
# solvable form, so we know that the coefficient of `var` must be `-1` and do
139+
# not need to be concerned with fractions.
140+
eq_vars = mm.row_cols[ieq]
141+
eq_coeffs = mm.row_vals[ieq]
142+
143+
I = Int[]
144+
V = Int[]
145+
sizehint!(I, length(eq_vars))
146+
sizehint!(V, length(eq_vars))
147+
can_be_aliased = true
148+
for (v, cf) in zip(eq_vars, eq_coeffs)
149+
if v == var
150+
@assert cf == -1
151+
continue
152+
end
153+
alias = get(aliases, v, nothing)
154+
if alias === nothing
155+
# If this variable is not aliased to a linear combination,
156+
# and is torn, then this equation is no longer a linear equation
157+
# that can be retained in `mm`.
158+
if v in torn_vars_set
159+
can_be_aliased = false
160+
break
161+
end
162+
push!(I, v)
163+
push!(V, cf)
164+
continue
165+
end
166+
_I, _V = SparseArrays.findnz(alias)
167+
append!(I, _I)
168+
append!(V, Iterators.map(Base.Fix1(*, cf), _V))
169+
end
170+
can_be_aliased || continue
171+
172+
resize!(perm, length(I))
173+
sortperm!(perm, I)
174+
I = I[perm]
175+
V = V[perm]
176+
i = 1
177+
for j in Iterators.drop(eachindex(I), 1)
178+
if I[j] == I[i]
179+
V[i] += V[j]
180+
else
181+
i += 1
182+
I[i] = I[j]
183+
V[i] = V[j]
184+
end
185+
end
186+
187+
aliases[var] = SparseArrays.SparseVector(size(mm, 2), I, V)
188+
end
189+
mm = get_new_mm(aliases, old_to_new_eq, old_to_new_var, mm)
190+
return ts, mm
191+
end
192+
125193
return ts
126194
end
127195

src/utils.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,120 @@ function sorted_incidence_matrix(ts::TransformationState, val = true; only_algeq
245245
end
246246
SparseArrays.sparse(I, J, val, nsrcs(graph), ndsts(graph))
247247
end
248+
249+
function add_row_coeffs!(
250+
row_col_i::Vector{Int}, row_val_i::Vector{T}, old_to_new_var::Vector{Int},
251+
aliases::Dict{Int, Int}, old_var::Int, coeff::T
252+
) where {T <: Integer}
253+
alias = get(aliases, old_var, 0)
254+
iszero(alias) && return
255+
push!(row_col_i, old_to_new_var[alias])
256+
push!(row_val_i, coeff)
257+
return nothing
258+
end
259+
260+
function add_row_coeffs!(
261+
row_col_i::Vector{Int}, row_val_i::Vector{T}, old_to_new_var::Vector{Int},
262+
aliases::Dict{Int, SparseArrays.SparseVector{T, Int}}, old_var::Int, coeff::T
263+
) where {T <: Integer}
264+
alias = get(aliases, old_var, nothing)
265+
if alias isa SparseArrays.SparseVector{T, Int}
266+
I, V = SparseArrays.findnz(alias)
267+
for (i, v) in zip(I, V)
268+
iszero(old_to_new_var[i])
269+
end
270+
append!(row_col_i, Iterators.map(Base.Fix1(getindex, old_to_new_var), I))
271+
append!(row_val_i, Iterators.map(Base.Fix1(*, coeff), V))
272+
end
273+
return nothing
274+
end
275+
276+
"""
277+
$TYPEDSIGNATURES
278+
279+
Construct the new coefficient matrix with:
280+
281+
- Some equations removed, as indicated by `old_to_new_eq` which is obtained from
282+
`get_old_to_new_idxs`.
283+
- Some variables removed and aliased to other variables, as indicated by `old_to_new_var`
284+
and `aliases`. `aliases` can be a `Dict{Int, Int}` indicating variables exactly aliased
285+
to others, or `Dict{Int, SparseVector{eltype(mm)}}` indicating variables aliased to linear
286+
combinations of others. Note that this is not recursive - if one variable
287+
depends on another aliased variable, it will lead to incorrect results.
288+
"""
289+
function get_new_mm(
290+
aliases::Dict{Int}, old_to_new_eq::Vector{Int}, old_to_new_var::Vector{Int},
291+
mm::CLIL.SparseMatrixCLIL
292+
)
293+
294+
new_nparentrows = mm.nparentrows
295+
new_row_cols = eltype(mm.row_cols)[]
296+
new_row_vals = eltype(mm.row_vals)[]
297+
new_nzrows = Int[]
298+
indices = Int[]
299+
300+
for (i, eq) in enumerate(mm.nzrows)
301+
old_to_new_eq[eq] > 0 || continue
302+
new_row_col_i = eltype(eltype(new_row_cols))[]
303+
new_row_val_i = eltype(eltype(new_row_vals))[]
304+
sizehint!(new_row_col_i, length(mm.row_cols[i]))
305+
sizehint!(new_row_val_i, length(mm.row_vals[i]))
306+
still_valid_eq = true
307+
for (var, coeff) in zip(mm.row_cols[i], mm.row_vals[i])
308+
if old_to_new_var[var] > 0
309+
push!(new_row_col_i, old_to_new_var[var])
310+
push!(new_row_val_i, coeff)
311+
continue
312+
end
313+
# This variable is removed, but not aliased to an integer coefficient linear
314+
# combination. As a result, this equation cannot be retained in `mm`.
315+
if !haskey(aliases, var)
316+
still_valid_eq = false
317+
break
318+
end
319+
320+
add_row_coeffs!(new_row_col_i, new_row_val_i, old_to_new_var, aliases, var, coeff)
321+
end
322+
323+
bad_idx = findfirst(iszero, new_row_col_i)
324+
if bad_idx isa Int
325+
throw(BadMMAliasError(bad_idx))
326+
end
327+
still_valid_eq || continue
328+
empty!(indices)
329+
append!(indices, LinearIndices(new_row_col_i))
330+
sortperm!(indices, new_row_col_i)
331+
final_row_cols = empty(new_row_col_i)
332+
final_row_vals = empty(new_row_val_i)
333+
push!(final_row_cols, new_row_col_i[indices[1]])
334+
push!(final_row_vals, new_row_val_i[indices[1]])
335+
for i in Iterators.drop(eachindex(indices), 1)
336+
if new_row_col_i[indices[i]] == new_row_col_i[indices[i - 1]]
337+
final_row_vals[end] += new_row_val_i[indices[i]]
338+
else
339+
push!(final_row_cols, new_row_col_i[indices[i]])
340+
push!(final_row_vals, new_row_val_i[indices[i]])
341+
end
342+
end
343+
344+
push!(new_row_cols, final_row_cols)
345+
push!(new_row_vals, final_row_vals)
346+
push!(new_nzrows, old_to_new_eq[eq])
347+
end
348+
349+
return typeof(mm)(new_nparentrows, count(!iszero, old_to_new_var), new_nzrows, new_row_cols, new_row_vals)
350+
end
351+
352+
struct BadMMAliasError <: Exception
353+
eq::Int
354+
end
355+
356+
function Base.showerror(io::IO, err::BadMMAliasError)
357+
return print(
358+
io, """
359+
When processing equation $(err.eq), the list of aliases resulted in a linear
360+
combination of a removed variable. No variable should be aliased to a removed
361+
variable.
362+
"""
363+
)
364+
end

test/runtests.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,31 @@
11
using StateSelection
2+
import StateSelection as SSel
3+
using SparseArrays
24
using Test
35

46
include("bareiss.jl")
7+
8+
@testset "`get_new_mm`" begin
9+
mm = SSel.CLIL.SparseMatrixCLIL(
10+
[
11+
-1 0 2 0
12+
2 1 2 0
13+
0 0 1 2
14+
]
15+
)
16+
# We're using the first equation to solve for the first variable,
17+
# and removing the last variable (assuming it's torn via some other
18+
# equation)
19+
old_to_new_eq = [0, 2, 3]
20+
old_to_new_var = [0, 1, 2, 0]
21+
aliases = Dict(1 => sparse([0, 0, 2, 0]))
22+
mm2 = SSel.get_new_mm(aliases, old_to_new_eq, old_to_new_var, mm)
23+
# Eq#3 can't be retained, because it depends on variable 4 which isn't an
24+
# integer coefficient linear combination
25+
@test mm2.nzrows == [2]
26+
@test mm2.row_cols == [[1, 2]]
27+
@test mm2.row_vals == [[1, 6]]
28+
@test mm2.ncols == 2
29+
aliases = Dict(1 => sparse([0, 0, 2, 1]))
30+
@test_throws SSel.BadMMAliasError SSel.get_new_mm(aliases, old_to_new_eq, old_to_new_var, mm)
31+
end

0 commit comments

Comments
 (0)