Skip to content

Commit 70eb739

Browse files
Merge pull request #61 from JuliaComputing/as/unused-var-interface
feat: add `is_unused_var` interface function
2 parents 4e1e3f1 + 4bdcce4 commit 70eb739

2 files changed

Lines changed: 17 additions & 3 deletions

File tree

src/tearing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function trivial_tearing!(ts::TransformationState, mm::Union{SparseMatrixCLIL, N
131131
end
132132
torn_vars_set = BitSet(torn_vars_idxs)
133133
perm = Int[]
134-
for (var, eq) in zip(torn_vars_idxs, torn_eqs_idxs)
134+
for (var, eq) in zip(matched_vars, trivial_idxs)
135135
ieq = get(linear_eqs, eq, 0)
136136
iszero(ieq) && continue
137137
# `trivial_tearing!` considers equations already written by the user in a

src/utils.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,24 @@ function singular_check(state::TransformationState)
9494
unassigned_var = eltype(get_fullvars(state))[]
9595
for (vj, eq) in enumerate(extended_var_eq_matching)
9696
vj > nvars && break
97-
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
97+
if eq === unassigned && !is_unused_var(state, vj)
9898
push!(unassigned_var, fullvars[vj])
9999
end
100100
end
101101
return unassigned_var
102102

103103
end
104104

105+
"""
106+
$TYPEDSIGNATURES
107+
108+
Check if the variable `var` is unsed by any equation in `state`. By default, checks
109+
the incidence graph.
110+
"""
111+
function is_unused_var(state::TransformationState, var::Integer)
112+
return isempty(𝑑neighbors(state.structure.graph, var))
113+
end
114+
105115
"""
106116
$(TYPEDSIGNATURES)
107117
@@ -117,7 +127,7 @@ function check_consistency(state::TransformationState, orig_inputs; nothrow = fa
117127
n_highest_vars = 0
118128
for (v, h) in enumerate(highest_vars)
119129
h || continue
120-
isempty(𝑑neighbors(graph, v)) && continue
130+
is_unused_var(state, v) && continue
121131
n_highest_vars += 1
122132
end
123133
is_balanced = n_highest_vars == neqs
@@ -335,6 +345,10 @@ function get_new_mm(
335345
for i in Iterators.drop(eachindex(indices), 1)
336346
if new_row_col_i[indices[i]] == new_row_col_i[indices[i - 1]]
337347
final_row_vals[end] += new_row_val_i[indices[i]]
348+
if iszero(final_row_vals[end])
349+
pop!(final_row_cols)
350+
pop!(final_row_vals)
351+
end
338352
else
339353
push!(final_row_cols, new_row_col_i[indices[i]])
340354
push!(final_row_vals, new_row_val_i[indices[i]])

0 commit comments

Comments
 (0)