Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions lib/ModelingToolkitTearing/src/stateselection_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,38 @@ function manual_dispatch_is_small_int(@nospecialize(x::Number))::Int
end
end
end

function StateSelection.rm_eqs_vars!(
structure::SystemStructure, eqs_to_rm::Vector{Int}, vars_to_rm::Vector{Int};
eqs_sorted_and_uniqued::Bool = false, vars_sorted_and_uniqued::Bool = false
)
old_to_new_eq, old_to_new_var = StateSelection.default_rm_eqs_vars!(
structure, eqs_to_rm, vars_to_rm; eqs_sorted_and_uniqued, vars_sorted_and_uniqued
)
deleteat!(structure.state_priorities, vars_to_rm)
deleteat!(structure.var_types, vars_to_rm)
return old_to_new_eq, old_to_new_var
end

function StateSelection.rm_eqs_vars!(
state::TearingState, eqs_to_rm::Vector{Int}, vars_to_rm::Vector{Int};
eqs_sorted_and_uniqued::Bool = false, vars_sorted_and_uniqued::Bool = false
)
(; structure, sys) = state
old_to_new_eq, old_to_new_var = StateSelection.rm_eqs_vars!(
structure, eqs_to_rm, vars_to_rm; eqs_sorted_and_uniqued, vars_sorted_and_uniqued
)
deleteat!(state.fullvars, vars_to_rm)
eqs = copy(MTKBase.get_eqs(state.sys))
deleteat!(eqs, eqs_to_rm)
deleteat!(state.original_eqs, eqs_to_rm)
if !isempty(state.eqs_source)
deleteat!(state.eqs_source, eqs_to_rm)
end

@set! sys.eqs = eqs
state.sys = sys

return old_to_new_eq, old_to_new_var
end

17 changes: 2 additions & 15 deletions lib/ModelingToolkitTearing/src/trivial_tearing_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,8 @@ function StateSelection.possibly_explicit_equations(state::TearingState)
return Iterators.map(mapfn, Iterators.filter(filterfn, eachindex(state.original_eqs)))
end

function StateSelection.trivial_tearing_postprocess!(ts::TearingState, torn_eqs::OrderedSet{Int}, torn_vars::OrderedSet{Int})
append!(ts.additional_observed, @view ts.original_eqs[collect(torn_eqs)])
sort!(torn_vars)
sort!(torn_eqs)
if ts.structure.var_types !== nothing
deleteat!(ts.structure.var_types, torn_vars)
end
deleteat!(ts.fullvars, torn_vars)
deleteat!(ts.structure.state_priorities, torn_vars)
deleteat!(ts.original_eqs, torn_eqs)
sys = ts.sys
eqs = copy(MTKBase.get_eqs(sys))
deleteat!(eqs, torn_eqs)
@set! sys.eqs = eqs
ts.sys = sys
function StateSelection.trivial_tearing_postprocess!(ts::TearingState, torn_eqs::Vector{Int}, torn_vars::Vector{Int})
append!(ts.additional_observed, @view ts.original_eqs[torn_eqs])
return ts
end

82 changes: 75 additions & 7 deletions src/tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Preemptively identify observed equations in the system and tear them. This happe
any simplification. The equations torn by this process are ones that are already given in
an explicit form in the system and where the LHS is not present in any other equation of
the system except for other such preempitvely torn equations.

