Skip to content

Commit 95e9bff

Browse files
Merge pull request #4501 from SciML/as/fix-inputs-to-params
fix: use `rm_eqs_vars!` in `inputs_to_parameters!`
2 parents 9eb8f8f + 3693632 commit 95e9bff

1 file changed

Lines changed: 14 additions & 45 deletions

File tree

src/systems/systemstructure.jl

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -64,54 +64,23 @@ end
6464
Turn input variables into parameters of the system.
6565
"""
6666
function inputs_to_parameters!(state::TearingState, inputsyms::OrderedSet{SymbolicT}, outputsyms::OrderedSet{SymbolicT})
67-
@unpack structure, fullvars, sys = state
68-
@unpack var_to_diff, graph, solvable_graph = structure
69-
@assert solvable_graph === nothing
67+
(; sys, fullvars) = state
7068

71-
var_reidx = zeros(Int, length(fullvars))
72-
nvar = 0
73-
new_fullvars = SymbolicT[]
74-
for (i, v) in enumerate(fullvars)
75-
if v in inputsyms
76-
if var_to_diff[i] !== nothing
77-
error("Input $(fullvars[i]) is differentiated!")
78-
end
79-
var_reidx[i] = -1
80-
else
81-
nvar += 1
82-
var_reidx[i] = nvar
83-
push!(new_fullvars, v)
84-
end
85-
end
86-
ninputs = length(inputsyms)
87-
@set! sys.inputs = inputsyms
88-
@set! sys.outputs = outputsyms
89-
if ninputs == 0
69+
if isempty(inputsyms)
70+
@set! sys.inputs = inputsyms
71+
@set! sys.outputs = outputsyms
9072
state.sys = sys
9173
return state
9274
end
9375

94-
nvars = ndsts(graph) - ninputs
95-
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
96-
97-
for ie in 1:nsrcs(graph)
98-
for iv in 𝑠neighbors(graph, ie)
99-
iv = var_reidx[iv]
100-
iv > 0 || continue
101-
add_edge!(new_graph, ie, iv)
102-
end
103-
end
104-
105-
new_var_to_diff = StateSelection.DiffGraph(nvars, true)
106-
for (i, v) in enumerate(var_to_diff)
107-
new_i = var_reidx[i]
108-
(new_i < 1 || v === nothing) && continue
109-
new_v = var_reidx[v]
110-
@assert new_v > 0
111-
new_var_to_diff[new_i] = new_v
76+
vars_to_rm = Int[]
77+
for (i, v) in enumerate(fullvars)
78+
v in inputsyms && push!(vars_to_rm, i)
11279
end
113-
@set! structure.var_to_diff = complete(new_var_to_diff)
114-
@set! structure.graph = complete(new_graph)
80+
StateSelection.rm_eqs_vars!(
81+
state, Int[], vars_to_rm; eqs_sorted_and_uniqued = true,
82+
vars_sorted_and_uniqued = true
83+
)
11584

11685
binds = copy(parent(bindings(sys)))
11786
for var in inputsyms
@@ -120,11 +89,11 @@ function inputs_to_parameters!(state::TearingState, inputsyms::OrderedSet{Symbol
12089
@set! sys.unknowns = setdiff(unknowns(sys), inputsyms)
12190
ps = copy(parameters(sys))
12291
append!(ps, inputsyms)
92+
@set! sys.inputs = inputsyms
93+
@set! sys.outputs = outputsyms
12394
@set! sys.ps = ps
12495
@set! sys.bindings = ROSymmapT(binds)
125-
@set! state.sys = sys
126-
@set! state.fullvars = Vector{SymbolicT}(new_fullvars)
127-
@set! state.structure = structure
96+
state.sys = sys
12897
return state
12998
end
13099

0 commit comments

Comments
 (0)