Skip to content

Commit 71910a4

Browse files
committed
finished implementation
1 parent 192286b commit 71910a4

3 files changed

Lines changed: 17 additions & 8 deletions

File tree

src/afw.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -423,16 +423,17 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
423423
# compute new vertex with normal or weak oracle
424424
if weak_separation
425425
lazy_threshold = fast_dot(gradient, x) - phi / lazy_tolerance
426-
(v, _) = compute_weak_separation_point(lmo, gradient, lazy_threshold)
427-
tt = weaksep
426+
(v, gap) = compute_weak_separation_point(lmo, gradient, lazy_threshold)
427+
tt = gap == 0.0 ? regular : weaksep
428428
else
429429
v = compute_extreme_point(lmo, gradient)
430+
gap = zero(eltype(v))
430431
tt = regular
431432
end
432433
end
433-
# Real dual gap promises enough progress.
434434
grad_dot_fw_vertex = fast_dot(v, gradient)
435435
dual_gap = grad_dot_x - grad_dot_fw_vertex
436+
# Real dual gap promises enough progress.
436437
if dual_gap >= phi / lazy_tolerance
437438
gamma_max = one(a_lambda)
438439
d = muladd_memory_mode(memory_mode, d, x, v)
@@ -441,6 +442,7 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
441442
fw_step_taken = true
442443
index = -1
443444
else # lower our expectation for progress.
445+
@assert tt != weaksep
444446
tt = dualstep
445447
phi = min(dual_gap, phi / 2.0)
446448
gamma_max = zero(a_lambda)
@@ -460,15 +462,15 @@ function afw_step(x, gradient, lmo, active_set, epsilon, d; memory_mode::MemoryE
460462
away_gap = fast_dot(a, gradient) - grad_dot_x
461463
(v, gap) = if weak_separation
462464
# Condition for taking a FW step
463-
# ⟨∇f, x-v⟩ ≥ gₐ
465+
# ⟨∇f, x-v⟩ ≥ gₐ <=>
464466
# ⟨∇f, v⟩ ≤ ⟨∇f, x⟩ - gₐ
465-
# We ask for a bit more on the FW step
467+
# We ask for a bit more progress on the FW step
466468
# to promote away steps when we can (and therefore sparsity)
467469
# ⟨∇f, v⟩ ≤ ⟨∇f, x⟩ - K gₐ
468470
lazy_threshold = grad_dot_x - lazy_tolerance * away_gap
469471
compute_weak_separation_point(lmo, gradient, lazy_threshold)
470472
else
471-
(compute_extreme_point(lmo, gradient), 0.0)
473+
(compute_extreme_point(lmo, gradient), zero(away_gap))
472474
end
473475
dual_gap = grad_dot_x - fast_dot(v, gradient)
474476
if dual_gap > away_gap && dual_gap >= epsilon

src/pairwise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ function blended_pairwise_conditional_gradient(
376376
else # dual step
377377
# set to computed dual_gap for consistency between the lazy and non-lazy run.
378378
# that is ok as we scale with the K = 2.0 default anyways
379-
# we only update the dual gap if the step was regular (not lazy from discarded set)
379+
# we only update the dual gap if the step was regular or weaksep (not lazy from discarded set)
380380
if tt != lazylazy
381381
@assert dual_gap + gap < phi
382382
phi = dual_gap + gap

test/weak_separation.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ end
7171
@test w == v
7272
end
7373

74-
@testset "AFW with weak separation" begin
74+
@testset "AFW and BPCG with weak separation" begin
7575
n = 1000
7676
# reference point to get an optimum on a face
7777
ref_point = [0.6 + mod(idx, 2) for idx in 1:n]
@@ -88,5 +88,12 @@ end
8888
tracking_weak = FrankWolfe.TrackingLMO(Hypercube())
8989
x, v, primal, dual_gap, trajectory_weak, active_set_weak = FrankWolfe.away_frank_wolfe(f, grad!, tracking_weak, x0, verbose=false, weak_separation=true, lazy=lazy)
9090
@test tracking_lmo.counter <= tracking_weak.counter
91+
92+
tracking_lmo = FrankWolfe.TrackingLMO(Hypercube())
93+
x, v, primal, dual_gap, trajectory_exact, active_set_exact = FrankWolfe.blended_pairwise_conditional_gradient(f, grad!, tracking_lmo, x0, verbose=false, weak_separation=false, lazy=lazy)
94+
tracking_weak = FrankWolfe.TrackingLMO(Hypercube())
95+
x, v, primal, dual_gap, trajectory_weak, active_set_weak = FrankWolfe.blended_pairwise_conditional_gradient(f, grad!, tracking_weak, x0, verbose=false, weak_separation=true, lazy=lazy)
96+
@test tracking_lmo.counter <= tracking_weak.counter
97+
9198
end
9299
end

0 commit comments

Comments
 (0)