@@ -548,7 +548,6 @@ function get_linear_scc_linsol(state::TearingState, alg_eqs::Vector{Int},
548548 N = length (alg_eqs)
549549 vars = Symbolics. fixpoint_sub (fullvars[alg_vars], total_sub; maxiters = max (length (total_sub), 10 ))
550550
551- # Linear coefficients
552551 A = fill (Num (Symbolics. COMMON_ZERO), N, N)
553552 b = fill (Symbolics. COMMON_ZERO, N)
554553
@@ -626,6 +625,33 @@ function get_linear_scc_linsol(state::TearingState, alg_eqs::Vector{Int},
626625 else
627626 reference = fullvars[state_idx]
628627 end
628+ # Use the `ArrayMaker` form for `A` and `b`
629+ A_regions = SU. RegionsT ()
630+ A_values = Symbolics. SArgsT ()
631+ b_regions = SU. RegionsT ()
632+ b_values = Symbolics. SArgsT ()
633+ # fill the entire thing with zeros
634+ push! (A_regions, SU. ShapeVecT ((1 : N, 1 : N)))
635+ push! (A_values, SU. Fill (A_regions[1 ])(Symbolics. COMMON_ZERO))
636+ push! (b_regions, SU. ShapeVecT ((1 : N,)))
637+ push! (b_values, SU. Fill (b_regions[1 ])(Symbolics. COMMON_ZERO))
638+
639+ for i in axes (A, 1 ), j in axes (A, 2 )
640+ coeff = unwrap (A[i, j])
641+ SU. _iszero (coeff) && continue
642+ push! (A_regions, SU. ShapeVecT ((i: i, j: j)))
643+ push! (A_values, Symbolics. SConst ([coeff;;]))
644+ end
645+
646+ for (i, resid) in enumerate (b)
647+ SU. _iszero (resid) && continue
648+ push! (b_regions, SU. ShapeVecT ((i: i,)))
649+ push! (b_values, Symbolics. SConst ([resid]))
650+ end
651+
652+ A = SU. ArrayMaker {VartypeT} (A_regions, A_values; shape = SU. ShapeVecT ((1 : N, 1 : N)))
653+ b = SU. ArrayMaker {VartypeT} (b_regions, b_values; shape = SU. ShapeVecT ((1 : N,)))
654+
629655 reference_args = Symbolics. SArgsT ((reference, MTKBase. get_iv (sys):: SymbolicT ))
630656 inps = MTKBase. inputs (sys)
631657 if ! isempty (inps)
@@ -635,13 +661,14 @@ function get_linear_scc_linsol(state::TearingState, alg_eqs::Vector{Int},
635661 promote, reference_args;
636662 type = Vector{Real}, shape = [1 : length (reference_args)]
637663 )[1 ]
638- sys, A_cache = MTKBase. add_diffcache (sys, length (A) )
664+ sys, A_cache = MTKBase. add_diffcache (sys, N * N )
639665 A_allocator = A_cache (reference)
640666 A = SU. Code. with_allocator (A_allocator, SU. Const {VartypeT} (A))
641- sys, b_cache = MTKBase. add_diffcache (sys, length (b) )
667+ sys, b_cache = MTKBase. add_diffcache (sys, N )
642668 b_allocator = b_cache (reference)
643669 b = SU. Code. with_allocator (b_allocator, SU. Const {VartypeT} (b))
644670 state. sys = sys
671+
645672 return INLINE_LINEAR_SCC_OP (A, b)
646673end
647674
0 commit comments