Skip to content

Commit 8401e68

Browse files
fix: fix subexpressions_not_involving_vars!
Now handles callable parameters and array variables correctly
1 parent 7dba00e commit 8401e68

2 files changed

Lines changed: 35 additions & 3 deletions

File tree

lib/ModelingToolkitBase/src/utils.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,8 +1293,19 @@ function subexpressions_not_involving_vars!(
12931293
# Use `Iterators.filter` instead of just checking `ir[i]` and returning `()` for
12941294
# type-stability. The early-exit infers as a `Union`.
12951295
is_atomic = SU.default_is_atomic(ir[i])
1296+
# If the operation is symbolic, it is the first `outneighbor`. Symbolic operations
1297+
# are already parameters, we don't want to cache it.
1298+
drop = Moshi.Match.@match ir[i] begin
1299+
BSImpl.Term(; f) && if f isa SymbolicT end => begin
1300+
1
1301+
end
1302+
_ => 0
1303+
end
12961304
# We also don't want to descend into constants - those are not worth caching.
1297-
Iterators.filter(j -> !is_atomic && !SU.isconst(ir[j]), Graphs.outneighbors(graph, i))
1305+
return Iterators.filter(
1306+
j -> !is_atomic && !SU.isconst(ir[j]),
1307+
Iterators.drop(Graphs.outneighbors(graph, i), drop)
1308+
)
12981309
end
12991310
end
13001311
rdfs = SU.RecursiveDFS(
@@ -1311,9 +1322,16 @@ function subexpressions_not_involving_vars!(
13111322
# Walk through the usages of `banned_vars` that are in `unbanned_subexprs`, and remove
13121323
# from the candidates any expression we encounter. The remaining vertices are
13131324
# ones that do not use `banned_vars`.
1314-
unbanned_nbors_fn = let unbanned_subexprs = unbanned_subexprs
1325+
unbanned_nbors_fn = let unbanned_subexprs = unbanned_subexprs, ir = ir
13151326
function __unbanned_nbors_fn(graph, i)
1316-
return Iterators.filter(in(unbanned_subexprs), Graphs.inneighbors(graph, i))
1327+
# If an `inneighbor` is atomic, it means `ir[i]` is an array variable and
1328+
# the neighbor is a scalarized element. We want to cache specific parts of
1329+
# symbolic arrays that are unbanned, and only ban usages of the full symbolic
1330+
# array.
1331+
return Iterators.filter(
1332+
j -> j in unbanned_subexprs && !SU.default_is_atomic(ir[j]),
1333+
Graphs.inneighbors(graph, i)
1334+
)
13171335
end
13181336
end
13191337
rdfs = SU.RecursiveDFS(

test/scc_nonlinear_problem.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ using OrdinaryDiffEqBDF
55
using SciMLBase, Symbolics
66
using StaticArrays
77
using LinearAlgebra, Test
8+
using SymbolicUtils
9+
import ModelingToolkitBase
10+
using Symbolics: SymbolicT, VartypeT, unwrap
811
using ModelingToolkit: t_nounits as t, D_nounits as D
912

1013
@testset "Trivial case" begin
@@ -388,3 +391,14 @@ end
388391
sol = solve(prob)
389392
@test SciMLBase.successful_retcode(sol)
390393
end
394+
395+
@testset "`subexpressions_not_involving_vars!`" begin
396+
@variables x[1:3]
397+
@parameters (f::Function)(..)
398+
ir = IRStructure{VartypeT}()
399+
expr = f([x[1], 2x[2], x[2]^2 + x[3]]) + f(x) + f(x[2] + 1)
400+
state = Dict{SymbolicT, SymbolicT}()
401+
banned_vars = Set{SymbolicT}([x[3], x])
402+
ModelingToolkitBase.subexpressions_not_involving_vars!(ir, [unwrap(expr)], banned_vars, state)
403+
@test issetequal(keys(state), [x[1], 2x[2], x[2]^2, f(x[2] + 1)])
404+
end

0 commit comments

Comments
 (0)