Skip to content
Merged
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Hungarian = "e91730f6-4275-51fb-a7a0-7064cfbd3b39"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
Boscia = "0.2"
Expand All @@ -26,7 +27,8 @@ Hungarian = "0.7.0"

[extras]
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "HiGHS"]
test = ["Test", "HiGHS", "StableRNGs"]
195 changes: 161 additions & 34 deletions src/birkhoff_polytope.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""
BirkhoffLMO
BirkhoffLMO
Comment thread
matbesancon marked this conversation as resolved.
Outdated

A bounded LMO for the Birkhoff polytope. This oracle computes an extreme point subject to
node-specific bounds on the integer variables.
A bounded Linear Minimization Oracle (LMO) for the Birkhoff polytope. The oracle
computes extreme points (permutation matrices) possibly under node-specific bound
constraints on a subset of integer variables. It also supports mixed-integer
variants, partial fixings, and in-face oracles used by DiCG/BCG-like methods.
"""
mutable struct BirkhoffLMO <: FrankWolfe.LinearMinimizationOracle
append_by_column::Bool
Expand All @@ -17,10 +19,14 @@ mutable struct BirkhoffLMO <: FrankWolfe.LinearMinimizationOracle
updated_lmo::Bool
atol::Float64
rtol::Float64
boscia_use::Bool
end

# return an integer-type BirkhoffLMO
"""
BirkhoffLMO(dim, int_vars; append_by_column=true, atol=1e-6, rtol=1e-3)

Constructor for a mixed-integer Birkhoff LMO. All variables listed in
`int_vars` are treated as integer with default bounds `[0,1]`.
"""
BirkhoffLMO(dim, int_vars; append_by_column=true, atol=1e-6, rtol=1e-3) = BirkhoffLMO(
append_by_column,
dim,
Expand All @@ -34,10 +40,13 @@ BirkhoffLMO(dim, int_vars; append_by_column=true, atol=1e-6, rtol=1e-3) = Birkho
true,
atol,
rtol,
true,
)

# return a continuous BirkhoffLMO
"""
BirkhoffLMO(dim; append_by_column=true, atol=1e-6, rtol=1e-3)

Constructor for a continuous Birkhoff LMO (no integer variables).
"""
Comment thread
matbesancon marked this conversation as resolved.
Outdated
BirkhoffLMO(dim; append_by_column=true, atol=1e-6, rtol=1e-3) = BirkhoffLMO(
append_by_column,
dim,
Expand All @@ -51,15 +60,24 @@ BirkhoffLMO(dim; append_by_column=true, atol=1e-6, rtol=1e-3) = BirkhoffLMO(
true,
atol,
rtol,
false,
)

## Necessary

"""
Computes the extreme point given an direction d, the current lower and upper bounds on the integer variables, and the set of integer variables.
FrankWolfe.compute_extreme_point(lmo::BirkhoffLMO, d::AbstractMatrix{T}; kwargs...) where {T}

Compute an extreme point (a permutation matrix) minimizing the linear form
`⟨d, X⟩` over the current feasible face of the (possibly reduced) Birkhoff polytope,
subject to integer bounds and fixings maintained by `lmo`.

