Skip to content

Commit b7ac903

Browse files
Merge pull request #4518 from SciML/as/scc-nlp-no-obs
refactor: use `full_equations` in `SCCNonlinearProblem`
2 parents 2442d62 + 75b7faf commit b7ac903

2 files changed

Lines changed: 109 additions & 148 deletions

File tree

lib/ModelingToolkitBase/src/utils.jl

Lines changed: 82 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,79 +1263,94 @@ function observed_equations_used_by(
12631263
end
12641264

12651265
"""
1266-
$(TYPEDSIGNATURES)
1267-
1268-
Given an expression `expr`, return a dictionary mapping subexpressions of `expr` that do
1269-
not involve variables in `vars` to anonymous symbolic variables. Also return the modified
1270-
`expr` with the substitutions indicated by the dictionary. If `expr` is a function
1271-
of only `vars`, then all of the returned subexpressions can be precomputed.
1272-
1273-
Note that this will only process subexpressions floating point value. Additionally,
1274-
array variables must be passed in both scalarized and non-scalarized forms in `vars`.
1275-
"""
1276-
function subexpressions_not_involving_vars(expr, vars)
1277-
expr = unwrap(expr)
1278-
vars = map(unwrap, vars)
1279-
state = Dict()
1280-
newexpr = subexpressions_not_involving_vars!(expr, vars, state)
1281-
return state, newexpr
1282-
end
1283-
1284-
"""
1285-
$(TYPEDSIGNATURES)
1266+
$TYPEDSIGNATURES
12861267
1287-
Mutating version of `subexpressions_not_involving_vars` which writes to `state`. Only
1288-
returns the modified `expr`.
1268+
Given a list of expressions `exprs`, find all top-level subexpressions in `exprs`
1269+
that do not involve variables in `banned_vars`. "Top-level" implies that for all
1270+
such subexpressions, any parent of theirs in `exprs` will involve something in
1271+
`banned_vars`. `state` will be populated as a map from the identified subexpressions
1272+
to anonymous symbols they can be replaced by.
12891273
"""
1290-
function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
1291-
expr = unwrap(expr)
1292-
if symbolic_type(expr) == NotSymbolic()
1293-
if is_array_of_symbolics(expr)
1294-
return map(expr) do el
1295-
subexpressions_not_involving_vars!(el, vars, state)
1296-
end
1274+
function subexpressions_not_involving_vars!(
1275+
ir::SU.IRStructure{VartypeT}, exprs::AbstractArray{SymbolicT},
1276+
banned_vars::Set{SymbolicT}, state::Dict{SymbolicT, SymbolicT}
1277+
)
1278+
# Populate the IR first to ensure that the `RecursiveDFS` has a correctly sized
1279+
# `visited` buffer.
1280+
for x in exprs
1281+
populate_ir!(ir, x)
1282+
end
1283+
for x in banned_vars
1284+
populate_ir!(ir, x)
1285+
end
1286+
# Get the nodes that are reachable from `exprs`. We need to find the topologically
1287+
# earliest subset of these nodes that do not contain `banned_vars`.
1288+
reachable_nodes = Set{Int}()
1289+
# We only want to descent into non-atomic nodes. Otherwise e.g. the index `1` in `x[1]`
1290+
# will end up being a subexpression not involving `banned_vars`.
1291+
filtered_nbors = let ir = ir
1292+
function __filtered_nbors(graph, i)
1293+
# Use `Iterators.filter` instead of just checking `ir[i]` and returning `()` for
1294+
# type-stability. The early-exit infers as a `Union`.
1295+
is_atomic = SU.default_is_atomic(ir[i])
1296+
# We also don't want to descend into constants - those are not worth caching.
1297+
Iterators.filter(j -> !is_atomic && !SU.isconst(ir[j]), Graphs.outneighbors(graph, i))
12971298
end
1298-
return expr
1299-
end
1300-
any(isequal(expr), vars) && return expr
1301-
iscall(expr) || return expr
1302-
symbolic_has_known_size(expr) || return expr
1303-
haskey(state, expr) && return state[expr]
1304-
op = operation(expr)
1305-
args = arguments(expr)
1306-
# if this is a `getindex` and the getindex-ed value is a `Sym`
1307-
# or it is not a called parameter
1308-
# OR
1309-
# none of `vars` are involved in `expr`
1310-
if op === getindex && (issym(args[1]) || !iscalledparameter(args[1])) ||
1311-
(vs = SU.search_variables(expr); intersect!(vs, vars); isempty(vs))
1312-
sym = gensym(:subexpr)
1313-
var = similar_variable(expr, sym)
1314-
state[expr] = var
1315-
return var
1316-
end
1317-
1318-
if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic()
1319-
indep_args = SymbolicT[]
1320-
dep_args = SymbolicT[]
1321-
for arg in args
1322-
_vs = SU.search_variables(arg)
1323-
intersect!(_vs, vars)
1324-
if !isempty(_vs)
1325-
push!(dep_args, subexpressions_not_involving_vars!(arg, vars, state))
1326-
else
1327-
push!(indep_args, arg)
1328-
end
1299+
end
1300+
rdfs = SU.RecursiveDFS(
1301+
ir.dependency_graph; neighbors_fn = filtered_nbors,
1302+
on_exit = Base.Fix1(push!, reachable_nodes)
1303+
)
1304+
for x in exprs
1305+
for idx in ir.weak_definitions[x]
1306+
rdfs(idx)
1307+
end
1308+
end
1309+
# We want to retain the reachability information for later
1310+
unbanned_subexprs = copy(reachable_nodes)
1311+
# Walk through the usages of `banned_vars` that are in `unbanned_subexprs`, and remove
1312+
# from the candidates any expression we encounter. The remaining vertices are
1313+
# ones that do not use `banned_vars`.
1314+
unbanned_nbors_fn = let unbanned_subexprs = unbanned_subexprs
1315+
function __unbanned_nbors_fn(graph, i)
1316+
return Iterators.filter(in(unbanned_subexprs), Graphs.inneighbors(graph, i))
1317+
end
1318+
end
1319+
rdfs = SU.RecursiveDFS(
1320+
ir.dependency_graph;
1321+
neighbors_fn = unbanned_nbors_fn,
1322+
on_exit = Base.Fix1(delete!, unbanned_subexprs)
1323+
)
1324+
for var in banned_vars
1325+
for idx in ir.weak_definitions[var]
1326+
rdfs(idx)
13291327
end
1330-
indep_term = reduce(op, indep_args; init = Int(op == (*)))
1331-
indep_term = subexpressions_not_involving_vars!(indep_term, vars, state)
1332-
dep_term = reduce(op, dep_args; init = Int(op == (*)))
1333-
return op(indep_term, dep_term)
13341328
end
1335-
newargs = map(args) do arg
1336-
subexpressions_not_involving_vars!(arg, vars, state)
1329+
1330+
# Now, the nodes we actually care about and will populate `state` with are ones
1331+
# present in `unbanned_subexprs` and are used by a node in
1332+
# `setdiff(reachable_nodes, unbanned_subexprs)`. All of `setdiff!`, the subsequent
1333+
# `filter!`, and then populating `state` will iterate over `unbanned_subexprs`. We
1334+
# might as well combine this into one loop.
1335+
for node in unbanned_subexprs
1336+
nbors = Graphs.inneighbors(ir.dependency_graph, node)
1337+
is_valid_node = false
1338+
for nbor in nbors
1339+
nbor in reachable_nodes || continue
1340+
nbor in unbanned_subexprs && continue
1341+
is_valid_node = true
1342+
break
1343+
end
1344+
is_valid_node || continue
1345+
1346+
expr = ir[node]
1347+
haskey(state, expr) && continue
1348+
anon_sym = Symbolics.SSym(
1349+
Symbol(:__cached_, length(state));
1350+
type = SU.symtype(expr), shape = SU.shape(expr)
1351+
)
1352+
state[expr] = anon_sym
13371353
end
1338-
return maketerm(typeof(expr), op, newargs, metadata(expr))
13391354
end
13401355

13411356
"""

0 commit comments

Comments
 (0)