Skip to content

Commit 709e668

Browse files
Merge pull request #4497 from SciML/as/alias-elim-irreducibles
fix: do not eliminate positive state priority variables in alias elimination
2 parents 5f594d5 + 78a0948 commit 709e668

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

src/systems/alias_elimination.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,12 @@ unknowns, so one of them is chosen as the target when the group contains any. Ot
4848
the variable with the highest `state_priority` wins.
4949
"""
5050
function pick_alias_target(
51-
fullvars::Vector{SymbolicT}, group_vars::Vector{Int}, state_priorities
51+
fullvars::Vector{SymbolicT}, group_vars::Vector{Int}, state_priorities, irreducibles::AtomicSetT
52+
)
53+
irr_idx = findfirst(
54+
Base.Fix1(contains_possibly_indexed_element, irreducibles) Base.Fix1(getindex, fullvars),
55+
group_vars
5256
)
53-
irr_idx = findfirst(v -> isirreducible(fullvars[v]), group_vars)
5457
irr_idx === nothing || return group_vars[irr_idx]
5558
_, target_idx = findmax(Base.Fix1(getindex, state_priorities), group_vars)
5659
return group_vars[target_idx]
@@ -69,14 +72,15 @@ function find_perfect_aliases!(
6972
state::TearingState, eqs_to_rm::Vector{Int}, vars_to_rm::Vector{Int}
7073
)
7174
(; sys, fullvars, structure) = state
72-
(; graph, solvable_graph, var_to_diff) = structure
75+
(; graph, solvable_graph, var_to_diff, state_priorities) = structure
7376

7477
@assert solvable_graph === nothing
7578
diff_to_var = invview(var_to_diff)
7679
aliases = Dict{Int, Int}()
7780
subs = Dict{SymbolicT, SymbolicT}()
7881
eqs = collect(equations(state))
7982
original_eqs = state.original_eqs
83+
irreducibles = get_irreducibles(sys)
8084

8185
# Not `IntDisjointSet` because we don't want singleton sets for every single variable
8286
alias_groups = DisjointSet{Int}()
@@ -118,7 +122,7 @@ function find_perfect_aliases!(
118122

119123
group_target = Dict{Int, Int}()
120124
for (root, group_vars) in alias_sets
121-
group_target[root] = pick_alias_target(fullvars, group_vars, state.structure.state_priorities)
125+
group_target[root] = pick_alias_target(fullvars, group_vars, state_priorities, irreducibles)
122126
end
123127

124128
# Queue an alias equation for removal only if both endpoints collapse onto the
@@ -128,8 +132,8 @@ function find_perfect_aliases!(
128132
# rewrites the kept equation into `I ~ T` form automatically.
129133
for (ieq, v1, v2) in candidate_eqs
130134
target = group_target[DataStructures.find_root!(alias_groups, v1)]
131-
c1 = isirreducible(fullvars[v1]) ? v1 : target
132-
c2 = isirreducible(fullvars[v2]) ? v2 : target
135+
c1 = contains_possibly_indexed_element(irreducibles, fullvars[v1]) || state_priorities[v1] > 0 ? v1 : target
136+
c2 = contains_possibly_indexed_element(irreducibles, fullvars[v2]) || state_priorities[v2] > 0 ? v2 : target
133137
c1 == c2 && push!(eqs_to_rm, ieq)
134138
end
135139

@@ -140,10 +144,11 @@ function find_perfect_aliases!(
140144
v == target && continue
141145
# Irreducibles other than the target stay as unknowns; only non-irreducibles
142146
# are eliminated in favor of the target.
143-
if isirreducible(fullvars[v])
147+
if contains_possibly_indexed_element(irreducibles, fullvars[v]) || state_priorities[v] > 0
144148
state.always_present[v] = true
145149
continue
146150
end
151+
147152
push!(vars_to_rm, v)
148153
subs[fullvars[v]] = fullvars[target]
149154
push!(state.additional_observed, fullvars[v] ~ fullvars[target])

0 commit comments

Comments
 (0)