@@ -1263,79 +1263,89 @@ function observed_equations_used_by(
12631263end
12641264
12651265"""
1266- $(TYPEDSIGNATURES)
1267-
1268- Given an expression `expr`, return a dictionary mapping subexpressions of `expr` that do
1269- not involve variables in `vars` to anonymous symbolic variables. Also return the modified
1270- `expr` with the substitutions indicated by the dictionary. If `expr` is a function
1271- of only `vars`, then all of the returned subexpressions can be precomputed.
1272-
1273- Note that this will only process subexpressions floating point value. Additionally,
1274- array variables must be passed in both scalarized and non-scalarized forms in `vars`.
1275- """
1276- function subexpressions_not_involving_vars (expr, vars)
1277- expr = unwrap (expr)
1278- vars = map (unwrap, vars)
1279- state = Dict ()
1280- newexpr = subexpressions_not_involving_vars! (expr, vars, state)
1281- return state, newexpr
1282- end
1283-
1284- """
1285- $(TYPEDSIGNATURES)
1266+ $TYPEDSIGNATURES
12861267
1287- Mutating version of `subexpressions_not_involving_vars` which writes to `state`. Only
1288- returns the modified `expr`.
1268+ Given a list of expressions `exprs`, find all top-level subexpressions in `exprs`
1269+ that do not involve variables in `banned_vars`. "Top-level" implies that for all
1270+ such subexpressions, any parent of theirs in `exprs` will involve something in
1271+ `banned_vars`. `state` will be populated as a map from the identified subexpressions
1272+ to anonymous symbols they can be replaced by.
12891273"""
1290- function subexpressions_not_involving_vars! (expr, vars, state:: Dict{Any, Any} )
1291- expr = unwrap (expr)
1292- if symbolic_type (expr) == NotSymbolic ()
1293- if is_array_of_symbolics (expr)
1294- return map (expr) do el
1295- subexpressions_not_involving_vars! (el, vars, state)
1296- end
1274+ function subexpressions_not_involving_vars! (
1275+ ir:: SU.IRStructure{VartypeT} , exprs:: AbstractArray{SymbolicT} ,
1276+ banned_vars:: Set{SymbolicT} , state:: Dict{SymbolicT, SymbolicT}
1277+ )
1278+ # Populate the IR first to ensure that the `RecursiveDFS` has a correctly sized
1279+ # `visited` buffer.
1280+ for x in exprs
1281+ populate_ir! (ir, x)
1282+ end
1283+ for x in banned_vars
1284+ populate_ir! (ir, x)
1285+ end
1286+ # Get the nodes that are reachable from `exprs`. We need to find the topologically
1287+ # earliest subset of these nodes that do not contain `banned_vars`.
1288+ reachable_nodes = Set {Int} ()
1289+ # We only want to descent into non-atomic nodes. Otherwise e.g. the index `1` in `x[1]`
1290+ # will end up being a subexpression not involving `banned_vars`.
1291+ filtered_nbors = let ir = ir
1292+ function __filtered_nbors (graph, i)
1293+ # Use `Iterators.filter` instead of just checking `ir[i]` and returning `()` for
1294+ # type-stability. The early-exit infers as a `Union`.
1295+ is_atomic = SU. default_is_atomic (ir[i])
1296+ # We also don't want to descend into constants - those are not worth caching.
1297+ Iterators. filter (j -> ! is_atomic && ! SU. isconst (ir[j]), Graphs. outneighbors (graph, i))
12971298 end
1298- return expr
1299- end
1300- any (isequal (expr), vars) && return expr
1301- iscall (expr) || return expr
1302- symbolic_has_known_size (expr) || return expr
1303- haskey (state, expr) && return state[expr]
1304- op = operation (expr)
1305- args = arguments (expr)
1306- # if this is a `getindex` and the getindex-ed value is a `Sym`
1307- # or it is not a called parameter
1308- # OR
1309- # none of `vars` are involved in `expr`
1310- if op === getindex && (issym (args[1 ]) || ! iscalledparameter (args[1 ])) ||
1311- (vs = SU. search_variables (expr); intersect! (vs, vars); isempty (vs))
1312- sym = gensym (:subexpr )
1313- var = similar_variable (expr, sym)
1314- state[expr] = var
1315- return var
1316- end
1317-
1318- if (op == (+ ) || op == (* )) && symbolic_type (expr) != = ArraySymbolic ()
1319- indep_args = SymbolicT[]
1320- dep_args = SymbolicT[]
1321- for arg in args
1322- _vs = SU. search_variables (arg)
1323- intersect! (_vs, vars)
1324- if ! isempty (_vs)
1325- push! (dep_args, subexpressions_not_involving_vars! (arg, vars, state))
1326- else
1327- push! (indep_args, arg)
1328- end
1299+ end
1300+ rdfs = SU. RecursiveDFS (
1301+ ir. dependency_graph; neighbors_fn = filtered_nbors,
1302+ on_exit = Base. Fix1 (push!, reachable_nodes)
1303+ )
1304+ for x in exprs
1305+ rdfs (ir[x])
1306+ end
1307+ # We want to retain the reachability information for later
1308+ unbanned_subexprs = copy (reachable_nodes)
1309+ # Walk through the usages of `banned_vars` that are in `unbanned_subexprs`, and remove
1310+ # from the candidates any expression we encounter. The remaining vertices are
1311+ # ones that do not use `banned_vars`.
1312+ unbanned_nbors_fn = let unbanned_subexprs = unbanned_subexprs
1313+ function __unbanned_nbors_fn (graph, i)
1314+ return Iterators. filter (in (unbanned_subexprs), Graphs. inneighbors (graph, i))
13291315 end
1330- indep_term = reduce (op, indep_args; init = Int (op == (* )))
1331- indep_term = subexpressions_not_involving_vars! (indep_term, vars, state)
1332- dep_term = reduce (op, dep_args; init = Int (op == (* )))
1333- return op (indep_term, dep_term)
13341316 end
1335- newargs = map (args) do arg
1336- subexpressions_not_involving_vars! (arg, vars, state)
1317+ rdfs = SU. RecursiveDFS (
1318+ ir. dependency_graph;
1319+ neighbors_fn = unbanned_nbors_fn,
1320+ on_exit = Base. Fix1 (delete!, unbanned_subexprs)
1321+ )
1322+ for var in banned_vars
1323+ rdfs (ir[var])
1324+ end
1325+
1326+ # Now, the nodes we actually care about and will populate `state` with are ones
1327+ # present in `unbanned_subexprs` and are used by a node in
1328+ # `setdiff(reachable_nodes, unbanned_subexprs)`. All of `setdiff!`, the subsequent
1329+ # `filter!`, and then populating `state` will iterate over `unbanned_subexprs`. We
1330+ # might as well combine this into one loop.
1331+ for node in unbanned_subexprs
1332+ nbors = Graphs. inneighbors (ir. dependency_graph, node)
1333+ is_valid_node = false
1334+ for nbor in nbors
1335+ nbor in reachable_nodes || continue
1336+ nbor in unbanned_subexprs && continue
1337+ is_valid_node = true
1338+ break
1339+ end
1340+ is_valid_node || continue
1341+
1342+ expr = ir[node]
1343+ anon_sym = Symbolics. SSym (
1344+ Symbol (:__cached_ , length (state));
1345+ type = SU. symtype (expr), shape = SU. shape (expr)
1346+ )
1347+ state[expr] = anon_sym
13371348 end
1338- return maketerm (typeof (expr), op, newargs, metadata (expr))
13391349end
13401350
13411351"""
0 commit comments