Optionally passing in the linear coefficient matrix `mm` will update it according to the
newly torn equations.
"""
function trivial_tearing!(ts::TransformationState, mm::Union{SparseMatrixCLIL, Nothing} = nothing)
if is_only_discrete(ts.structure)
Expand Down Expand Up @@ -95,33 +98,98 @@ function trivial_tearing!(ts::TransformationState, mm::Union{SparseMatrixCLIL, N
# if we didn't add an equation this iteration, we won't add one next iteration
added_equation || break
end
torn_vars_idxs = collect(matched_vars)
torn_eqs_idxs = collect(trivial_idxs)

# For backward compatibility
if hasmethod(rm_eqs_vars!, Tuple{typeof(ts), Vector{Int}, Vector{Int}})
# `deleteat!` requires sorted indices, but we want to maintain relative order to pass
# to `trivial_tearing_postprocess!`
torn_vars_idxs = collect(matched_vars)
torn_eqs_idxs = collect(trivial_idxs)
trivial_tearing_postprocess!(ts, trivial_idxs, matched_vars)
trivial_tearing_postprocess!(ts, torn_eqs_idxs, torn_vars_idxs)
sort!(torn_eqs_idxs)
sort!(torn_vars_idxs)
rm_eqs_vars!(
old_to_new_eq, old_to_new_var = rm_eqs_vars!(
ts, torn_eqs_idxs, torn_vars_idxs; eqs_sorted_and_uniqued = true,
vars_sorted_and_uniqued = true
)
else
# `deleteat!` requires sorted indices, but we want to maintain relative order to pass
# to `trivial_tearing_postprocess!`
torn_vars_idxs = collect(matched_vars)
sort!(torn_vars_idxs)
torn_eqs_idxs = collect(trivial_idxs)
sort!(torn_eqs_idxs)
rm_eqs_vars!(
old_to_new_eq, old_to_new_var = rm_eqs_vars!(
ts.structure, torn_eqs_idxs, torn_vars_idxs; eqs_sorted_and_uniqued = true,
vars_sorted_and_uniqued = true
)
trivial_tearing_postprocess!(ts, trivial_idxs, matched_vars)
end
if mm !== nothing
aliases = Dict{Int, SparseArrays.SparseVector{eltype(mm), Int}}()
linear_eqs = Dict{Int, Int}()
for (i, eq) in enumerate(mm.nzrows)
linear_eqs[eq] = i
end
torn_vars_set = BitSet(torn_vars_idxs)
perm = Int[]
for (var, eq) in zip(torn_vars_idxs, torn_eqs_idxs)
ieq = get(linear_eqs, eq, 0)
iszero(ieq) && continue
# `trivial_tearing!` considers equations already written by the user in a
# solvable form, so we know that the coefficient of `var` must be `-1` and do
# not need to be concerned with fractions.
eq_vars = mm.row_cols[ieq]
eq_coeffs = mm.row_vals[ieq]

I = Int[]
V = Int[]
sizehint!(I, length(eq_vars))
sizehint!(V, length(eq_vars))
can_be_aliased = true
for (v, cf) in zip(eq_vars, eq_coeffs)
if v == var
@assert cf == -1
continue
end
alias = get(aliases, v, nothing)
if alias === nothing
# If this variable is not aliased to a linear combination,
# and is torn, then this equation is no longer a linear equation
# that can be retained in `mm`.
if v in torn_vars_set
can_be_aliased = false
break
end
push!(I, v)
push!(V, cf)
continue
end
_I, _V = SparseArrays.findnz(alias)
append!(I, _I)
append!(V, Iterators.map(Base.Fix1(*, cf), _V))
end
can_be_aliased || continue

resize!(perm, length(I))
sortperm!(perm, I)
I = I[perm]
V = V[perm]
i = 1
for j in Iterators.drop(eachindex(I), 1)
if I[j] == I[i]
V[i] += V[j]
else
i += 1
I[i] = I[j]
V[i] = V[j]
end
end

aliases[var] = SparseArrays.SparseVector(size(mm, 2), I, V)
end
mm = get_new_mm(aliases, old_to_new_eq, old_to_new_var, mm)
return ts, mm
end

return ts
end

Expand Down
117 changes: 117 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,120 @@ function sorted_incidence_matrix(ts::TransformationState, val = true; only_algeq
end
SparseArrays.sparse(I, J, val, nsrcs(graph), ndsts(graph))
end

function add_row_coeffs!(
row_col_i::Vector{Int}, row_val_i::Vector{T}, old_to_new_var::Vector{Int},
aliases::Dict{Int, Int}, old_var::Int, coeff::T
) where {T <: Integer}
alias = get(aliases, old_var, 0)
iszero(alias) && return
push!(row_col_i, old_to_new_var[alias])
push!(row_val_i, coeff)
return nothing
end

function add_row_coeffs!(
row_col_i::Vector{Int}, row_val_i::Vector{T}, old_to_new_var::Vector{Int},
aliases::Dict{Int, SparseArrays.SparseVector{T, Int}}, old_var::Int, coeff::T
) where {T <: Integer}
alias = get(aliases, old_var, nothing)
if alias isa SparseArrays.SparseVector{T, Int}
I, V = SparseArrays.findnz(alias)
for (i, v) in zip(I, V)
iszero(old_to_new_var[i])
end
append!(row_col_i, Iterators.map(Base.Fix1(getindex, old_to_new_var), I))
append!(row_val_i, Iterators.map(Base.Fix1(*, coeff), V))
end
return nothing
end

"""
$TYPEDSIGNATURES

Construct the new coefficient matrix with:

