Skip to content

Commit 7da4739

Browse files
Merge pull request #49 from JuliaComputing/as/refactor-alias-elim
feat: use state priorities in bareiss pivot selection
2 parents 058a7f5 + 3d363c0 commit 7da4739

1 file changed

Lines changed: 150 additions & 24 deletions

File tree

src/singularity_removal.jl

Lines changed: 150 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,21 @@ end
1515
level === nothing ? v : (v => level)
1616
end
1717

18-
function structural_singularity_removal!(state::TransformationState;
19-
variable_underconstrained! = force_var_to_zero!, kwargs...)
18+
function structural_singularity_removal!(
19+
state::TransformationState, ::Val{ReturnPivots} = Val{false}();
20+
variable_underconstrained! = force_var_to_zero!, kwargs...
21+
) where {ReturnPivots}
2022
mm = linear_subsys_adjmat!(state; kwargs...)
2123
if size(mm, 1) == 0
22-
return mm # No linear subsystems
24+
if ReturnPivots
25+
return mm, PivotInfo(0, 0, Int[])
26+
else
27+
return mm # No linear subsystems
28+
end
2329
end
2430

2531
(; graph, var_to_diff, solvable_graph) = state.structure
26-
mm = structural_singularity_removal!(state, mm; variable_underconstrained!)
32+
mm, pivotinfo = structural_singularity_removal!(state, mm, Val{true}(); variable_underconstrained!)
2733
s = state.structure
2834
for (ei, e) in enumerate(mm.nzrows)
2935
set_neighbors!(s.graph, e, mm.row_cols[ei])
@@ -34,7 +40,11 @@ function structural_singularity_removal!(state::TransformationState;
3440
end
3541
end
3642

37-
return mm
43+
if ReturnPivots
44+
return mm, pivotinfo
45+
else
46+
return mm
47+
end
3848
end
3949

4050
# For debug purposes
@@ -55,7 +65,7 @@ the `constraint`.
5565
@inline function find_first_linear_variable(M::SparseMatrixCLIL,
5666
range,
5767
mask,
58-
constraint)
68+
constraint, ::Nothing = nothing)
5969
eadj = M.row_cols
6070
@inbounds for i in range
6171
vertices = eadj[i]
@@ -70,10 +80,33 @@ the `constraint`.
7080
return nothing
7181
end
7282

83+
@inline function find_first_linear_variable(
84+
M::SparseMatrixCLIL,
85+
range,
86+
mask,
87+
constraint, var_priorities::AbstractVector{Int}
88+
)
89+
eadj = M.row_cols
90+
@inbounds for i in range
91+
vertices = eadj[i]
92+
constraint(length(vertices)) || continue
93+
candidate_v = 0
94+
candidate_val = 0
95+
for (j, v) in enumerate(vertices)
96+
mask === nothing || mask[v] || continue
97+
iszero(candidate_v) || var_priorities[v] < var_priorities[candidate_v] || continue
98+
candidate_v = v
99+
candidate_val = M.row_vals[i][j]
100+
end
101+
iszero(candidate_v) || return CartesianIndex(i, candidate_v), candidate_val
102+
end
103+
return nothing
104+
end
105+
73106
@inline function find_first_linear_variable(M::AbstractMatrix,
74107
range,
75108
mask,
76-
constraint)
109+
constraint, ::Nothing = nothing)
77110
@inbounds for i in range
78111
row = @view M[i, :]
79112
if constraint(count(!iszero, row))
@@ -87,12 +120,36 @@ end
87120
return nothing
88121
end
89122

90-
function find_masked_pivot(variables, M, k)
91-
r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(1))
123+
@inline function find_first_linear_variable(
124+
M::AbstractMatrix,
125+
range,
126+
mask,
127+
constraint, var_priorities::AbstractVector{Int}
128+
)
129+
@inbounds for i in range
130+
row = @view M[i, :]
131+
constraint(count(!iszero, row)) || continue
132+
candidate_v = 0
133+
candidate_val = 0
134+
for (v, val) in enumerate(row)
135+
mask === nothing || mask[v] || continue
136+
if iszero(candidate_v) || var_priorities[v] < var_priorities[candidate_v]
137+
candidate_v = v
138+
candidate_val = val
139+
end
140+
end
141+
iszero(candidate_v) && return nothing
142+
return CartesianIndex(i, candidate_v), candidate_val
143+
end
144+
return nothing
145+
end
146+
147+
function find_masked_pivot(variables, M, k, var_priorities)
148+
r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(1), var_priorities)
92149
r !== nothing && return r
93-
r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(2))
150+
r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(2), var_priorities)
94151
r !== nothing && return r
95-
r = find_first_linear_variable(M, k:size(M, 1), variables, _ -> true)
152+
r = find_first_linear_variable(M, k:size(M, 1), variables, _ -> true, var_priorities)
96153
return r
97154
end
98155

