Skip to content

Commit fce8f60

Browse files
Merge pull request #73 from JuliaComputing/as/inline-linsolve-arraymaker
refactor: use `ArrayMaker` for inline linear SCCs
2 parents d145e32 + c6b3be0 commit fce8f60

2 files changed

Lines changed: 32 additions & 5 deletions

File tree

lib/ModelingToolkitTearing/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ Setfield = "0.7, 0.8, 1"
3939
SparseArrays = "1"
4040
StateSelection = "1.9"
4141
SymbolicIndexingInterface = "0.3"
42-
SymbolicUtils = "4.3"
43-
Symbolics = "7.15.1"
42+
SymbolicUtils = "4.25"
43+
Symbolics = "7.20"
4444
UUIDs = "1"
4545
julia = "1.10"
4646

lib/ModelingToolkitTearing/src/reassemble.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,6 @@ function get_linear_scc_linsol(state::TearingState, alg_eqs::Vector{Int},
548548
N = length(alg_eqs)
549549
vars = Symbolics.fixpoint_sub(fullvars[alg_vars], total_sub; maxiters = max(length(total_sub), 10))
550550

551-
# Linear coefficients
552551
A = fill(Num(Symbolics.COMMON_ZERO), N, N)
553552
b = fill(Symbolics.COMMON_ZERO, N)
554553

@@ -626,6 +625,33 @@ function get_linear_scc_linsol(state::TearingState, alg_eqs::Vector{Int},
626625
else
627626
reference = fullvars[state_idx]
628627
end
628+
# Use the `ArrayMaker` form for `A` and `b`
629+
A_regions = SU.RegionsT()
630+
A_values = Symbolics.SArgsT()
631+
b_regions = SU.RegionsT()
632+
b_values = Symbolics.SArgsT()
633+
# fill the entire thing with zeros
634+
push!(A_regions, SU.ShapeVecT((1:N, 1:N)))
635+
push!(A_values, SU.Fill(A_regions[1])(Symbolics.COMMON_ZERO))
636+
push!(b_regions, SU.ShapeVecT((1:N,)))
637+
push!(b_values, SU.Fill(b_regions[1])(Symbolics.COMMON_ZERO))
638+
639+
for i in axes(A, 1), j in axes(A, 2)
640+
coeff = unwrap(A[i, j])
641+
SU._iszero(coeff) && continue
642+
push!(A_regions, SU.ShapeVecT((i:i, j:j)))
643+
push!(A_values, Symbolics.SConst([coeff;;]))
644+
end
645+
646+
for (i, resid) in enumerate(b)
647+
SU._iszero(resid) && continue
648+
push!(b_regions, SU.ShapeVecT((i:i,)))
649+
push!(b_values, Symbolics.SConst([resid]))
650+
end
651+
652+
A = SU.ArrayMaker{VartypeT}(A_regions, A_values; shape = SU.ShapeVecT((1:N, 1:N)))
653+
b = SU.ArrayMaker{VartypeT}(b_regions, b_values; shape = SU.ShapeVecT((1:N,)))
654+
629655
reference_args = Symbolics.SArgsT((reference, MTKBase.get_iv(sys)::SymbolicT))
630656
inps = MTKBase.inputs(sys)
631657
if !isempty(inps)
@@ -635,13 +661,14 @@ function get_linear_scc_linsol(state::TearingState, alg_eqs::Vector{Int},
635661
promote, reference_args;
636662
type = Vector{Real}, shape = [1:length(reference_args)]
637663
)[1]
638-
sys, A_cache = MTKBase.add_diffcache(sys, length(A))
664+
sys, A_cache = MTKBase.add_diffcache(sys, N * N)
639665
A_allocator = A_cache(reference)
640666
A = SU.Code.with_allocator(A_allocator, SU.Const{VartypeT}(A))
641-
sys, b_cache = MTKBase.add_diffcache(sys, length(b))
667+
sys, b_cache = MTKBase.add_diffcache(sys, N)
642668
b_allocator = b_cache(reference)
643669
b = SU.Code.with_allocator(b_allocator, SU.Const{VartypeT}(b))
644670
state.sys = sys
671+
645672
return INLINE_LINEAR_SCC_OP(A, b)
646673
end
647674

0 commit comments

Comments
 (0)