Skip to content

Commit 5d71d96

Browse files
refactor: use cached LinearExpanders
1 parent b3fbd16 commit 5d71d96

2 files changed

Lines changed: 7 additions & 2 deletions

File tree

lib/ModelingToolkitBase/src/systems/codegen.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1446,7 +1446,7 @@ function calculate_A_b(sys::System; sparse = false, throw = true)
14461446
# `linear_expansion` caches values based on `var`. This loop ordering helps
14471447
# avoid invalidating the cache frequently.
14481448
for (j, var) in enumerate(dvs)
1449-
lex = Symbolics.LinearExpander(var; strict = true)
1449+
lex = get_linear_expander_for!(sys, var, true)
14501450
for (i, resid) in enumerate(rhss)
14511451
p, q, islinear = lex(resid)
14521452
# An equation such as `0 ~ 1 + x * y` will return `(x, 1, true)` for `y`.

lib/ModelingToolkitBase/src/systems/nonlinear/initializesystem.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,12 @@ function generate_initializesystem_timevarying(
199199
kwargs...
200200
)
201201
diffcache_params = SU.getmetadata(sys, DiffCacheParams, Dict{SymbolicT, Int}())::Dict{SymbolicT, Int}
202-
isys = SU.setmetadata(isys, DiffCacheParams, diffcache_params)
202+
isys = SU.setmetadata(isys, DiffCacheParams, copy(diffcache_params))
203+
# Reuse `LinearExpander` cache for the initialization system
204+
linear_expansion_cache = check_mutable_cache(sys, LinearExpansionCache, LinearExpansionCacheT, nothing)
205+
if linear_expansion_cache !== nothing
206+
store_to_mutable_cache!(isys, LinearExpansionCache, linear_expansion_cache)
207+
end
203208
return isys
204209
end
205210

0 commit comments

Comments
 (0)