Skip to content

Commit 34cd268

Browse files
Merge pull request #65 from JuliaComputing/as/fix-find-eq-solvables
fix(`find_eq_solvables!`): handle pushing to `coeffs` when `may_be_zero`
2 parents cce01d4 + 7ec7e30 commit 34cd268

4 files changed

Lines changed: 114 additions & 3 deletions

File tree

ext/StateSelectionDeepDiffsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module StateSelectionDeepDiffsExt
22

33
using DeepDiffs
4-
using BipartiteGraphs: Label, BipartiteAdjacencyList, unassigned, HighlightInt
4+
using BipartiteGraphs: Label, BipartiteAdjacencyList, unassigned, HighlightInt, Unassigned
55
using StateSelection: SystemStructure,
66
MatchedSystemStructure,
77
SystemStructurePrintMatrix

lib/ModelingToolkitTearing/src/stateselection_interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ function StateSelection.find_eq_solvables!(state::TearingState, ieq, to_rm = Int
268268
conservative && continue
269269
elseif conservative && abs(a) > 1
270270
continue
271-
else
272-
coeffs === nothing || push!(coeffs, a)
271+
elseif coeffs !== nothing && (!iszero(a) || !may_be_zero)
272+
push!(coeffs, a)
273273
end
274274

275275
if !iszero(a)

lib/ModelingToolkitTearing/src/utils.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,97 @@ macro union_split_var(annotated_var::Expr, block::Expr)
8282

8383
return esc(expr)
8484
end
85+
86+
"""
87+
$TYPEDSIGNATURES
88+
89+
Debugging tool useful for comparing two `SystemStructure`s. Return a copy of `structure` with
90+
variables reordered according to `oldtonewvar` and equations according to `oldtoneweq`.
91+
"""
92+
function permute(structure::SystemStructure, oldtonewvar::Vector{Int}, oldtoneweq::Vector{Int})
93+
graph = BipartiteGraph(nsrcs(structure.graph), ndsts(structure.graph))
94+
for e in 𝑠vertices(structure.graph)
95+
for v in 𝑠neighbors(structure.graph, e)
96+
add_edge!(graph, oldtoneweq[e], oldtonewvar[v])
97+
end
98+
end
99+
solvable_graph = BipartiteGraph(nsrcs(structure.solvable_graph), ndsts(structure.solvable_graph))
100+
for e in 𝑠vertices(structure.solvable_graph)
101+
for v in 𝑠neighbors(structure.solvable_graph, e)
102+
add_edge!(solvable_graph, oldtoneweq[e], oldtonewvar[v])
103+
end
104+
end
105+
var_to_diff = StateSelection.DiffGraph(ndsts(graph))
106+
for i in 𝑑vertices(structure.graph)
107+
if structure.var_to_diff[i] isa Int
108+
var_to_diff[oldtonewvar[i]] = oldtonewvar[structure.var_to_diff[i]]
109+
end
110+
end
111+
eq_to_diff = StateSelection.DiffGraph(nsrcs(graph))
112+
for i in 𝑠vertices(structure.graph)
113+
if structure.eq_to_diff[i] isa Int
114+
eq_to_diff[oldtoneweq[i]] = oldtoneweq[structure.eq_to_diff[i]]
115+
end
116+
end
117+
118+
var_types = similar(structure.var_types)
119+
sps = similar(structure.state_priorities)
120+
for i in 𝑑vertices(structure.graph)
121+
var_types[oldtonewvar[i]] = structure.var_types[i]
122+
sps[oldtonewvar[i]] = structure.state_priorities[i]
123+
end
124+
125+
return SystemStructure(var_to_diff, eq_to_diff, graph, solvable_graph, var_types, sps, structure.only_discrete)
126+
end
127+
128+
"""
129+
$TYPEDSIGNATURES
130+
131+
Debugging tool useful for comparing two `TearingState`s. Return a copy of `ts` with
132+
variables reordered according to `oldtonewvar` and equations according to `oldtoneweq`.
133+
"""
134+
function permute(ts::TearingState, oldtonewvar::Vector{Int}, oldtoneweq::Vector{Int})
135+
structure = permute(ts.structure, oldtonewvar, oldtoneweq)
136+
fullvars = similar(ts.fullvars)
137+
always_present = similar(ts.always_present)
138+
for i in eachindex(fullvars)
139+
fullvars[oldtonewvar[i]] = ts.fullvars[i]
140+
always_present[oldtonewvar[i]] = ts.always_present[i]
141+
end
142+
original_eqs = similar(ts.original_eqs)
143+
eqs_source = similar(ts.eqs_source)
144+
eqs = collect(equations(ts))
145+
for i in eachindex(original_eqs)
146+
original_eqs[oldtoneweq[i]] = ts.original_eqs[i]
147+
eqs_source[oldtoneweq[i]] = ts.eqs_source[i]
148+
eqs[oldtoneweq[i]] = equations(ts)[i]
149+
end
150+
151+
sys = ts.sys
152+
@set! sys.eqs = eqs
153+
@set! sys.unknowns = fullvars
154+
return TearingState(sys, fullvars, structure, Equation[], ts.param_derivative_map, ts.no_deriv_params, original_eqs, ts.additional_observed, always_present, ts.statemachines, eqs_source)
155+
end
156+
157+
"""
158+
$TYPEDSIGNATURES
159+
160+
Debugging tool useful for comparing two `SparseMatrixCLIL`s. Return a copy of `mm` with
161+
variables reordered according to `oldtonewvar` and equations according to `oldtoneweq`.
162+
"""
163+
function permute(mm::StateSelection.SparseMatrixCLIL{S, T}, oldtonewvar::Vector{Int}, oldtoneweq::Vector{Int}) where {S, T}
164+
nzrows = copy(mm.nzrows)
165+
rowcols = copy(mm.row_cols)
166+
rowvals = copy(mm.row_vals)
167+
for i in eachindex(nzrows)
168+
nzrows[i] = oldtoneweq[nzrows[i]]
169+
for j in eachindex(rowcols[i])
170+
rowcols[i][j] = oldtonewvar[rowcols[i][j]]
171+
end
172+
perm = sortperm(rowcols[i])
173+
rowcols[i] = rowcols[i][perm]
174+
rowvals[i] = rowvals[i][perm]
175+
end
176+
177+
return StateSelection.SparseMatrixCLIL{S, T}(mm.nparentrows, mm.ncols, nzrows, rowcols, rowvals)
178+
end

lib/ModelingToolkitTearing/test/runtests.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Graphs
88
import StateSelection
99
using ModelingToolkit: t_nounits as t, D_nounits as D
1010
import SymbolicUtils as SU
11+
using Setfield
1112
using ForwardDiff
1213

1314
@testset "`InferredDiscrete` validation" begin
@@ -138,3 +139,19 @@ end
138139
prob.f.f.f_iip(du, prob.u0, prob.p, t)
139140
end
140141
end
142+
143+
@testset "`find_eq_solvables!` with `may_be_zero = true` doesn't push 0 elements to `coeffs`" begin
144+
@variables x y z
145+
@named sys = System([x + y + z ~ 0])
146+
ts = TearingState(sys)
147+
# Artificially remove symbolic incidence
148+
@set! ts.sys.eqs = [0 ~ -x]
149+
ts.structure.solvable_graph = BipartiteGraph(1, 3)
150+
to_rm = Int[]
151+
coeffs = Int[]
152+
StateSelection.find_eq_solvables!(ts, 1, to_rm, coeffs)
153+
@test issetequal(ts.fullvars[to_rm], [y, z])
154+
# Previously, this would have zeros corresponding to `y` and `z`, and yet the incidence for those
155+
# variables is removed from `ts.structure.graph`. This would cause incorrect values in `linear_subsys_adjmat!`
156+
@test coeffs == [-1]
157+
end

0 commit comments

Comments
 (0)