@@ -207,14 +264,15 @@ function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
207264
end
208265
end
209266
solvable_variables = findall(is_linear_variables)
267+
var_priorities = has_state_priorities(structure) ? get_state_priorities(structure) : nothing
210268

211269
local bar
212270
try
213-
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff)
271+
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities)
214272
catch e
215273
e isa OverflowError || rethrow(e)
216274
mm = convert(SparseMatrixCLIL{BigInt, Ti}, mm_orig)
217-
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff)
275+
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities)
218276
end
219277

220278
# This phrasing infers the return type as `Union{Tuple{...}}` instead of
@@ -243,6 +301,18 @@ end
243301
(s::SyncedSwapRows{Nothing})(M, i::Int, j::Int) = Base.swaprows!(M, i, j)
244302
(s::SyncedSwapRows)(M, i::Int, j::Int) = (Base.swaprows!(s.Mold, i, j); Base.swaprows!(M, i, j))
245303

304+
"""
305+
$TYPEDEF
306+
307+
Lazy `&&` of two boolean masks. Only implements whatever is required for `find_masked_pivot`.
308+
"""
309+
struct LazyMaskAnd{V1 <: AbstractVector{Bool}, V2 <: AbstractVector{Bool}}
310+
mask1::V1
311+
mask2::V2
312+
end
313+
314+
Base.getindex(lma::LazyMaskAnd, i::Integer) = lma.mask1[i] && lma.mask2[i]
315+
246316
"""
247317
$(TYPEDEF)
248318
@@ -253,12 +323,21 @@ Mutable state threaded through the Bareiss factorization callbacks.
253323
- `pivots`: accumulates the column index of every pivot chosen during elimination.
254324
- `is_linear_variables`/`is_highest_diff`: masks used for the tiered pivot search.
255325
"""
256-
mutable struct BareissContext{V1 <: AbstractVector{Bool}, V2 <: AbstractVector{Bool}}
326+
mutable struct BareissContext{V1 <: AbstractVector{Bool}, V2 <: AbstractVector{Bool}, P <: Union{Nothing, AbstractVector{Int}}}
257327
rank1::Union{Nothing, Int}
258328
rank2::Union{Nothing, Int}
259329
pivots::Vector{Int}
260330
is_linear_variables::V1
261331
is_highest_diff::V2
332+
valid_pivot_mask::BitVector
333+
var_priorities::P
334+
end
335+
336+
function BareissContext(is_linear_variables, is_highest_diff, var_priorities = nothing)
337+
return BareissContext(
338+
nothing, nothing, Int[], is_linear_variables, is_highest_diff,
339+
trues(length(is_linear_variables)), var_priorities
340+
)
262341
end
263342

264343
"""
@@ -273,15 +352,17 @@ The column index of every selected pivot is appended to `ctx.pivots`.
273352
"""
274353
function (ctx::BareissContext)(M, k::Int)
275354
if ctx.rank1 === nothing
276-
r = find_masked_pivot(ctx.is_linear_variables, M, k)
355+
mask = LazyMaskAnd(ctx.is_linear_variables, ctx.valid_pivot_mask)
356+
r = find_masked_pivot(ctx.is_linear_variables, M, k, ctx.var_priorities)
277357
if r !== nothing
278358
push!(ctx.pivots, r[1][2])
279359
return r
280360
end
281361
ctx.rank1 = k - 1
282362
end
283363
if ctx.rank2 === nothing
284-
r = find_masked_pivot(ctx.is_highest_diff, M, k)
364+
mask = LazyMaskAnd(ctx.is_highest_diff, ctx.valid_pivot_mask)
365+
r = find_masked_pivot(ctx.is_highest_diff, M, k, ctx.var_priorities)
285366
if r !== nothing
286367
push!(ctx.pivots, r[1][2])
287368
return r
@@ -291,16 +372,28 @@ function (ctx::BareissContext)(M, k::Int)
291372
# TODO: It would be better to sort the variables by
292373
# derivative order here to enable more elimination
293374
# opportunities.
294-
r = find_masked_pivot(nothing, M, k)
375+
r = find_masked_pivot(nothing, M, k, ctx.var_priorities)
295376
r !== nothing && push!(ctx.pivots, r[1][2])
296377
return r
297378
end
298379

