diff --git a/lib/ModelingToolkitTearing/src/stateselection_interface.jl b/lib/ModelingToolkitTearing/src/stateselection_interface.jl index 19abe4c..b079e0f 100644 --- a/lib/ModelingToolkitTearing/src/stateselection_interface.jl +++ b/lib/ModelingToolkitTearing/src/stateselection_interface.jl @@ -326,3 +326,38 @@ function manual_dispatch_is_small_int(@nospecialize(x::Number))::Int end end end + +function StateSelection.rm_eqs_vars!( + structure::SystemStructure, eqs_to_rm::Vector{Int}, vars_to_rm::Vector{Int}; + eqs_sorted_and_uniqued::Bool = false, vars_sorted_and_uniqued::Bool = false + ) + old_to_new_eq, old_to_new_var = StateSelection.default_rm_eqs_vars!( + structure, eqs_to_rm, vars_to_rm; eqs_sorted_and_uniqued, vars_sorted_and_uniqued + ) + deleteat!(structure.state_priorities, vars_to_rm) + deleteat!(structure.var_types, vars_to_rm) + return old_to_new_eq, old_to_new_var +end + +function StateSelection.rm_eqs_vars!( + state::TearingState, eqs_to_rm::Vector{Int}, vars_to_rm::Vector{Int}; + eqs_sorted_and_uniqued::Bool = false, vars_sorted_and_uniqued::Bool = false + ) + (; structure, sys) = state + old_to_new_eq, old_to_new_var = StateSelection.rm_eqs_vars!( + structure, eqs_to_rm, vars_to_rm; eqs_sorted_and_uniqued, vars_sorted_and_uniqued + ) + deleteat!(state.fullvars, vars_to_rm) + eqs = copy(MTKBase.get_eqs(state.sys)) + deleteat!(eqs, eqs_to_rm) + deleteat!(state.original_eqs, eqs_to_rm) + if !isempty(state.eqs_source) + deleteat!(state.eqs_source, eqs_to_rm) + end + + @set! sys.eqs = eqs + state.sys = sys + + return old_to_new_eq, old_to_new_var +end + diff --git a/lib/ModelingToolkitTearing/src/trivial_tearing_interface.jl b/lib/ModelingToolkitTearing/src/trivial_tearing_interface.jl index 5be66ad..f715554 100644 --- a/lib/ModelingToolkitTearing/src/trivial_tearing_interface.jl +++ b/lib/ModelingToolkitTearing/src/trivial_tearing_interface.jl @@ -23,21 +23,8 @@ function StateSelection.possibly_explicit_equations(state::TearingState) return Iterators.map(mapfn, Iterators.filter(filterfn, eachindex(state.original_eqs))) end -function StateSelection.trivial_tearing_postprocess!(ts::TearingState, torn_eqs::OrderedSet{Int}, torn_vars::OrderedSet{Int}) - append!(ts.additional_observed, @view ts.original_eqs[collect(torn_eqs)]) - sort!(torn_vars) - sort!(torn_eqs) - if ts.structure.var_types !== nothing - deleteat!(ts.structure.var_types, torn_vars) - end - deleteat!(ts.fullvars, torn_vars) - deleteat!(ts.structure.state_priorities, torn_vars) - deleteat!(ts.original_eqs, torn_eqs) - sys = ts.sys - eqs = copy(MTKBase.get_eqs(sys)) - deleteat!(eqs, torn_eqs) - @set! sys.eqs = eqs - ts.sys = sys +function StateSelection.trivial_tearing_postprocess!(ts::TearingState, torn_eqs::Vector{Int}, torn_vars::Vector{Int}) + append!(ts.additional_observed, @view ts.original_eqs[torn_eqs]) return ts end diff --git a/src/tearing.jl b/src/tearing.jl index 3b80da9..59b2b62 100644 --- a/src/tearing.jl +++ b/src/tearing.jl @@ -34,6 +34,9 @@ Preemptively identify observed equations in the system and tear them. This happe any simplification. The equations torn by this process are ones that are already given in an explicit form in the system and where the LHS is not present in any other equation of the system except for other such preempitvely torn equations. + +Optionally passing in the linear coefficient matrix `mm` will update it according to the +newly torn equations. """ function trivial_tearing!(ts::TransformationState, mm::Union{SparseMatrixCLIL, Nothing} = nothing) if is_only_discrete(ts.structure) @@ -95,33 +98,98 @@ function trivial_tearing!(ts::TransformationState, mm::Union{SparseMatrixCLIL, N # if we didn't add an equation this iteration, we won't add one next iteration added_equation || break end + torn_vars_idxs = collect(matched_vars) + torn_eqs_idxs = collect(trivial_idxs) # For backward compatibility if hasmethod(rm_eqs_vars!, Tuple{typeof(ts), Vector{Int}, Vector{Int}}) # `deleteat!` requires sorted indices, but we want to maintain relative order to pass # to `trivial_tearing_postprocess!` - torn_vars_idxs = collect(matched_vars) - torn_eqs_idxs = collect(trivial_idxs) - trivial_tearing_postprocess!(ts, trivial_idxs, matched_vars) + trivial_tearing_postprocess!(ts, torn_eqs_idxs, torn_vars_idxs) sort!(torn_eqs_idxs) sort!(torn_vars_idxs) - rm_eqs_vars!( + old_to_new_eq, old_to_new_var = rm_eqs_vars!( ts, torn_eqs_idxs, torn_vars_idxs; eqs_sorted_and_uniqued = true, vars_sorted_and_uniqued = true ) else # `deleteat!` requires sorted indices, but we want to maintain relative order to pass # to `trivial_tearing_postprocess!` - torn_vars_idxs = collect(matched_vars) sort!(torn_vars_idxs) - torn_eqs_idxs = collect(trivial_idxs) sort!(torn_eqs_idxs) - rm_eqs_vars!( + old_to_new_eq, old_to_new_var = rm_eqs_vars!( ts.structure, torn_eqs_idxs, torn_vars_idxs; eqs_sorted_and_uniqued = true, vars_sorted_and_uniqued = true ) trivial_tearing_postprocess!(ts, trivial_idxs, matched_vars) end + if mm !== nothing + aliases = Dict{Int, SparseArrays.SparseVector{eltype(mm), Int}}() + linear_eqs = Dict{Int, Int}() + for (i, eq) in enumerate(mm.nzrows) + linear_eqs[eq] = i + end + torn_vars_set = BitSet(torn_vars_idxs) + perm = Int[] + for (var, eq) in zip(torn_vars_idxs, torn_eqs_idxs) + ieq = get(linear_eqs, eq, 0) + iszero(ieq) && continue + # `trivial_tearing!` considers equations already written by the user in a + # solvable form, so we know that the coefficient of `var` must be `-1` and do + # not need to be concerned with fractions. + eq_vars = mm.row_cols[ieq] + eq_coeffs = mm.row_vals[ieq] + + I = Int[] + V = Int[] + sizehint!(I, length(eq_vars)) + sizehint!(V, length(eq_vars)) + can_be_aliased = true + for (v, cf) in zip(eq_vars, eq_coeffs) + if v == var + @assert cf == -1 + continue + end + alias = get(aliases, v, nothing) + if alias === nothing + # If this variable is not aliased to a linear combination, + # and is torn, then this equation is no longer a linear equation + # that can be retained in `mm`. + if v in torn_vars_set + can_be_aliased = false + break + end + push!(I, v) + push!(V, cf) + continue + end + _I, _V = SparseArrays.findnz(alias) + append!(I, _I) + append!(V, Iterators.map(Base.Fix1(*, cf), _V)) + end + can_be_aliased || continue + + resize!(perm, length(I)) + sortperm!(perm, I) + I = I[perm] + V = V[perm] + i = 1 + for j in Iterators.drop(eachindex(I), 1) + if I[j] == I[i] + V[i] += V[j] + else + i += 1 + I[i] = I[j] + V[i] = V[j] + end + end + + aliases[var] = SparseArrays.SparseVector(size(mm, 2), I, V) + end + mm = get_new_mm(aliases, old_to_new_eq, old_to_new_var, mm) + return ts, mm + end + return ts end diff --git a/src/utils.jl b/src/utils.jl index 40fda28..c5daba6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -245,3 +245,120 @@ function sorted_incidence_matrix(ts::TransformationState, val = true; only_algeq end SparseArrays.sparse(I, J, val, nsrcs(graph), ndsts(graph)) end + +function add_row_coeffs!( + row_col_i::Vector{Int}, row_val_i::Vector{T}, old_to_new_var::Vector{Int}, + aliases::Dict{Int, Int}, old_var::Int, coeff::T + ) where {T <: Integer} + alias = get(aliases, old_var, 0) + iszero(alias) && return + push!(row_col_i, old_to_new_var[alias]) + push!(row_val_i, coeff) + return nothing +end + +function add_row_coeffs!( + row_col_i::Vector{Int}, row_val_i::Vector{T}, old_to_new_var::Vector{Int}, + aliases::Dict{Int, SparseArrays.SparseVector{T, Int}}, old_var::Int, coeff::T + ) where {T <: Integer} + alias = get(aliases, old_var, nothing) + if alias isa SparseArrays.SparseVector{T, Int} + I, V = SparseArrays.findnz(alias) + for (i, v) in zip(I, V) + iszero(old_to_new_var[i]) + end + append!(row_col_i, Iterators.map(Base.Fix1(getindex, old_to_new_var), I)) + append!(row_val_i, Iterators.map(Base.Fix1(*, coeff), V)) + end + return nothing +end + +""" + $TYPEDSIGNATURES + +Construct the new coefficient matrix with: + +- Some equations removed, as indicated by `old_to_new_eq` which is obtained from + `get_old_to_new_idxs`. +- Some variables removed and aliased to other variables, as indicated by `old_to_new_var` + and `aliases`. `aliases` can be a `Dict{Int, Int}` indicating variables exactly aliased + to others, or `Dict{Int, SparseVector{eltype(mm)}}` indicating variables aliased to linear + combinations of others. Note that this is not recursive - if one variable + depends on another aliased variable, it will lead to incorrect results. +""" +function get_new_mm( + aliases::Dict{Int}, old_to_new_eq::Vector{Int}, old_to_new_var::Vector{Int}, + mm::CLIL.SparseMatrixCLIL + ) + + new_nparentrows = mm.nparentrows + new_row_cols = eltype(mm.row_cols)[] + new_row_vals = eltype(mm.row_vals)[] + new_nzrows = Int[] + indices = Int[] + + for (i, eq) in enumerate(mm.nzrows) + old_to_new_eq[eq] > 0 || continue + new_row_col_i = eltype(eltype(new_row_cols))[] + new_row_val_i = eltype(eltype(new_row_vals))[] + sizehint!(new_row_col_i, length(mm.row_cols[i])) + sizehint!(new_row_val_i, length(mm.row_vals[i])) + still_valid_eq = true + for (var, coeff) in zip(mm.row_cols[i], mm.row_vals[i]) + if old_to_new_var[var] > 0 + push!(new_row_col_i, old_to_new_var[var]) + push!(new_row_val_i, coeff) + continue + end + # This variable is removed, but not aliased to an integer coefficient linear + # combination. As a result, this equation cannot be retained in `mm`. + if !haskey(aliases, var) + still_valid_eq = false + break + end + + add_row_coeffs!(new_row_col_i, new_row_val_i, old_to_new_var, aliases, var, coeff) + end + + bad_idx = findfirst(iszero, new_row_col_i) + if bad_idx isa Int + throw(BadMMAliasError(bad_idx)) + end + still_valid_eq || continue + empty!(indices) + append!(indices, LinearIndices(new_row_col_i)) + sortperm!(indices, new_row_col_i) + final_row_cols = empty(new_row_col_i) + final_row_vals = empty(new_row_val_i) + push!(final_row_cols, new_row_col_i[indices[1]]) + push!(final_row_vals, new_row_val_i[indices[1]]) + for i in Iterators.drop(eachindex(indices), 1) + if new_row_col_i[indices[i]] == new_row_col_i[indices[i - 1]] + final_row_vals[end] += new_row_val_i[indices[i]] + else + push!(final_row_cols, new_row_col_i[indices[i]]) + push!(final_row_vals, new_row_val_i[indices[i]]) + end + end + + push!(new_row_cols, final_row_cols) + push!(new_row_vals, final_row_vals) + push!(new_nzrows, old_to_new_eq[eq]) + end + + return typeof(mm)(new_nparentrows, count(!iszero, old_to_new_var), new_nzrows, new_row_cols, new_row_vals) +end + +struct BadMMAliasError <: Exception + eq::Int +end + +function Base.showerror(io::IO, err::BadMMAliasError) + return print( + io, """ + When processing equation $(err.eq), the list of aliases resulted in a linear + combination of a removed variable. No variable should be aliased to a removed + variable. + """ + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index f44c6f9..399f56f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,31 @@ using StateSelection +import StateSelection as SSel +using SparseArrays using Test include("bareiss.jl") + +@testset "`get_new_mm`" begin + mm = SSel.CLIL.SparseMatrixCLIL( + [ + -1 0 2 0 + 2 1 2 0 + 0 0 1 2 + ] + ) + # We're using the first equation to solve for the first variable, + # and removing the last variable (assuming it's torn via some other + # equation) + old_to_new_eq = [0, 2, 3] + old_to_new_var = [0, 1, 2, 0] + aliases = Dict(1 => sparse([0, 0, 2, 0])) + mm2 = SSel.get_new_mm(aliases, old_to_new_eq, old_to_new_var, mm) + # Eq#3 can't be retained, because it depends on variable 4 which isn't an + # integer coefficient linear combination + @test mm2.nzrows == [2] + @test mm2.row_cols == [[1, 2]] + @test mm2.row_vals == [[1, 6]] + @test mm2.ncols == 2 + aliases = Dict(1 => sparse([0, 0, 2, 1])) + @test_throws SSel.BadMMAliasError SSel.get_new_mm(aliases, old_to_new_eq, old_to_new_var, mm) +end