1515level === nothing ? v : (v => level)
1616end
1717
18- function structural_singularity_removal! (state:: TransformationState ;
19- variable_underconstrained! = force_var_to_zero!, kwargs... )
18+ function structural_singularity_removal! (
19+ state:: TransformationState , :: Val{ReturnPivots} = Val {false} ();
20+ variable_underconstrained! = force_var_to_zero!, kwargs...
21+ ) where {ReturnPivots}
2022 mm = linear_subsys_adjmat! (state; kwargs... )
2123 if size (mm, 1 ) == 0
22- return mm # No linear subsystems
24+ if ReturnPivots
25+ return mm, PivotInfo (0 , 0 , Int[])
26+ else
27+ return mm # No linear subsystems
28+ end
2329 end
2430
2531 (; graph, var_to_diff, solvable_graph) = state. structure
26- mm = structural_singularity_removal! (state, mm; variable_underconstrained!)
32+ mm, pivotinfo = structural_singularity_removal! (state, mm, Val {true} () ; variable_underconstrained!)
2733 s = state. structure
2834 for (ei, e) in enumerate (mm. nzrows)
2935 set_neighbors! (s. graph, e, mm. row_cols[ei])
@@ -34,7 +40,11 @@ function structural_singularity_removal!(state::TransformationState;
3440 end
3541 end
3642
37- return mm
43+ if ReturnPivots
44+ return mm, pivotinfo
45+ else
46+ return mm
47+ end
3848end
3949
4050# For debug purposes
@@ -55,7 +65,7 @@ the `constraint`.
5565@inline function find_first_linear_variable (M:: SparseMatrixCLIL ,
5666 range,
5767 mask,
58- constraint)
68+ constraint, :: Nothing = nothing )
5969 eadj = M. row_cols
6070 @inbounds for i in range
6171 vertices = eadj[i]
@@ -70,10 +80,33 @@ the `constraint`.
7080 return nothing
7181end
7282
83+ @inline function find_first_linear_variable (
84+ M:: SparseMatrixCLIL ,
85+ range,
86+ mask,
87+ constraint, var_priorities:: AbstractVector{Int}
88+ )
89+ eadj = M. row_cols
90+ @inbounds for i in range
91+ vertices = eadj[i]
92+ constraint (length (vertices)) || continue
93+ candidate_v = 0
94+ candidate_val = 0
95+ for (j, v) in enumerate (vertices)
96+ mask === nothing || mask[v] || continue
97+ iszero (candidate_v) || var_priorities[v] < var_priorities[candidate_v] || continue
98+ candidate_v = v
99+ candidate_val = M. row_vals[i][j]
100+ end
101+ iszero (candidate_v) || return CartesianIndex (i, candidate_v), candidate_val
102+ end
103+ return nothing
104+ end
105+
73106@inline function find_first_linear_variable (M:: AbstractMatrix ,
74107 range,
75108 mask,
76- constraint)
109+ constraint, :: Nothing = nothing )
77110 @inbounds for i in range
78111 row = @view M[i, :]
79112 if constraint (count (! iszero, row))
87120 return nothing
88121end
89122
90- function find_masked_pivot (variables, M, k)
91- r = find_first_linear_variable (M, k: size (M, 1 ), variables, isequal (1 ))
123+ @inline function find_first_linear_variable (
124+ M:: AbstractMatrix ,
125+ range,
126+ mask,
127+ constraint, var_priorities:: AbstractVector{Int}
128+ )
129+ @inbounds for i in range
130+ row = @view M[i, :]
131+ constraint (count (! iszero, row)) || continue
132+ candidate_v = 0
133+ candidate_val = 0
134+ for (v, val) in enumerate (row)
135+ mask === nothing || mask[v] || continue
136+ if iszero (candidate_v) || var_priorities[v] < var_priorities[candidate_v]
137+ candidate_v = v
138+ candidate_val = val
139+ end
140+ end
141+ iszero (candidate_v) && return nothing
142+ return CartesianIndex (i, candidate_v), candidate_val
143+ end
144+ return nothing
145+ end
146+
147+ function find_masked_pivot (variables, M, k, var_priorities)
148+ r = find_first_linear_variable (M, k: size (M, 1 ), variables, isequal (1 ), var_priorities)
92149 r != = nothing && return r
93- r = find_first_linear_variable (M, k: size (M, 1 ), variables, isequal (2 ))
150+ r = find_first_linear_variable (M, k: size (M, 1 ), variables, isequal (2 ), var_priorities )
94151 r != = nothing && return r
95- r = find_first_linear_variable (M, k: size (M, 1 ), variables, _ -> true )
152+ r = find_first_linear_variable (M, k: size (M, 1 ), variables, _ -> true , var_priorities )
96153 return r
97154end
98155
@@ -207,14 +264,15 @@ function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
207264 end
208265 end
209266 solvable_variables = findall (is_linear_variables)
267+ var_priorities = has_state_priorities (structure) ? get_state_priorities (structure) : nothing
210268
211269 local bar
212270 try
213- bar = do_bareiss! (mm, mm_orig, is_linear_variables, is_highest_diff)
271+ bar = do_bareiss! (mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities )
214272 catch e
215273 e isa OverflowError || rethrow (e)
216274 mm = convert (SparseMatrixCLIL{BigInt, Ti}, mm_orig)
217- bar = do_bareiss! (mm, mm_orig, is_linear_variables, is_highest_diff)
275+ bar = do_bareiss! (mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities )
218276 end
219277
220278 # This phrasing infers the return type as `Union{Tuple{...}}` instead of
243301(s:: SyncedSwapRows{Nothing} )(M, i:: Int , j:: Int ) = Base. swaprows! (M, i, j)
244302(s:: SyncedSwapRows )(M, i:: Int , j:: Int ) = (Base. swaprows! (s. Mold, i, j); Base. swaprows! (M, i, j))
245303
304+ """
305+ $TYPEDEF
306+
307+ Lazy `&&` of two boolean masks. Only implements whatever is required for `find_masked_pivot`.
308+ """
309+ struct LazyMaskAnd{V1 <: AbstractVector{Bool} , V2 <: AbstractVector{Bool} }
310+ mask1:: V1
311+ mask2:: V2
312+ end
313+
314+ Base. getindex (lma:: LazyMaskAnd , i:: Integer ) = lma. mask1[i] && lma. mask2[i]
315+
246316"""
247317 $(TYPEDEF)
248318
@@ -253,12 +323,21 @@ Mutable state threaded through the Bareiss factorization callbacks.
253323- `pivots`: accumulates the column index of every pivot chosen during elimination.
254324- `is_linear_variables`/`is_highest_diff`: masks used for the tiered pivot search.
255325"""
256- mutable struct BareissContext{V1 <: AbstractVector{Bool} , V2 <: AbstractVector{Bool} }
326+ mutable struct BareissContext{V1 <: AbstractVector{Bool} , V2 <: AbstractVector{Bool} , P <: Union{Nothing, AbstractVector{Int}} }
257327 rank1:: Union{Nothing, Int}
258328 rank2:: Union{Nothing, Int}
259329 pivots:: Vector{Int}
260330 is_linear_variables:: V1
261331 is_highest_diff:: V2
332+ valid_pivot_mask:: BitVector
333+ var_priorities:: P
334+ end
335+
336+ function BareissContext (is_linear_variables, is_highest_diff, var_priorities = nothing )
337+ return BareissContext (
338+ nothing , nothing , Int[], is_linear_variables, is_highest_diff,
339+ trues (length (is_linear_variables)), var_priorities
340+ )
262341end
263342
264343"""
@@ -273,15 +352,17 @@ The column index of every selected pivot is appended to `ctx.pivots`.
273352"""
274353function (ctx:: BareissContext )(M, k:: Int )
275354 if ctx. rank1 === nothing
276- r = find_masked_pivot (ctx. is_linear_variables, M, k)
355+ mask = LazyMaskAnd (ctx. is_linear_variables, ctx. valid_pivot_mask)
356+ r = find_masked_pivot (ctx. is_linear_variables, M, k, ctx. var_priorities)
277357 if r != = nothing
278358 push! (ctx. pivots, r[1 ][2 ])
279359 return r
280360 end
281361 ctx. rank1 = k - 1
282362 end
283363 if ctx. rank2 === nothing
284- r = find_masked_pivot (ctx. is_highest_diff, M, k)
364+ mask = LazyMaskAnd (ctx. is_highest_diff, ctx. valid_pivot_mask)
365+ r = find_masked_pivot (ctx. is_highest_diff, M, k, ctx. var_priorities)
285366 if r != = nothing
286367 push! (ctx. pivots, r[1 ][2 ])
287368 return r
@@ -291,16 +372,28 @@ function (ctx::BareissContext)(M, k::Int)
291372 # TODO : It would be better to sort the variables by
292373 # derivative order here to enable more elimination
293374 # opportunities.
294- r = find_masked_pivot (nothing , M, k)
375+ r = find_masked_pivot (nothing , M, k, ctx . var_priorities )
295376 r != = nothing && push! (ctx. pivots, r[1 ][2 ])
296377 return r
297378end
298379
380+ struct BareissContextUpdate{C <: BareissContext , F}
381+ context:: C
382+ inner_update:: F
383+ end
384+
385+ function (bcu:: BareissContextUpdate )(zero!, M, k, swapto, pivot, last_pivot; kw... )
386+ ctx = bcu. context
387+ col = swapto[2 ]
388+ ctx. valid_pivot_mask[col] = false
389+ return bcu. inner_update (zero!, M, k, swapto, pivot, last_pivot; kw... )
390+ end
391+
299392function do_bareiss! (M, Mold, is_linear_variables:: AbstractVector{Bool} ,
300- is_highest_diff:: AbstractVector{Bool} )
301- ctx = BareissContext (nothing , nothing , Int[], is_linear_variables, is_highest_diff)
302- bareiss_ops = (noop_colswap, SyncedSwapRows (Mold),
303- bareiss_update_virtual_colswap_mtk !, bareiss_zero!)
393+ is_highest_diff:: AbstractVector{Bool} , var_priorities = nothing )
394+ ctx = BareissContext (is_linear_variables, is_highest_diff, var_priorities )
395+ update! = BareissContextUpdate (ctx, bareiss_update_virtual_colswap_mtk!)
396+ bareiss_ops = (noop_colswap, SyncedSwapRows (Mold), update !, bareiss_zero!)
304397 rank3, = bareiss! (M, bareiss_ops; find_pivot = ctx)
305398 rank2 = something (ctx. rank2, rank3)
306399 rank1 = something (ctx. rank1, rank2)
@@ -321,8 +414,37 @@ function force_var_to_zero!(structure::SystemStructure, ils::SparseMatrixCLIL, v
321414 return ils
322415end
323416
324- function structural_singularity_removal! (state:: TransformationState , ils:: SparseMatrixCLIL ;
325- variable_underconstrained! = force_var_to_zero!)
417+ """
418+ $TYPEDSIGNATURES
419+
420+ Information about the pivots chosen by Bareiss during `structural_singularity_removal!`.
421+ This can be returned from `structural_singularity_removal!` by passing `Val(true)` as the last
422+ positional argument.
423+
424+ $TYPEDFIELDS
425+ """
426+ struct PivotInfo
427+ """
428+ The length of the prefix of `pivots` that is variables which _only_ occur in linear
429+ equations of the sort considered by this pass. These variables must be solved for
430+ using the integer coefficient equations considered by this pass.
431+ """
432+ n_linear_vars:: Int
433+ """
434+ Number of elements in `pivots` after `n_linear_vars` corresponding to highest order
435+ derivative variables.
436+ """
437+ n_highest_diff_vars:: Int
438+ """
439+ The list of pivots chosen by the Bareiss algorithm.
440+ """
441+ pivots:: Vector{Int}
442+ end
443+
444+ function structural_singularity_removal! (
445+ state:: TransformationState , ils:: SparseMatrixCLIL , :: Val{ReturnPivots} = Val {false} ();
446+ variable_underconstrained! = force_var_to_zero!
447+ ) where {ReturnPivots}
326448 (; structure) = state
327449 (; graph, solvable_graph, var_to_diff, eq_to_diff) = state. structure
328450 # Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
@@ -337,5 +459,9 @@ function structural_singularity_removal!(state::TransformationState, ils::Sparse
337459 ils = variable_underconstrained! (structure, ils, v)
338460 end
339461
340- return ils
462+ if ReturnPivots
463+ return ils, PivotInfo (rank1, rank2, pivots)
464+ else
465+ return ils
466+ end
341467end
0 commit comments