380+
struct BareissContextUpdate{C <: BareissContext, F}
381+
context::C
382+
inner_update::F
383+
end
384+
385+
function (bcu::BareissContextUpdate)(zero!, M, k, swapto, pivot, last_pivot; kw...)
386+
ctx = bcu.context
387+
col = swapto[2]
388+
ctx.valid_pivot_mask[col] = false
389+
return bcu.inner_update(zero!, M, k, swapto, pivot, last_pivot; kw...)
390+
end
391+
299392
function do_bareiss!(M, Mold, is_linear_variables::AbstractVector{Bool},
300-
is_highest_diff::AbstractVector{Bool})
301-
ctx = BareissContext(nothing, nothing, Int[], is_linear_variables, is_highest_diff)
302-
bareiss_ops = (noop_colswap, SyncedSwapRows(Mold),
303-
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
393+
is_highest_diff::AbstractVector{Bool}, var_priorities = nothing)
394+
ctx = BareissContext(is_linear_variables, is_highest_diff, var_priorities)
395+
update! = BareissContextUpdate(ctx, bareiss_update_virtual_colswap_mtk!)
396+
bareiss_ops = (noop_colswap, SyncedSwapRows(Mold), update!, bareiss_zero!)
304397
rank3, = bareiss!(M, bareiss_ops; find_pivot = ctx)
305398
rank2 = something(ctx.rank2, rank3)
306399
rank1 = something(ctx.rank1, rank2)
@@ -321,8 +414,37 @@ function force_var_to_zero!(structure::SystemStructure, ils::SparseMatrixCLIL, v
321414
return ils
322415
end
323416

324-
function structural_singularity_removal!(state::TransformationState, ils::SparseMatrixCLIL;
325-
variable_underconstrained! = force_var_to_zero!)
417+
"""
418+
$TYPEDSIGNATURES
419+
420+
Information about the pivots chosen by Bareiss during `structural_singularity_removal!`.
421+
This can be returned from `structural_singularity_removal!` by passing `Val(true)` as the last
422+
positional argument.
423+
424+
$TYPEDFIELDS
425+
"""
426+
struct PivotInfo
427+
"""
428+
The length of the prefix of `pivots` that is variables which _only_ occur in linear
429+
equations of the sort considered by this pass. These variables must be solved for
430+
using the integer coefficient equations considered by this pass.
431+
"""
432+
n_linear_vars::Int
433+
"""
434+
Number of elements in `pivots` after `n_linear_vars` corresponding to highest order
435+
derivative variables.
436+
"""
437+
n_highest_diff_vars::Int
438+
"""
439+
The list of pivots chosen by the Bareiss algorithm.
440+
"""
441+
pivots::Vector{Int}
442+
end
443+
444+
function structural_singularity_removal!(
445+
state::TransformationState, ils::SparseMatrixCLIL, ::Val{ReturnPivots} = Val{false}();
446+
variable_underconstrained! = force_var_to_zero!
447+
) where {ReturnPivots}
326448
(; structure) = state
327449
(; graph, solvable_graph, var_to_diff, eq_to_diff) = state.structure
328450
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
@@ -337,5 +459,9 @@ function structural_singularity_removal!(state::TransformationState, ils::Sparse
337459
ils = variable_underconstrained!(structure, ils, v)
338460
end
339461

340-
return ils
462+
if ReturnPivots
463+
return ils, PivotInfo(rank1, rank2, pivots)
464+
else
465+
return ils
466+
end
341467
end

0 commit comments

Comments
 (0)