Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 352 additions & 0 deletions src/spanning_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,36 @@ struct SpanningTreeLMO{G} <: FrankWolfe.LinearMinimizationOracle
graph::G
end

"""
Union-find find with path compression.
Used to detect cycles and connectivity in fixing checks.
"""
function uf_find!(parent::Vector{Int}, x::Int)
while parent[x] != x
parent[x] = parent[parent[x]]
x = parent[x]
end
return x
end

"""
Union-find union.
Returns `false` when `a` and `b` are already connected (cycle).
"""
function uf_union!(parent::Vector{Int}, a::Int, b::Int)
ra = uf_find!(parent, a)
rb = uf_find!(parent, b)
if ra == rb
return false
end
parent[rb] = ra
return true
end

"""
Minimum spanning tree LMO using Kruskal on the weighted graph.
Returns an incidence vector over `edges(g)`.
"""
function FrankWolfe.compute_extreme_point(
lmo::SpanningTreeLMO,
direction::M;
Expand All @@ -34,3 +64,325 @@ function FrankWolfe.compute_extreme_point(
end
return v
end

"""
Bound-aware LMO for spanning trees.
Contracts forced edges, removes forbidden edges, and runs Kruskal
on the reduced graph, then lifts the solution to the original graph.
"""
function Boscia.bounded_compute_extreme_point(
lmo::SpanningTreeLMO,
direction,
lb,
ub,
int_vars;
kwargs...,
)
N = length(direction)
edges_iter = collect(Graphs.edges(lmo.graph))
@assert length(edges_iter) == N "length(edges_iter) = $(length(edges_iter)) != N = $N"

# Contract all forced edges into components via union-find.
# Connected nodes will have the same parent node at the end.
parent = collect(1:Graphs.nv(lmo.graph))
for (i, edge) in enumerate(edges_iter)
if lb[i] ≈ 1
@assert ub[i] ≈ 1
uf_union!(parent, src(edge), dst(edge))
end
end
# Count the number of connected components
# which will be contracted to super nodes.
comp_id = Dict{Int,Int}()
comp = Vector{Int}(undef, Graphs.nv(lmo.graph))
k = 0
for vtx in 1:Graphs.nv(lmo.graph)
root = uf_find!(parent, vtx)
if !haskey(comp_id, root)
k += 1
comp_id[root] = k
end
comp[vtx] = comp_id[root]
end

# Initialize solution with all forced edges.
v = spzeros(N)
for i in 1:N
if lb[i] ≈ 1
v[i] = 1
end
end

# Build reduced graph on components using the cheapest allowed
# edge for each component pair. (Note that after contracted the graph
# we might have multiple edges between the same super nodes.)
# Then run Kruskal on the reduced graph.
if k > 1
edge_choice = Dict{Tuple{Int,Int},Tuple{eltype(direction),Int}}()
for (i, edge) in enumerate(edges_iter)
# Edge is forbidden.
if ub[i] ≈ 0
continue
end
c1 = comp[src(edge)]
c2 = comp[dst(edge)]
# source and destination nodes lie in the same component,
# so the edge can be ignored.
if c1 == c2
continue
end
# order independent key
if c1 > c2
c1, c2 = c2, c1
end
key = (c1, c2)
w = direction[i]
# If multiple edges connect the same super nodes
# choose the cheapest one.
if !haskey(edge_choice, key) || w < edge_choice[key][1]
edge_choice[key] = (w, i)
end
end
# Build reduced graph and weight matrix.
reduced_graph = SimpleGraph(k)
for key in keys(edge_choice)
add_edge!(reduced_graph, key[1], key[2])
end
distmx = spzeros(k, k)
for (key, (w, _)) in edge_choice
distmx[key[1], key[2]] = w
distmx[key[2], key[1]] = w
end
reduced_span = Graphs.kruskal_mst(reduced_graph, distmx)
for edge in reduced_span
c1 = src(edge)
c2 = dst(edge)
# order independent key
if c1 > c2
c1, c2 = c2, c1
end
idx = edge_choice[(c1, c2)][2]
v[idx] = 1
end
end
# Optional sanity check.
@debug begin
for i in 1:N
if ub[i] ≈ 0
@assert v[i] ≈ 0
elseif lb[i] ≈ 1
@assert v[i] ≈ 1
end
end
end
return v
end

"""
Lightweight feasibility check for a candidate `v`.
Enforces total edge count and singleton-cut constraints.
"""
function Boscia.is_simple_linear_feasible(lmo::SpanningTreeLMO, v)
n = Graphs.nv(lmo.graph)
if n == 0
return true
end
total = sum(v)
if abs(total - (n - 1)) > 1e-4
return false
end
# detect cycles among edges that are (almost) fully selected
parent = collect(1:n)
for (idx, edge) in enumerate(edges(lmo.graph))
if v[idx] < 1 - 1e-4
continue
end
if !(uf_union!(parent, src(edge), dst(edge)))
return false
end
end
# singleton cut constraints: each vertex must have at least one incident edge
degrees = zeros(eltype(v), n)
for (idx, edge) in enumerate(edges(lmo.graph))
if v[idx] ≈ 0
continue
end
degrees[src(edge)] += v[idx]
degrees[dst(edge)] += v[idx]
end
if minimum(degrees) < 1 - 1e-4
return false
end
# ensure support is connected (prevents disjoint forests passing)
parent = collect(1:n)
for (idx, edge) in enumerate(edges(lmo.graph))
if v[idx] <= 1e-4
continue
end
uf_union!(parent, src(edge), dst(edge))
end
# All nodes should have a common root if the graph is connected.
root = uf_find!(parent, 1)
for vtx in 2:n
if uf_find!(parent, vtx) != root
return false
end
end
return true
end

"""
Feasibility of bounds alone for spanning trees.
Returns `Boscia.OPTIMAL` if some spanning tree can satisfy the bounds.
"""
function Boscia.check_feasibility(
lmo::SpanningTreeLMO,
lb,
ub,
int_vars,
n
)
edges_iter = collect(Graphs.edges(lmo.graph))
n_local = Graphs.nv(lmo.graph)
@debug "n_local = $n_local n = $n"
if n_local <= 1
return Boscia.OPTIMAL
end
# The forced edges (lb=ub=1) must be acyclic.
parent = collect(1:n_local)
for (i, edge) in enumerate(edges_iter)
if lb[i] ≈ 1
if !uf_union!(parent, src(edge), dst(edge))
@debug "Forced edges form a cycle"
return Boscia.INFEASIBLE
end
end
end
# The graph must stay connected after removing forbidden edges.
parent = collect(1:n_local)
for (i, edge) in enumerate(edges_iter)
if !(ub[i] ≈ 0)
uf_union!(parent, src(edge), dst(edge))
end
end
root = uf_find!(parent, 1)
for vtx in 2:n_local
if uf_find!(parent, vtx) != root
@debug "Forbidden edges disconnect graph"
return Boscia.INFEASIBLE
end
end
return Boscia.OPTIMAL
end

"""
Spanning trees support decomposition-invariant oracles since in-faces
are defined by fixing edges to 0/1.
"""
function Boscia.is_decomposition_invariant_oracle_simple(::SpanningTreeLMO)
return true
end

function FrankWolfe.is_decomposition_invariant_oracle(::SpanningTreeLMO)
return true
end

"""
In-face bounded LMO: adds fixings implied by `x` (0/1 entries)
on top of existing bounds, then solves the bounded LMO.
"""
function Boscia.bounded_compute_inface_extreme_point(
lmo::SpanningTreeLMO,
direction,
x,
lb,
ub,
int_vars;
kwargs...,
)
N = length(direction)
lb2 = copy(lb)
ub2 = copy(ub)
# In-face fixings: iterates are floating; using machine `eps()` is too strict and
# can prevent DICG from identifying the correct minimal face.
tol = 1e-8
for i in 1:N
if x[i] ≤ tol
ub2[i] = 0.0
elseif x[i] ≥ 1 - tol
lb2[i] = 1.0
end
end
return Boscia.bounded_compute_extreme_point(lmo, direction, lb2, ub2, int_vars; kwargs...)
end

"""
Check whether `a` lies on the minimal face of `x` and respects bounds.
"""
function Boscia.is_simple_inface_feasible(
lmo::SpanningTreeLMO,
a,
x,
lb,
ub,
int_vars,
)
tol = 1e-8
for i in eachindex(a)
if x[i] ≤ tol && a[i] > 1e-6
return false
elseif x[i] ≥ 1 - tol && a[i] < 1 - 1e-6
return false
end
end
return true
end

"""
Maximum DICG step size respecting bounds and in-face fixings from `x`.
"""
function Boscia.bounded_dicg_maximum_step(
lmo::SpanningTreeLMO,
x,
direction,
lb,
ub,
int_vars;
kwargs...,
)
T = promote_type(eltype(x), eltype(direction))
gamma_max = one(T)
# Safety margin: keep x - gamma*direction strictly inside bounds
# to prevent FrankWolfe's domain oracle from repeatedly rejecting points
# due to floating-point noise during line search.
tol = T(1e-12)
for i in eachindex(x)
if direction[i] == 0
continue
end
lower = lb[i]
upper = ub[i]
# FrankWolfe updates as: x_new = x - gamma * direction.
# We need gamma >= 0 such that lower <= x_new <= upper.
if direction[i] > 0
# x_new decreases with gamma; only the lower bound can become active:
# lower <= x[i] - gamma*dir => gamma <= (x[i] - lower)/dir
if x[i] ≤ lower + tol
return zero(gamma_max)
end
num = x[i] - lower - tol
num <= 0 && return zero(gamma_max)
gamma_max = min(gamma_max, num / direction[i])
else
# direction[i] < 0: x_new increases with gamma; only the upper bound can become active:
# x[i] - gamma*dir <= upper => gamma <= (upper - x[i]) / (-dir)
if x[i] ≥ upper - tol
return zero(gamma_max)
end
num = upper - x[i] - tol
num <= 0 && return zero(gamma_max)
gamma_max = min(gamma_max, num / (-direction[i]))
end
end
return gamma_max
end
Loading
Loading