- Some equations removed, as indicated by `old_to_new_eq` which is obtained from
`get_old_to_new_idxs`.
- Some variables removed and aliased to other variables, as indicated by `old_to_new_var`
and `aliases`. `aliases` can be a `Dict{Int, Int}` indicating variables exactly aliased
to others, or `Dict{Int, SparseVector{eltype(mm)}}` indicating variables aliased to linear
combinations of others. Note that this is not recursive - if one variable
depends on another aliased variable, it will lead to incorrect results.
"""
function get_new_mm(
aliases::Dict{Int}, old_to_new_eq::Vector{Int}, old_to_new_var::Vector{Int},
mm::CLIL.SparseMatrixCLIL
)

new_nparentrows = mm.nparentrows
new_row_cols = eltype(mm.row_cols)[]
new_row_vals = eltype(mm.row_vals)[]
new_nzrows = Int[]
indices = Int[]

for (i, eq) in enumerate(mm.nzrows)
old_to_new_eq[eq] > 0 || continue
new_row_col_i = eltype(eltype(new_row_cols))[]
new_row_val_i = eltype(eltype(new_row_vals))[]
sizehint!(new_row_col_i, length(mm.row_cols[i]))
sizehint!(new_row_val_i, length(mm.row_vals[i]))
still_valid_eq = true
for (var, coeff) in zip(mm.row_cols[i], mm.row_vals[i])
if old_to_new_var[var] > 0
push!(new_row_col_i, old_to_new_var[var])
push!(new_row_val_i, coeff)
continue
end
# This variable is removed, but not aliased to an integer coefficient linear
# combination. As a result, this equation cannot be retained in `mm`.
if !haskey(aliases, var)
still_valid_eq = false
break
end

add_row_coeffs!(new_row_col_i, new_row_val_i, old_to_new_var, aliases, var, coeff)
end

bad_idx = findfirst(iszero, new_row_col_i)
if bad_idx isa Int
throw(BadMMAliasError(bad_idx))
end
still_valid_eq || continue
empty!(indices)
append!(indices, LinearIndices(new_row_col_i))
sortperm!(indices, new_row_col_i)
final_row_cols = empty(new_row_col_i)
final_row_vals = empty(new_row_val_i)
push!(final_row_cols, new_row_col_i[indices[1]])
push!(final_row_vals, new_row_val_i[indices[1]])
for i in Iterators.drop(eachindex(indices), 1)
if new_row_col_i[indices[i]] == new_row_col_i[indices[i - 1]]
final_row_vals[end] += new_row_val_i[indices[i]]
else
push!(final_row_cols, new_row_col_i[indices[i]])
push!(final_row_vals, new_row_val_i[indices[i]])
end
end

push!(new_row_cols, final_row_cols)
push!(new_row_vals, final_row_vals)
push!(new_nzrows, old_to_new_eq[eq])
end

return typeof(mm)(new_nparentrows, count(!iszero, old_to_new_var), new_nzrows, new_row_cols, new_row_vals)
end

struct BadMMAliasError <: Exception
eq::Int
end

function Base.showerror(io::IO, err::BadMMAliasError)
return print(
io, """
When processing equation $(err.eq), the list of aliases resulted in a linear
combination of a removed variable. No variable should be aliased to a removed
variable.
"""
)
end
27 changes: 27 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
using StateSelection
import StateSelection as SSel
using SparseArrays
using Test

include("bareiss.jl")

@testset "`get_new_mm`" begin
mm = SSel.CLIL.SparseMatrixCLIL(
[
-1 0 2 0
2 1 2 0
0 0 1 2
]
)
# We're using the first equation to solve for the first variable,
# and removing the last variable (assuming it's torn via some other
# equation)
old_to_new_eq = [0, 2, 3]
old_to_new_var = [0, 1, 2, 0]
aliases = Dict(1 => sparse([0, 0, 2, 0]))
mm2 = SSel.get_new_mm(aliases, old_to_new_eq, old_to_new_var, mm)
# Eq#3 can't be retained, because it depends on variable 4 which isn't an
# integer coefficient linear combination
@test mm2.nzrows == [2]
@test mm2.row_cols == [[1, 2]]
@test mm2.row_vals == [[1, 6]]
@test mm2.ncols == 2
aliases = Dict(1 => sparse([0, 0, 2, 1]))
@test_throws SSel.BadMMAliasError SSel.get_new_mm(aliases, old_to_new_eq, old_to_new_var, mm)
end
Loading