Return a sparse `n×n` matrix with `0/1` entries representing the selected permutation.
"""
function Boscia.compute_extreme_point(lmo::BirkhoffLMO, d::AbstractMatrix{T}; kwargs...) where {T}
function FrankWolfe.compute_extreme_point(
lmo::BirkhoffLMO,
d::AbstractMatrix{T};
kwargs...,
) where {T}
n = lmo.dim

fixed_to_one_rows = lmo.fixed_to_one_rows
Expand Down Expand Up @@ -133,7 +151,18 @@ function Boscia.compute_extreme_point(lmo::BirkhoffLMO, d::AbstractMatrix{T}; kw
return m
end

function Boscia.compute_extreme_point(lmo::BirkhoffLMO, d::AbstractVector{T}; kwargs...) where {T}
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given the number of methods for compute_extreme_point, we should not document all of them except rare cases, it's usually not very discoverable. Any critical information should be in the LMO type docstringm, the rest removed

FrankWolfe.compute_extreme_point(lmo::BirkhoffLMO, d::AbstractVector{T}; kwargs...) where {T}

Vector form of [`compute_extreme_point`](@ref), where `d` is a vectorized cost.
Handles the reshape/transposition according to `append_by_column` and returns a
sparse vectorized permutation of length `n^2`.
"""
function FrankWolfe.compute_extreme_point(
lmo::BirkhoffLMO,
d::AbstractVector{T};
kwargs...,
) where {T}
n = lmo.dim
d = lmo.append_by_column ? reshape(d, (n, n)) : transpose(reshape(d, (n, n)))
m = Boscia.compute_extreme_point(lmo, d; kwargs...)
Expand All @@ -153,7 +182,13 @@ function Boscia.compute_extreme_point(lmo::BirkhoffLMO, d::AbstractVector{T}; kw
end

"""
Computes the extreme point given an direction d, the current lower and upper bounds on the integer variables, and the set of integer variables.
FrankWolfe.compute_inface_extreme_point(lmo::BirkhoffLMO, direction::AbstractMatrix{T}, x::AbstractMatrix{T}; kwargs...) where {T}

Compute a vertex that minimizes the linear form `⟨direction, X⟩` on the minimal face containing
the current iterate `x`, given current fixings and bounds. Entries already at `1` and `0` in
`x` are kept fixed.

Return a sparse `n×n` permutation matrix consistent with the in-face constraints.
"""
function FrankWolfe.compute_inface_extreme_point(
lmo::BirkhoffLMO,
Expand Down Expand Up @@ -186,7 +221,7 @@ function FrankWolfe.compute_inface_extreme_point(
for i in 1:nreduced
row_orig = index_map_rows[i]
col_orig = index_map_cols[j]
if x[row_orig, col_orig] >= 1-eps()
if x[row_orig, col_orig] >= 1 - eps()
push!(fixed_to_one_rows, row_orig)
push!(fixed_to_one_cols, col_orig)

Expand All @@ -210,9 +245,9 @@ function FrankWolfe.compute_inface_extreme_point(
for i in 1:nreduced
row_orig = index_map_rows[i]
if lmo.append_by_column
orig_linear_idx = (col_orig-1)*n+row_orig
orig_linear_idx = (col_orig - 1) * n + row_orig
else
orig_linear_idx = (row_orig-1)*n+col_orig
orig_linear_idx = (row_orig - 1) * n + col_orig
end
if x[row_orig, col_orig] <= eps()
if lmo.append_by_column
Expand Down Expand Up @@ -262,6 +297,12 @@ function FrankWolfe.compute_inface_extreme_point(
return m
end

"""
FrankWolfe.compute_inface_extreme_point(lmo::BirkhoffLMO, direction::AbstractVector{T}, x::AbstractVector{T}; kwargs...) where {T}

Vector form of the in-face oracle; reshapes inputs/outputs according to
`append_by_column` and returns a sparse vectorized permutation.
"""
function FrankWolfe.compute_inface_extreme_point(
lmo::BirkhoffLMO,
direction::AbstractVector{T},
Expand Down Expand Up @@ -290,8 +331,12 @@ function FrankWolfe.compute_inface_extreme_point(
end

"""
LMO-like operation which computes a vertex minimizing in `direction` on the face defined by the current fixings.
Fixings are maintained by the oracle (or deduced from `x` itself).
FrankWolfe.dicg_maximum_step(lmo::BirkhoffLMO, direction, x; kwargs...)

Compute the maximum feasible step-size `γ_max` along a given direction
for DICG updates on the hypercube constraints `0 ≤ x ≤ 1`. If moving in the
positive (increasing) direction hits the `1`-bound or in the negative (decreasing)
direction hits the `0`-bound, the step is clipped accordingly.
"""
function FrankWolfe.dicg_maximum_step(lmo::BirkhoffLMO, direction, x; kwargs...)
n = lmo.dim
Expand All @@ -317,12 +362,21 @@ function FrankWolfe.dicg_maximum_step(lmo::BirkhoffLMO, direction, x; kwargs...)

end

"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for all FW methods

FrankWolfe.is_decomposition_invariant_oracle(lmo::BirkhoffLMO)

Indicate that this oracle is decomposition invariant.
"""
function FrankWolfe.is_decomposition_invariant_oracle(lmo::BirkhoffLMO)
return true
end

"""
The sum of each row and column has to be equal to 1.
Boscia.is_linear_feasible(blmo::BirkhoffLMO, v::AbstractVector)

Check whether vector `v` is feasible for the Birkhoff polytope (row/column sums
are `1` under the configured vectorization) and consistent with the current
integer bounds `lower_bounds/upper_bounds` for indices in `int_vars`.
"""
function Boscia.is_linear_feasible(blmo::BirkhoffLMO, v::AbstractVector)
for (i, int_var) in enumerate(blmo.int_vars)
Expand Down Expand Up @@ -351,7 +405,12 @@ function Boscia.is_linear_feasible(blmo::BirkhoffLMO, v::AbstractVector)
return true
end

# Read global bounds from the problem.
"""
Boscia.build_global_bounds(blmo::BirkhoffLMO, integer_variables)

Build a `Boscia.IntegerBounds()` object from the current lower/upper bounds stored
in the oracle for all integer variables.
"""
function Boscia.build_global_bounds(blmo::BirkhoffLMO, integer_variables)
global_bounds = Boscia.IntegerBounds()
for (idx, int_var) in enumerate(blmo.int_vars)
Expand All @@ -361,34 +420,58 @@ function Boscia.build_global_bounds(blmo::BirkhoffLMO, integer_variables)
return global_bounds
end

# Get list of variables indices.
# If the problem has n variables, they are expected to contiguous and ordered from 1 to n.
"""
Boscia.get_list_of_variables(blmo::BirkhoffLMO)

Return the number of variables (`n = dim^2`) and the list of their linear indices
`1:n` under the current storage order.
"""
function Boscia.get_list_of_variables(blmo::BirkhoffLMO)
n = blmo.dim^2
return n, collect(1:n)
end

# Get list of integer variables
"""
Boscia.get_integer_variables(blmo::BirkhoffLMO)

Return the vector of linear indices of integer-constrained variables.
"""
function Boscia.get_integer_variables(blmo::BirkhoffLMO)
return blmo.int_vars
end

# Get the index of the integer variable the bound is working on.
"""
Boscia.get_int_var(blmo::BirkhoffLMO, cidx)

Map the internal bound index `cidx` to its corresponding variable linear index.
"""
function Boscia.get_int_var(blmo::BirkhoffLMO, cidx)
return blmo.int_vars[cidx]
end

# Get the list of lower bounds.
"""
Boscia.get_lower_bound_list(blmo::BirkhoffLMO)

Return the list of indices for the lower-bound constraints (i.e., `1:length(lower_bounds)`).
"""
function Boscia.get_lower_bound_list(blmo::BirkhoffLMO)
return collect(1:length(blmo.lower_bounds))
end

# Get the list of upper bounds.
"""
Boscia.get_upper_bound_list(blmo::BirkhoffLMO)

Return the list of indices for the upper-bound constraints (i.e., `1:length(upper_bounds)`).
"""
function Boscia.get_upper_bound_list(blmo::BirkhoffLMO)
return collect(1:length(blmo.upper_bounds))
end

# Read bound value for c_idx.
"""
Boscia.get_bound(blmo::BirkhoffLMO, c_idx, sense::Symbol)

Read the bound value for constraint index `c_idx` with `sense ∈ {:lessthan, :greaterthan}`.
"""
function Boscia.get_bound(blmo::BirkhoffLMO, c_idx, sense::Symbol)
if sense == :lessthan
return blmo.upper_bounds[c_idx]
Expand All @@ -401,7 +484,14 @@ end

## Changing the bounds constraints.

# Change the value of the bound c_idx.
"""
Boscia.set_bound!(blmo::BirkhoffLMO, c_idx, value, sense::Symbol)

Change the value of an existing bound constraint at index `c_idx` with
`sense ∈ {:lessthan, :greaterthan}`. If a lower bound is set to `1.0`, the
corresponding `(i,j)` entry is fixed to one and the reduced index maps are
refreshed on demand.
"""
function Boscia.set_bound!(blmo::BirkhoffLMO, c_idx, value, sense::Symbol)
# Reset the lmo if necessary
if blmo.updated_lmo
Expand Down Expand Up @@ -432,7 +522,13 @@ function Boscia.set_bound!(blmo::BirkhoffLMO, c_idx, value, sense::Symbol)
end
end

# Delete bounds.
"""
Boscia.delete_bounds!(blmo::BirkhoffLMO, cons_delete)

Delete a collection of bounds given as pairs `(idx, sense)`. Lower bounds
are set to `0.0`, upper bounds to `1.0`. Also rebuild the reduced index maps
based on entries fixed to one.
"""
function Boscia.delete_bounds!(blmo::BirkhoffLMO, cons_delete)
for (d_idx, sense) in cons_delete
if sense == :greaterthan
Expand Down Expand Up @@ -469,7 +565,13 @@ function Boscia.delete_bounds!(blmo::BirkhoffLMO, cons_delete)
return true
end

# Add bound constraint.
"""
Boscia.add_bound_constraint!(blmo::BirkhoffLMO, key, value, sense::Symbol)

Add or overwrite a single bound for the integer variable with linear index `key`.
If a lower bound is set to `1.0`, the corresponding entry is fixed to one and the
fixing bookkeeping is updated.
"""
function Boscia.add_bound_constraint!(blmo::BirkhoffLMO, key, value, sense::Symbol)
idx = findfirst(x -> x == key, blmo.int_vars)
if sense == :greaterthan
Expand Down Expand Up @@ -497,25 +599,43 @@ end

## Checks

# Check if the subject of the bound c_idx is an integer variable (recorded in int_vars).
"""
Boscia.is_constraint_on_int_var(blmo::BirkhoffLMO, c_idx, int_vars)

Check whether the subject of bound index `c_idx` corresponds to an integer variable
in the provided `int_vars` set.
"""
function Boscia.is_constraint_on_int_var(blmo::BirkhoffLMO, c_idx, int_vars)
return blmo.int_vars[c_idx] in int_vars
end

# To check if there is bound for the variable in the global or node bounds.
"""
Boscia.is_bound_in(blmo::BirkhoffLMO, c_idx, bounds)

Return `true` if there is a bound for the variable targeted by constraint index
`c_idx` inside the `bounds` dictionary-like structure.
"""
function Boscia.is_bound_in(blmo::BirkhoffLMO, c_idx, bounds)
return haskey(bounds, blmo.int_vars[c_idx])
end

# Has variable an integer constraint?
"""
Boscia.has_integer_constraint(blmo::BirkhoffLMO, idx)

Return `true` if linear index `idx` is constrained to be integer (i.e., in `int_vars`).
"""
function Boscia.has_integer_constraint(blmo::BirkhoffLMO, idx)
return idx in blmo.int_vars
end

## Safety Functions

# Check if the bounds were set correctly in build_LMO.
# Safety check only.
"""
Boscia.build_LMO_correct(blmo::BirkhoffLMO, node_bounds)

Verify that the bounds recorded in `blmo` match those in
`node_bounds` (for both lower and upper maps). Returns `true` if consistent.
"""
function Boscia.build_LMO_correct(blmo::BirkhoffLMO, node_bounds)
for key in keys(node_bounds.lower_bounds)
idx = findfirst(x -> x == key, blmo.int_vars)
Expand All @@ -534,6 +654,13 @@ end

## Optional

"""
Boscia.check_feasibility(blmo::BirkhoffLMO)

Quick feasibility test for the bounds alone (without a specific `x`). It validates
that `ub ≥ lb` componentwise and that row/column sums can still achieve `1` given
the accumulated lower/upper bounds on the integer variables present in each row/column.
"""
function Boscia.check_feasibility(blmo::BirkhoffLMO)
for (lb, ub) in zip(blmo.lower_bounds, blmo.upper_bounds)
if ub < lb
Expand Down
Loading
Loading