Skip to content

Commit 517d0f1

Browse files
Merge pull request #4547 from SciML/as/cache-linear-expander
feat: allow caching `LinearExpander`s in `System`
2 parents 2451b8e + 5d71d96 commit 517d0f1

4 files changed

Lines changed: 39 additions & 3 deletions

File tree

lib/ModelingToolkitBase/src/systems/abstractsystem.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,12 @@ Invalidate cached jacobians, etc.
998998
"""
999999
function invalidate_cache!(sys::AbstractSystem)
10001000
has_metadata(sys) || return sys
1001-
empty!(getmetadata(sys, MutableCacheKey, nothing))
1001+
cache = getmetadata(sys, MutableCacheKey, nothing)
1002+
if cache isa MutableCacheT
1003+
# Avoid clearing the linear expansion cache. It doesn't depend on anything in the
1004+
# system, just useful to have around.
1005+
filter!(Base.Fix2(===, LinearExpansionCache) first, cache)
1006+
end
10021007
return sys
10031008
end
10041009

lib/ModelingToolkitBase/src/systems/codegen.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ function calculate_A_b(sys::System; sparse = false, throw = true)
14501450
# `linear_expansion` caches values based on `var`. This loop ordering helps
14511451
# avoid invalidating the cache frequently.
14521452
for (j, var) in enumerate(dvs)
1453-
lex = Symbolics.LinearExpander(var; strict = true)
1453+
lex = get_linear_expander_for!(sys, var, true)
14541454
for (i, resid) in enumerate(rhss)
14551455
p, q, islinear = lex(resid)
14561456
# An equation such as `0 ~ 1 + x * y` will return `(x, 1, true)` for `y`.

lib/ModelingToolkitBase/src/systems/nonlinear/initializesystem.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,12 @@ function generate_initializesystem_timevarying(
199199
kwargs...
200200
)
201201
diffcache_params = SU.getmetadata(sys, DiffCacheParams, Dict{SymbolicT, Int}())::Dict{SymbolicT, Int}
202-
isys = SU.setmetadata(isys, DiffCacheParams, diffcache_params)
202+
isys = SU.setmetadata(isys, DiffCacheParams, copy(diffcache_params))
203+
# Reuse `LinearExpander` cache for the initialization system
204+
linear_expansion_cache = check_mutable_cache(sys, LinearExpansionCache, LinearExpansionCacheT, nothing)
205+
if linear_expansion_cache !== nothing
206+
store_to_mutable_cache!(isys, LinearExpansionCache, linear_expansion_cache)
207+
end
203208
return isys
204209
end
205210

lib/ModelingToolkitBase/src/systems/system.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,3 +1627,29 @@ function should_invalidate_mutable_cache_entry(
16271627
)
16281628
return true
16291629
end
1630+
1631+
abstract type LinearExpansionCache end
1632+
1633+
const LinearExpansionCacheT = Dict{Tuple{SymbolicT, Bool}, Symbolics.LinearExpander}
1634+
1635+
function should_invalidate_mutable_cache_entry(
1636+
::Type{LinearExpansionCache}, @nospecialize(patch::NamedTuple)
1637+
)
1638+
return false
1639+
end
1640+
1641+
"""
1642+
$TYPEDSIGNATURES
1643+
1644+
Get a cached `Symbolics.LinearExpander` for variable `var` and strictness `strict`. This
1645+
allows reusing significant cached information between `LinearExpander` calls. This cache
1646+
is never invalidated.
1647+
"""
1648+
function get_linear_expander_for!(sys::System, var::SymbolicT, strict::Bool)
1649+
cache = check_mutable_cache(sys, LinearExpansionCache, LinearExpansionCacheT, nothing)
1650+
if cache === nothing
1651+
cache = LinearExpansionCacheT()
1652+
store_to_mutable_cache!(sys, LinearExpansionCache, cache)
1653+
end
1654+
return get!(() -> Symbolics.LinearExpander(var; strict), cache, (var, strict))
1655+
end

0 commit comments

Comments
 (0)