diff --git a/lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl b/lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl index 23046ae..d2d8b73 100644 --- a/lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl +++ b/lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl @@ -107,7 +107,7 @@ function MTKBase.unhack_system(sys::System) resize!(obs_mask, length(obseqs)) fill!(obs_mask, true) additional_eqs = Equation[] - additional_vars = SymbolicT[] + additional_vars = Set{SymbolicT}() additional_subs = Dict{SymbolicT, SymbolicT}() # Also need to update schedule @@ -142,21 +142,15 @@ function MTKBase.unhack_system(sys::System) resid = A * x - b for res in resid + SU._iszero(res) && continue + # If a linear SCC contains both `D(w)` and `w_t`, it'll contain the equation `D(w) ~ w_t`. + # When unhacking it, `D(w)` will be totermed into `w_t`. Avoid adding the `0 ~ 0` equations. + # The duplicate variables are automatically removed by the `Set`. + # See https://github.com/SciML/ModelingToolkit.jl/issues/4196 for further details. push!(additional_eqs, Symbolics.COMMON_ZERO ~ res) end end @assert length(additional_eqs) == length(additional_vars) - # If a linear SCC contains both `D(w)` and `w_t`, it'll contain the equation `D(w) ~ w_t`. - # When unhacking it, `D(w)` will be totermed into `w_t`. This, `additional_vars` contains - # two `w_t` and an equation that is `0 ~ 0`. Find the `0 ~ 0` equations, and remove them - # along with the duplicate variables. - # See https://github.com/SciML/ModelingToolkit.jl/issues/4196 for further details. - additional_eqs_mask = trues(length(additional_eqs)) - for (i, eq) in enumerate(additional_eqs) - additional_eqs_mask[i] = !SU._iszero(eq.rhs) - end - additional_eqs = additional_eqs[additional_eqs_mask] - additional_vars = additional_vars[additional_eqs_mask] subst = SU.Substituter{false}(additional_subs, SU.default_substitute_filter) obseqs = obseqs[obs_mask] map!(subst, obseqs, obseqs) @@ -167,7 +161,7 @@ function MTKBase.unhack_system(sys::System) map!(subst, values(sched.dummy_sub)) end - dvs = [unknowns(sys); additional_vars] + dvs = [unknowns(sys); collect(additional_vars)] newsys = @set sys.observed = obseqs @set! newsys.eqs = eqs