Skip to content

Commit 6dee8aa

Browse files
author
Charles Vielzeuf
committed
crossover:v0
1 parent 9c8296e commit 6dee8aa

5 files changed

Lines changed: 561 additions & 7 deletions

File tree

src/CoolPDLP.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ include("public.jl")
4242
include("components/restart.jl")
4343
include("components/generic.jl")
4444
include("components/termination.jl")
45+
include("components/crossover.jl")
4546

4647
include("algorithms/common.jl")
4748
include("algorithms/pdhg.jl")

src/algorithms/common.jl

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct Algorithm{
1818
restart::RestartParameters{T}
1919
generic::GenericParameters
2020
termination::TerminationParameters{T}
21+
crossover::CrossoverParameters{T}
2122
end
2223

2324
"""
@@ -46,6 +47,13 @@ end
4647
termination_reltol = 1.0e-4,
4748
max_kkt_passes = 10^5,
4849
time_limit = 100.0,
50+
# crossover (V1: threshold snapping to bounds)
51+
crossover = true,
52+
crossover_threshold = 1.0e-6,
53+
crossover_fixed_tol = 1.0e-8,
54+
crossover_rollback_on_kkt_regression = true,
55+
crossover_kkt_rtol = 0.0,
56+
crossover_use_effective_bounds = true,
4957
)
5058
5159
Constructor for algorithm configs.
@@ -75,6 +83,13 @@ function Algorithm{A}(
7583
termination_reltol = 1.0e-4,
7684
max_kkt_passes = 10^5,
7785
time_limit = 100.0,
86+
# crossover
87+
crossover = true,
88+
crossover_threshold = 1.0e-6,
89+
crossover_fixed_tol = 1.0e-8,
90+
crossover_rollback_on_kkt_regression = true,
91+
crossover_kkt_rtol = 0.0,
92+
crossover_use_effective_bounds = true,
7893
) where {A, T, Ti, M, B}
7994

8095
conversion = ConversionParameters(
@@ -104,19 +119,29 @@ function Algorithm{A}(
104119
max_kkt_passes,
105120
time_limit
106121
)
122+
reltol_T = _T(termination_reltol)
123+
crossover_params = CrossoverParameters(;
124+
enabled = crossover,
125+
threshold = max(_T(crossover_threshold), reltol_T / 20),
126+
fixed_tol = _T(crossover_fixed_tol),
127+
rollback_on_kkt_regression = crossover_rollback_on_kkt_regression,
128+
kkt_rtol = _T(crossover_kkt_rtol),
129+
use_effective_bounds = crossover_use_effective_bounds,
130+
)
107131

108132
return Algorithm{A, T, Ti, M, B}(
109133
conversion,
110134
preconditioning,
111135
step_size,
112136
restart,
113137
generic,
114-
termination
138+
termination,
139+
crossover_params
115140
)
116141
end
117142

118143
function Base.show(io::IO, algo::Algorithm{A}) where {A}
119-
(; conversion, preconditioning, step_size, restart, generic, termination) = algo
144+
(; conversion, preconditioning, step_size, restart, generic, termination, crossover) = algo
120145
return print(
121146
io, """
122147
$A algorithm:
@@ -125,12 +150,57 @@ function Base.show(io::IO, algo::Algorithm{A}) where {A}
125150
- $step_size
126151
- $restart
127152
- $generic
128-
- $termination"""
153+
- $termination
154+
- $crossover"""
129155
)
130156
end
131157

132158
abstract type AbstractState{T, V} end
133159

