Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/src/reference/2_lmo.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The Linear Minimization Oracle (LMO) is a key component called at each iteration
v\in \argmin_{x\in \mathcal{C}} \langle d,x \rangle.
```

See [Combettes, Pokutta 2021](https://arxiv.org/abs/2101.10040) for references on most LMOs
See [Combettes, Pokutta 2021](https://arxiv.org/abs/2101.10040) for references on essential LMOs
implemented in the package and their comparison with projection operators.

## Interface and wrappers
Expand All @@ -19,6 +19,13 @@ All of them are subtypes of [`FrankWolfe.LinearMinimizationOracle`](@ref) and im
compute_extreme_point
```

Optionally, an LMO can implement a weak separation procedure based either on a heuristic or on an approximation algorithm:
```@docs
compute_weak_separation_point
```

Weak separation procedures will be used in the methods using an active set and lazified variants only.

We also provide some meta-LMOs wrapping another one with extended behavior:
```@docs
FrankWolfe.CachedLinearMinimizationOracle
Expand Down
113 changes: 108 additions & 5 deletions src/abstract_oracles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,29 @@ All LMOs should accept keyword arguments that they can ignore.
"""
function compute_extreme_point end

"""
compute_weak_separation_point(lmo, direction, max_value) -> (vertex, gap)

Weak separation algorithm for a given oracle.
Unlike `compute_extreme_point`, `compute_weak_separation_point` may provide a suboptimal `vertex` with the following conditions:
- `vertex` is still a valid extreme point of the polytope.
IF an inexact vertex is computed:
- `⟨v, d⟩ ≤ max_value`, the pre-specified required improvement.
- `⟨v, d⟩ ≤ ⟨v_opt, d⟩ + gap`, with `v_opt` a vertex computed with the exact oracle.
- If the algorithm used to compute the inexact vertex provides a bound on optimality, the `gap` value must be valid.
- Otherwise, e.g. if the vertex is computed with a heuristic, `gap = ∞`.
ELSE (the oracle computes an optimal vertex):
- `⟨v, d⟩` may be greater than `max_value`, `gap` must be 0.
"""
function compute_weak_separation_point(lmo, direction, max_value; kwargs...) end

# default to computing an exact vertex.
function compute_weak_separation_point(lmo::LinearMinimizationOracle, direction, max_value; kwargs...)
v = compute_extreme_point(lmo, direction; kwargs...)
gap = zero(eltype(v)) * zero(eltype(direction))
return v, gap
end

"""
CachedLinearMinimizationOracle{LMO}

Expand All @@ -43,11 +66,29 @@ Vertices of `LMO` have to be of type `VT` if provided.
mutable struct SingleLastCachedLMO{LMO,A} <: CachedLinearMinimizationOracle{LMO}
last_vertex::Union{Nothing,A}
inner::LMO
store_cache::Bool
end

# initializes with no cache by default
SingleLastCachedLMO(lmo::LMO) where {LMO<:LinearMinimizationOracle} =
SingleLastCachedLMO{LMO,AbstractVector}(nothing, lmo)
SingleLastCachedLMO{LMO,AbstractVector}(nothing, lmo, true)

# gap is 0 if exact, ∞ if cached point
function compute_weak_separation_point(lmo::SingleLastCachedLMO, direction, max_value; kwargs...)
if lmo.last_vertex !== nothing && isfinite(max_value)
# cache is a sufficiently-decreasing direction
if fast_dot(lmo.last_vertex, direction) ≤ max_value
T = promote_type(eltype(lmo.last_vertex), eltype(direction))
return lmo.last_vertex, T(Inf)
end
end
v = compute_extreme_point(lmo.inner, direction, kwargs...)
if lmo.store_cache
lmo.last_vertex = v
end
T = promote_type(eltype(v), eltype(direction))
return v, zero(T)
end

function compute_extreme_point(
lmo::SingleLastCachedLMO,
Expand All @@ -62,7 +103,7 @@ function compute_extreme_point(
return lmo.last_vertex
end
end
v = compute_extreme_point(lmo.inner, direction, kwargs...)
v = compute_extreme_point(lmo.inner, direction, v=v, kwargs...)
if store_cache
lmo.last_vertex = v
end
Expand Down Expand Up @@ -188,14 +229,17 @@ mutable struct VectorCacheLMO{LMO<:LinearMinimizationOracle,VT} <:
CachedLinearMinimizationOracle{LMO}
vertices::Vector{VT}
inner::LMO
store_cache::Bool
greedy::Bool
weak_separation::Bool
end

function VectorCacheLMO{LMO,VT}(lmo::LMO) where {VT,LMO<:LinearMinimizationOracle}
return VectorCacheLMO{LMO,VT}(VT[], lmo)
return VectorCacheLMO{LMO,VT}(VT[], lmo, true, false, false)
end

function VectorCacheLMO(lmo::LMO) where {LMO<:LinearMinimizationOracle}
return VectorCacheLMO{LMO,Vector{Float64}}(AbstractVector[], lmo)
return VectorCacheLMO{LMO,Vector{Float64}}(AbstractVector[], lmo, true, false, false)
end

function Base.empty!(lmo::VectorCacheLMO)
Expand All @@ -205,6 +249,65 @@ end

Base.length(lmo::VectorCacheLMO) = length(lmo.vertices)

function compute_weak_separation_point(lmo::VectorCacheLMO, direction, max_value; kwargs...)
if isempty(lmo.vertices)
v, gap = if lmo.weak_separation
compute_weak_separation_point(lmo.inner, direction, max_value; kwargs...)
else
v = compute_extreme_point(lmo.inner, direction; kwargs...)
v, zero(eltype(v))
end
T = promote_type(eltype(v), eltype(direction))
if lmo.store_cache
push!(lmo.vertices, v)
end
return v, T(gap)
end
best_idx = -1
best_val = Inf
best_v = nothing
for idx in reverse(eachindex(lmo.vertices))
@inbounds v = lmo.vertices[idx]
new_val = fast_dot(v, direction)
if new_val ≤ max_value
T = promote_type(eltype(v), eltype(direction))
# stop and return
if lmo.greedy
return v, T(Inf)
end
# otherwise, compare to incumbent
if new_val < best_val
best_v = v
best_val = new_val
best_idx = idx
end
end
end
if best_idx > 0
T = promote_type(eltype(best_v), eltype(direction))
return best_v, T(Inf)
end
# no satisfactory vertex found, call oracle
v, gap = if lmo.weak_separation
compute_weak_separation_point(lmo.inner, direction, max_value; kwargs...)
else
v = compute_extreme_point(lmo.inner, direction; kwargs...)
v, zero(eltype(v))
end
if lmo.store_cache
# note: we do not check for duplicates. hence you might end up with more vertices,
# in fact up to number of dual steps many, that might be already in the cache
# in order to reach this point, if v was already in the cache is must not meet the threshold (otherwise we would have returned it)
# and it is the best possible, hence we will perform a dual step on the outside.
#
# note: another possibility could be to test against that in the if statement but then you might end you recalculating the same vertex a few times.
# as such this might be a better tradeoff, i.e., to not check the set for duplicates and potentially accept #dual_steps many duplicates.
push!(lmo.vertices, v)
end
T = promote_type(eltype(v), eltype(direction))
return v, T(gap)
end

function compute_extreme_point(
lmo::VectorCacheLMO,
direction;
Expand All @@ -228,7 +331,7 @@ function compute_extreme_point(
@inbounds v = lmo.vertices[idx]
new_val = fast_dot(v, direction)
if new_val ≤ threshold
# stop, store and return
# stop and return
if greedy
return v
end
Expand Down
64 changes: 46 additions & 18 deletions src/afw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ function away_frank_wolfe(
use_extra_vertex_storage=false,
linesearch_workspace=nothing,
recompute_last_vertex=true,
weak_separation=false,
)
# add the first vertex to active set from initialization
active_set = ActiveSet([(1.0, x0)])
Expand Down Expand Up @@ -64,6 +65,7 @@ function away_frank_wolfe(
use_extra_vertex_storage=use_extra_vertex_storage,
linesearch_workspace=linesearch_workspace,
recompute_last_vertex=recompute_last_vertex,
weak_separation=weak_separation,
)
end

Expand Down Expand Up @@ -95,6 +97,7 @@ function away_frank_wolfe(
use_extra_vertex_storage=false,
linesearch_workspace=nothing,
recompute_last_vertex=true,
weak_separation=false,
)
# format string for output of the algorithm
format_string = "%6s %13s %14e %14e %14e %14e %14e %14i\n"
Expand Down Expand Up @@ -153,7 +156,7 @@ function away_frank_wolfe(
)
grad_type = typeof(gradient)
println(
"GRADIENTTYPE: $grad_type LAZY: $lazy lazy_tolerance: $lazy_tolerance MOMENTUM: $momentum AWAYSTEPS: $away_steps",
"GRADIENT TYPE: $grad_type LAZY: $lazy LAZY_TOLERANCE: $lazy_tolerance WEAK_SEPARATION: $weak_separation MOMENTUM: $momentum AWAYSTEPS: $away_steps",
)
println("Linear Minimization Oracle: $(typeof(lmo))")
if (use_extra_vertex_storage || add_dropped_vertices) && extra_vertex_storage === nothing
Expand Down Expand Up @@ -223,10 +226,16 @@ function away_frank_wolfe(
extra_vertex_storage=extra_vertex_storage,
lazy_tolerance=lazy_tolerance,
memory_mode=memory_mode,
weak_separation=weak_separation,
)
else
d, vertex, index, gamma_max, phi_value, away_step_taken, fw_step_taken, tt =
afw_step(x, gradient, lmo, active_set, epsilon, d, memory_mode=memory_mode)
afw_step(
x, gradient, lmo, active_set, epsilon, d,
memory_mode=memory_mode,
weak_separation=weak_separation,
lazy_tolerance=lazy_tolerance,
)
end
else
d, vertex, index, gamma_max, phi_value, away_step_taken, fw_step_taken, tt =
Expand All @@ -248,7 +257,7 @@ function away_frank_wolfe(
memory_mode,
)

gamma = min(gamma_max, gamma)
gamma = min(gamma_max, gamma)
# cleanup and renormalize every x iterations. Only for the fw steps.
renorm = mod(t, renorm_interval) == 0
if away_step_taken
Expand Down Expand Up @@ -368,9 +377,9 @@ function away_frank_wolfe(
return x, v, primal, dual_gap, traj_data, active_set
end

function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_vertex_storage=false, extra_vertex_storage=nothing, lazy_tolerance=2.0, memory_mode::MemoryEmphasis=InplaceEmphasis())
function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_vertex_storage=false, extra_vertex_storage=nothing, lazy_tolerance=2.0, memory_mode::MemoryEmphasis=InplaceEmphasis(), weak_separation::Bool=true)
_, v, v_loc, _, a_lambda, a, a_loc, _, _ = active_set_argminmax(active_set, gradient)
#Do lazy FW step
# do lazy FW step
grad_dot_lazy_fw_vertex = fast_dot(v, gradient)
grad_dot_x = fast_dot(x, gradient)
grad_dot_a = fast_dot(a, gradient)
Expand All @@ -385,7 +394,7 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
fw_step_taken = true
index = v_loc
else
#Do away step, as it promises enough progress.
# do away step, as it promises enough progress.
if grad_dot_a - grad_dot_x > grad_dot_x - grad_dot_lazy_fw_vertex &&
grad_dot_a - grad_dot_x >= phi / lazy_tolerance
tt = away
Expand All @@ -395,7 +404,7 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
away_step_taken = true
fw_step_taken = false
index = a_loc
#Resort to calling the LMO
# resort to calling the LMO
else
# optionally: try vertex storage
if use_extra_vertex_storage
Expand All @@ -406,26 +415,34 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
@debug("Found acceptable lazy vertex in storage")
v = new_forward_vertex
tt = lazylazy
end
else
found_better_vertex = false
end
if !found_better_vertex
# compute new vertex with normal or weak oracle
if weak_separation
lazy_threshold = fast_dot(gradient, x) - phi / lazy_tolerance
(v, gap) = compute_weak_separation_point(lmo, gradient, lazy_threshold)
tt = gap == 0.0 ? regular : weaksep
else
v = compute_extreme_point(lmo, gradient)
gap = zero(eltype(v))
tt = regular
end
else
v = compute_extreme_point(lmo, gradient)
tt = regular
end
# Real dual gap promises enough progress.
grad_dot_fw_vertex = fast_dot(v, gradient)
dual_gap = grad_dot_x - grad_dot_fw_vertex
# Real dual gap promises enough progress.
if dual_gap >= phi / lazy_tolerance
gamma_max = one(a_lambda)
d = muladd_memory_mode(memory_mode, d, x, v)
vertex = v
away_step_taken = false
fw_step_taken = true
index = -1
#Lower our expectation for progress.
else
else # lower our expectation for progress.
@assert tt != weaksep
tt = dualstep
phi = min(dual_gap, phi / 2.0)
gamma_max = zero(a_lambda)
Expand All @@ -439,14 +456,25 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
return d, vertex, index, gamma_max, phi, away_step_taken, fw_step_taken, tt
end

function afw_step(x, gradient, lmo, active_set, epsilon, d; memory_mode::MemoryEmphasis=InplaceEmphasis())
function afw_step(x, gradient, lmo, active_set, epsilon, d; memory_mode::MemoryEmphasis=InplaceEmphasis(), weak_separation::Bool=false, lazy_tolerance=2.0)
_, _, _, _, a_lambda, a, a_loc = active_set_argminmax(active_set, gradient)
v = compute_extreme_point(lmo, gradient)
grad_dot_x = fast_dot(x, gradient)
away_gap = fast_dot(a, gradient) - grad_dot_x
(v, gap) = if weak_separation
# Condition for taking a FW step
# ⟨∇f, x-v⟩ ≥ gₐ <=>
# ⟨∇f, v⟩ ≤ ⟨∇f, x⟩ - gₐ
# We ask for a bit more progress on the FW step
# to promote away steps when we can (and therefore sparsity)
# ⟨∇f, v⟩ ≤ ⟨∇f, x⟩ - K gₐ
lazy_threshold = grad_dot_x - lazy_tolerance * away_gap
compute_weak_separation_point(lmo, gradient, lazy_threshold)
else
(compute_extreme_point(lmo, gradient), zero(away_gap))
end
dual_gap = grad_dot_x - fast_dot(v, gradient)
if dual_gap >= away_gap && dual_gap >= epsilon
tt = regular
if dual_gap > away_gap && dual_gap >= epsilon
tt = gap == 0.0 ? regular : weaksep
gamma_max = one(a_lambda)
d = muladd_memory_mode(memory_mode, d, x, v)
vertex = v
Expand All @@ -469,7 +497,7 @@ function afw_step(x, gradient, lmo, active_set, epsilon, d; memory_mode::MemoryE
fw_step_taken = false
index = a_loc
end
return d, vertex, index, gamma_max, dual_gap, away_step_taken, fw_step_taken, tt
return d, vertex, index, gamma_max, dual_gap + gap, away_step_taken, fw_step_taken, tt
end

function fw_step(x, gradient, lmo, d; memory_mode::MemoryEmphasis = InplaceEmphasis())
Expand Down
2 changes: 2 additions & 0 deletions src/defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct OutplaceEmphasis <: MemoryEmphasis end
away = 6
pairwise = 7
drop = 8
weaksep = 9
simplex_descent = 101
gap_step = 102
last = 1000
Expand All @@ -34,6 +35,7 @@ const st = (
away="A",
pairwise="P",
drop="D",
weaksep="W",
simplex_descent="SD",
gap_step="GS",
last="Last",
Expand Down
Loading