160+
"""
161+
apply_crossover!(state, milp, algo)
162+
163+
Apply crossover to the current iterate when enabled.
164+
165+
If `rollback_on_kkt_regression` is set, compare KKT errors before and after snapping
166+
and restore the pre-crossover primal when the certificate would be lost or relative KKT
167+
error increases beyond `kkt_rtol`. Updates `state.stats.crossover_applied` and
168+
`state.stats.crossover_rolled_back`.
169+
"""
170+
function apply_crossover!(
171+
state::AbstractState,
172+
milp::MILP,
173+
algo::Algorithm,
174+
)
175+
params = algo.crossover
176+
stats = state.stats
177+
if !params.enabled
178+
stats.crossover_applied = false
179+
stats.crossover_rolled_back = false
180+
stats.crossover_n_snapped = 0
181+
return nothing
182+
end
183+
(; termination_reltol) = algo.termination
184+
err_before = stats.err
185+
x_backup = copy(state.sol.x)
186+
crossover_threshold!(state.sol, milp, params)
187+
n_changed = crossover_n_changed(state.sol.x, x_backup)
188+
err_after = kkt_errors!(state.scratch, state.sol, milp)
189+
if n_changed > 0 &&
190+
crossover_kkt_acceptable(err_before, err_after, termination_reltol, params)
191+
stats.err = err_after
192+
stats.crossover_applied = true
193+
stats.crossover_rolled_back = false
194+
stats.crossover_n_snapped = n_changed
195+
else
196+
state.sol.x .= x_backup
197+
stats.crossover_applied = false
198+
stats.crossover_rolled_back = n_changed > 0
199+
stats.crossover_n_snapped = 0
200+
end
201+
return nothing
202+
end
203+
134204
function prog_showvalues(state::AbstractState)
135205
err = state.stats.err
136206
(; primal, primal_scale, dual, dual_scale, gap, gap_scale) = err
@@ -198,6 +268,9 @@ function solve(
198268
return get_solution(state, milp), state.stats
199269
end
200270
solve!(state, milp, algo)
271+
if state.stats.termination_status == OPTIMAL
272+
apply_crossover!(state, milp, algo)
273+
end
201274
return get_solution(state, milp), state.stats
202275
end
203276

src/components/crossover.jl

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""
2+
CrossoverParameters
3+
4+
Post-solve crossover settings (V1: primal threshold snapping to bounds).
5+
6+
# Fields
7+
8+
$(TYPEDFIELDS)
9+
"""
10+
@kwdef struct CrossoverParameters{T <: Number}
11+
"whether to apply crossover after PDLP/PDHG terminates optimally"
12+
enabled::Bool = true
13+
"distance to a bound below which the primal is snapped to that bound"
14+
threshold::T = 1.0e-6
15+
"tolerance for treating lower and upper bounds as equal (fixed variable)"
16+
fixed_tol::T = 1.0e-8
17+
"revert to the pre-crossover primal if KKT errors regress beyond tolerances"
18+
rollback_on_kkt_regression::Bool = true
19+
"relative KKT increase tolerated above the pre-crossover value (0 = no increase)"
20+
kkt_rtol::T = 0.0
21+
"tighten infinite bounds from equality rows before snapping"
22+
use_effective_bounds::Bool = true
23+
end
24+
25+
function Base.show(io::IO, params::CrossoverParameters)
26+
(; enabled, threshold, fixed_tol, rollback_on_kkt_regression, kkt_rtol, use_effective_bounds) = params
27+
return print(
28+
io,
29+
"CrossoverParameters: enabled=$enabled, threshold=$threshold, fixed_tol=$fixed_tol, ",
30+
"rollback_on_kkt_regression=$rollback_on_kkt_regression, kkt_rtol=$kkt_rtol, ",
31+
"use_effective_bounds=$use_effective_bounds",
32+
)
33+
end
34+
35+
"""
36+
crossover_kkt_acceptable(err_before, err_after, termination_reltol, params)
37+
38+
Return `true` if the post-crossover KKT errors should be kept.
39+
40+
When `rollback_on_kkt_regression` is true, reject the crossover if either:
41+
- `relative(err_after) > termination_reltol` (would invalidate the optimality certificate), or
42+
- `relative(err_after) > relative(err_before) * (1 + kkt_rtol)`.
43+
"""
44+
function crossover_kkt_acceptable(
45+
err_before::KKTErrors,
46+
err_after::KKTErrors,
47+
termination_reltol,
48+
params::CrossoverParameters,
49+
)
50+
params.rollback_on_kkt_regression || return true
51+
rel_before = relative(err_before)
52+
rel_after = relative(err_after)
53+
rel_after <= termination_reltol || return false
54+
rel_after <= rel_before * (1 + params.kkt_rtol) || return false
55+
return true
56+
end
57+
58+
"""True if coordinate `j` is on a finite box bound (within `atol`)."""
59+
function _crossover_at_box_bound(xj, lj, uj; atol::Real = 1.0e-12)
60+
return (isfinite(lj) && abs(xj - lj) <= atol) ||
61+
(isfinite(uj) && abs(xj - uj) <= atol)
62+
end
63+
64+
"""
65+
crossover_effective_bounds(milp, x)
66+
67+
Box bounds tightened with implied limits from equality rows.
68+
69+
When an equality row has exactly one variable not yet on a finite box bound, that
70+
row's implied bound is used to fill in an infinite bound (e.g. `x₁ ≤ 1` from
71+
`x₁ + x₂ = 1` when `x₂` is already on its lower bound).
72+
"""
73+
function crossover_effective_bounds(
74+
milp::MILP{T},
75+
x::AbstractVector{T};
76+
bound_atol::Real = 1.0e-12,
77+
eq_atol::Real = 1.0e-12,
78+
) where {T}
79+
lv_eff = copy(milp.lv)
80+
uv_eff = copy(milp.uv)
81+
m, n = size(milp.A)
82+
at_box = [
83+
_crossover_at_box_bound(x[j], milp.lv[j], milp.uv[j]; atol = bound_atol) for j in 1:n
84+
]
85+
for i in 1:m
86+
isapprox(milp.lc[i], milp.uc[i]; atol = eq_atol) || continue
87+
free = Int[]
88+
slack = milp.lc[i]
89+
@inbounds for j in 1:n
90+
aij = milp.A[i, j]
91+
aij == 0 && continue
92+
if at_box[j]
93+
slack -= aij * x[j]
94+
else
95+
push!(free, j)
96+
end
97+
end
98+
length(free) == 1 || continue
99+
j = only(free)
100+
aij = milp.A[i, j]
101+
if aij > 0
102+
implied = slack / aij
103+
if !isfinite(uv_eff[j])
104+
uv_eff[j] = implied
105+
else
106+
uv_eff[j] = min(uv_eff[j], implied)
107+
end
108+
elseif aij < 0
109+
implied = slack / aij
110+
if !isfinite(lv_eff[j])
111+
lv_eff[j] = implied
112+
else
113+
lv_eff[j] = max(lv_eff[j], implied)
114+
end
115+
end
116+
end
117+
return lv_eff, uv_eff
118+
end
119+
120+
"""
121+
crossover_threshold!(x, lv, uv, params::CrossoverParameters)
122+
123+
Snap primal `x` to variable bounds using a fixed threshold (naive crossover).
124+
125+
For each coordinate: fixed variables are set to their bound; otherwise, if `x` is
126+
within `threshold` of a finite lower or upper bound, it is moved to that bound.
127+
"""
128+
function crossover_threshold!(
129+
x::AbstractVector{T},
130+
lv::AbstractVector{T},
131+
uv::AbstractVector{T},
132+
params::CrossoverParameters{T},
133+
) where {T}
134+
(; threshold, fixed_tol) = params
135+
# Broadcast (GPU-safe); same order as the former scalar loop: fixed → lower → upper.
136+
fixed = abs.(lv .- uv) .<= fixed_tol
137+
@. x = ifelse(fixed, lv, x)
138+
near_l = isfinite.(lv) .& (x .- lv .<= threshold)
139+
@. x = ifelse(near_l, lv, x)
140+
near_u = isfinite.(uv) .& (uv .- x .<= threshold)
141+
@. x = ifelse(near_u, uv, x)
142+
return x
143+
end
144+
145+
function crossover_threshold!(
146+
x::AbstractVector{T},
147+
milp::MILP{T},
148+
params::CrossoverParameters{T},
149+
) where {T}
150+
if params.use_effective_bounds
151+
lv_eff, uv_eff = crossover_effective_bounds(milp, x)
152+
else
153+
lv_eff, uv_eff = milp.lv, milp.uv
154+
end
155+
crossover_threshold!(x, lv_eff, uv_eff, params)
156+
return x
157+
end
158+
159+
function crossover_threshold!(
160+
sol::PrimalDualSolution{T},
161+
milp::MILP{T},
162+
params::CrossoverParameters{T},
163+
) where {T}
164+
crossover_threshold!(sol.x, milp, params)
165+
return sol
166+
end
167+
168+
"""Count coordinates changed by crossover (exact compare)."""
169+
function crossover_n_changed(x_after, x_before)
170+
return count(i -> x_after[i] != x_before[i], eachindex(x_before))
171+
end
172+
173+
"""
174+
fraction_at_bounds(x, milp; atol=1e-12)
175+
176+
Fraction of coordinates equal to a finite bound (diagnostic for crossover effect).
177+
"""
178+
function fraction_at_bounds(
179+
x::AbstractVector,
180+
milp::MILP;
181+
atol::Real = 1.0e-12,
182+
)
183+
(; lv, uv) = milp
184+
n = length(x)
185+
n == 0 && return 0.0
186+
at_l = isfinite.(lv) .& (abs.(x .- lv) .<= atol)
187+
at_u = isfinite.(uv) .& (abs.(x .- uv) .<= atol)
188+
return sum(at_l .| at_u) / n
189+
end

src/components/termination.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ mutable struct ConvergenceStats{T <: Number}
5454
termination_status::TerminationStatus
5555
"history of KKT errors, indexed by number of KKT passes"
5656
const error_history::Vector{Tuple{Int, KKTErrors{T}}}
57+
"true if the last crossover was applied and kept"
58+
crossover_applied::Bool
59+
"true if the last crossover was reverted to preserve KKT quality"
60+
crossover_rolled_back::Bool
61+
"number of primal coordinates snapped and kept by the last crossover"
62+
crossover_n_snapped::Int
5763

5864
function ConvergenceStats(
5965
::Type{T};
@@ -62,27 +68,42 @@ mutable struct ConvergenceStats{T <: Number}
6268
time_elapsed = 0.0,
6369
kkt_passes = 0,
6470
termination_status = STILL_RUNNING,
65-
error_history = Tuple{Int, KKTErrors{T}}[]
71+
error_history = Tuple{Int, KKTErrors{T}}[],
72+
crossover_applied = false,
73+
crossover_rolled_back = false,
74+
crossover_n_snapped = 0,
6675
) where {T}
6776
return new{T}(
6877
err,
6978
starting_time,
7079
time_elapsed,
7180
kkt_passes,
7281
termination_status,
73-
error_history
82+
error_history,
83+
crossover_applied,
84+
crossover_rolled_back,
85+
crossover_n_snapped,
7486
)
7587
end
7688
end
7789

7890
function Base.show(io::IO, stats::ConvergenceStats)
79-
(; err, time_elapsed, kkt_passes, termination_status) = stats
91+
(; err, time_elapsed, kkt_passes, termination_status, crossover_applied, crossover_rolled_back, crossover_n_snapped) =
92+
stats
93+
xover = if crossover_rolled_back
94+
"crossover rolled back"
95+
elseif crossover_applied
96+
"crossover applied ($crossover_n_snapped coords)"
97+
else
98+
"crossover not applied"
99+
end
80100
return print(
81101
io,
82102
"""Convergence stats with termination status $termination_status:
83103
- $err
84104
- time elapsed: $(round(time_elapsed; digits = 3)) seconds
85-
- KKT passes: $kkt_passes""",
105+
- KKT passes: $kkt_passes
106+
- $xover""",
86107
)
87108
end
88109

0 commit comments

Comments
 (0)