Skip to content

Commit 1c8914c

Browse files
committed
add Li-Lin method, add nonconvex tests
1 parent c24479e commit 1c8914c

4 files changed

Lines changed: 309 additions & 0 deletions

File tree

src/ProximalAlgorithms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ include("algorithms/panoc.jl")
2020
include("algorithms/douglasrachford.jl")
2121
include("algorithms/primaldual.jl")
2222
include("algorithms/davisyin.jl")
23+
include("algorithms/lilin.jl")
2324

2425
end # module

src/algorithms/lilin.jl

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Li, Lin, "Accelerated Proximal Gradient Methods for Nonconvex Programming",
2+
# Proceedings of NIPS 2015 (2015).
3+
4+
using Base.Iterators
5+
using ProximalAlgorithms.IterationTools
6+
using ProximalOperators: Zero
7+
using LinearAlgebra
8+
using Printf
9+
10+
struct LiLin_iterable{R <: Real, C <: Union{R, Complex{R}}, Tx <: AbstractArray{C}, Tf, TA, Tg}
11+
f::Tf # smooth term
12+
A::TA # matrix/linear operator
13+
g::Tg # (possibly) nonsmooth, proximable term
14+
x0::Tx # initial point
15+
gamma::Maybe{R} # stepsize parameter of forward and backward steps
16+
adaptive::Bool # enforce adaptive stepsize even if L is provided
17+
delta::R #
18+
eta::R #
19+
end
20+
21+
mutable struct LiLin_state{R <: Real, Tx, TAx}
22+
x::Tx # iterate
23+
y::Tx # extrapolated point
24+
Ay::TAx # A times y
25+
f_Ay::R # value of smooth term at y
26+
grad_f_Ay::TAx # gradient of f at Ay
27+
At_grad_f_Ay::Tx # gradient of smooth term at y
28+
# TODO: *two* gammas should be used in general, one for y and one for x
29+
gamma::R # stepsize parameter of forward and backward steps
30+
y_forward::Tx # forward point at y
31+
z::Tx # forward-backward point
32+
g_z::R # value of nonsmooth term at z
33+
res::Tx # fixed-point-residual (at y)
34+
theta::R # auxiliary sequence to compute extrapolated points
35+
F_average::R # moving average of objective values
36+
q::R # auxiliary sequence to compute moving average
37+
end
38+
39+
function Base.iterate(iter::LiLin_iterable{R}) where R
40+
y = iter.x0
41+
Ay = iter.A * y
42+
grad_f_Ay, f_Ay = gradient(iter.f, Ay)
43+
44+
# TODO: initialize gamma if not provided
45+
# TODO: authors suggest Barzilai-Borwein rule?
46+
# TODO: *two* gammas should be used in general, one for y and one for x
47+
48+
# compute initial forward-backward step
49+
At_grad_f_Ay = iter.A' * grad_f_Ay
50+
y_forward = y - iter.gamma .* At_grad_f_Ay
51+
z, g_z = prox(iter.g, y_forward, iter.gamma)
52+
53+
Fy = f_Ay + iter.g(y)
54+
55+
@assert isfinite(Fy) "initial point must be feasible"
56+
57+
# compute initial fixed-point residual
58+
res = y - z
59+
60+
state = LiLin_state(
61+
copy(iter.x0), y, Ay, f_Ay, grad_f_Ay, At_grad_f_Ay,
62+
iter.gamma, y_forward, z, g_z, res, R(1), Fy, R(1)
63+
)
64+
65+
return state, state
66+
end
67+
68+
function Base.iterate(iter::LiLin_iterable{R}, state::LiLin_state{R, Tx, TAx}) where {R, Tx, TAx}
69+
# TODO: backtrack gamma at y
70+
71+
Fz = iter.f(state.z) + state.g_z
72+
73+
theta1 = (R(1)+sqrt(R(1)+4*state.theta^2))/R(2)
74+
75+
if Fz <= state.F_average - iter.delta * norm(state.res)^2
76+
case = 1
77+
else
78+
# TODO: re-use available space in state?
79+
# TODO: backtrack gamma at x
80+
Ax = iter.A * state.x
81+
grad_f_Ax, f_Ax = gradient(iter.f, Ax)
82+
At_grad_f_Ax = iter.A' * grad_f_Ax
83+
x_forward = state.x - state.gamma .* At_grad_f_Ax
84+
v, g_v = prox(iter.g, x_forward, state.gamma)
85+
Fv = iter.f(v) + g_v
86+
case = Fz <= Fv ? 1 : 2
87+
end
88+
89+
if case == 1
90+
state.y .= state.z .+ ((state.theta - R(1)) / theta1) .* (state.z .- state.x)
91+
state.x, state.z = state.z, state.x
92+
Fx = Fz
93+
elseif case == 2
94+
state.y .= state.z .+ (state.theta / theta1) .* (state.z .- v) .+ ((state.theta - R(1)) / theta1) .* (v .- state.x)
95+
state.x = v
96+
Fx = Fv
97+
end
98+
99+
mul!(state.Ay, iter.A, state.y)
100+
state.f_Ay = gradient!(state.grad_f_Ay, iter.f, state.Ay)
101+
mul!(state.At_grad_f_Ay, adjoint(iter.A), state.grad_f_Ay)
102+
state.y_forward .= state.y .- state.gamma .* state.At_grad_f_Ay
103+
state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma)
104+
105+
state.res .= state.y - state.z
106+
107+
state.theta = theta1
108+
109+
# NOTE: the following can be simplified
110+
q1 = iter.eta * state.q + 1
111+
state.F_average = (iter.eta * state.q * state.F_average + Fx)/q1
112+
state.q = q1
113+
114+
return state, state
115+
end
116+
117+
# Solver
118+
119+
struct LiLin{R <: Real}
120+
gamma::Maybe{R}
121+
adaptive::Bool
122+
delta::R
123+
eta::R
124+
maxit::Int
125+
tol::R
126+
verbose::Bool
127+
freq::Int
128+
129+
function LiLin{R}(; gamma::Maybe{R}=nothing, adaptive::Bool=false,
130+
delta::R=R(1e-3), eta::R=R(0.8), maxit::Int=10000, tol::R=R(1e-8),
131+
verbose::Bool=false, freq::Int=100
132+
) where R
133+
@assert gamma === nothing || gamma > 0
134+
@assert delta > 0
135+
@assert 0 < eta < 1
136+
@assert maxit > 0
137+
@assert tol > 0
138+
@assert freq > 0
139+
new(gamma, adaptive, delta, eta, maxit, tol, verbose, freq)
140+
end
141+
end
142+
143+
function (solver::LiLin{R})(
144+
x0::AbstractArray{C}; f=Zero(), A=I, g=Zero(), L::Maybe{R}=nothing
145+
) where {R, C <: Union{R, Complex{R}}}
146+
147+
stop(state::LiLin_state) = norm(state.res, Inf)/state.gamma <= solver.tol
148+
disp((it, state)) = @printf(
149+
"%5d | %.3e | %.3e\n",
150+
it, state.gamma, norm(state.res, Inf)/state.gamma
151+
)
152+
153+
if solver.gamma === nothing && L !== nothing
154+
gamma = R(1)/L
155+
elseif solver.gamma !== nothing
156+
gamma = solver.gamma
157+
else
158+
gamma = nothing
159+
end
160+
161+
iter = LiLin_iterable(f, A, g, x0, gamma, solver.adaptive, solver.delta, solver.eta)
162+
iter = take(halt(iter, stop), solver.maxit)
163+
iter = enumerate(iter)
164+
if solver.verbose iter = tee(sample(iter, solver.freq), disp) end
165+
166+
num_iters, state_final = loop(iter)
167+
168+
return state_final.z, num_iters
169+
170+
end
171+
172+
# Outer constructors
173+
174+
"""
175+
LiLin([gamma, adaptive, fast, maxit, tol, verbose, freq])
176+
177+
Instantiate the nonconvex accelerated proximal gradient method by Li and Lin
178+
(see Algorithm 2 in [1]) for solving optimization problems of the form
179+
180+
minimize f(Ax) + g(x),
181+
182+
where `f` is smooth and `A` is a linear mapping (for example, a matrix).
183+
If `solver = LiLin(args...)`, then the above problem is solved with
184+
185+
solver(x0, [f, A, g, L])
186+
187+
Optional keyword arguments:
188+
189+
* `gamma::Real` (default: `nothing`), the stepsize to use; defaults to `1/L` if not set (but `L` is).
190+
* `adaptive::Bool` (default: `false`), if true, forces the method stepsize to be adaptively adjusted.
191+
* `delta::Real` (default: `1e-3`), parameter determinining when extrapolated steps are to be accepted.
192+
* `maxit::Integer` (default: `10000`), maximum number of iterations to perform.
193+
* `tol::Real` (default: `1e-8`), absolute tolerance on the fixed-point residual.
194+
* `verbose::Bool` (default: `true`), whether or not to print information during the iterations.
195+
* `freq::Integer` (default: `10`), frequency of verbosity.
196+
197+
If `gamma` is not specified at construction time, the following keyword
198+
argument can be used to set the stepsize parameter:
199+
200+
* `L::Real` (default: `nothing`), the Lipschitz constant of the gradient of x ↦ f(Ax).
201+
202+
References:
203+
204+
[1] Li, Lin, "Accelerated Proximal Gradient Methods for Nonconvex Programming",
205+
Proceedings of NIPS 2015 (2015).
206+
"""
207+
LiLin(::Type{R}; kwargs...) where R = LiLin{R}(; kwargs...)
208+
LiLin(; kwargs...) = LiLin(Float64; kwargs...)

test/problems/test_nonconvex_qp.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
using ProximalAlgorithms
2+
using ProximalOperators
3+
using LinearAlgebra
4+
using Random
5+
using Test
6+
7+
@testset "Nonconvex QP (tiny, $T)" for T in [Float64]
8+
Q = Matrix(Diagonal(T[-0.5, 1.0]))
9+
q = T[0.3, 0.5]
10+
low = T(-1.0)
11+
upp = T(+1.0)
12+
13+
f = Quadratic(Q, q)
14+
g = IndBox(low, upp)
15+
16+
n = 2
17+
18+
Lip = maximum(diag(Q))
19+
gamma = T(0.95) / Lip
20+
21+
@testset "PANOC" begin
22+
x0 = zeros(T, n)
23+
solver = ProximalAlgorithms.PANOC{T}()
24+
x, it = solver(x0, f=f, g=g)
25+
z = min.(upp, max.(low, x .- gamma .* (Q*x + q)))
26+
# println(it, " ", 0.5*dot(x, Q*x) + dot(q, x))
27+
@test norm(x - z, Inf)/gamma <= solver.tol
28+
end
29+
30+
@testset "ZeroFPR" begin
31+
x0 = zeros(T, n)
32+
solver = ProximalAlgorithms.ZeroFPR{T}()
33+
x, it = solver(x0, f=f, g=g)
34+
z = min.(upp, max.(low, x .- gamma .* (Q*x + q)))
35+
# println(it, " ", 0.5*dot(x, Q*x) + dot(q, x))
36+
@test norm(x - z, Inf)/gamma <= solver.tol
37+
end
38+
39+
@testset "LiLin" begin
40+
x0 = zeros(T, n)
41+
solver = ProximalAlgorithms.LiLin{T}(gamma=gamma)
42+
x, it = solver(x0, f=f, g=g)
43+
z = min.(upp, max.(low, x .- gamma .* (Q*x + q)))
44+
# println(it, " ", 0.5*dot(x, Q*x) + dot(q, x))
45+
@test norm(x - z, Inf)/gamma <= solver.tol
46+
end
47+
end
48+
49+
@testset "Nonconvex QP (small, $T)" for T in [Float64]
50+
@testset "Random problem $k" for k in 1:5
51+
Random.seed!(k)
52+
53+
n = 100
54+
A = randn(T, n, n)
55+
U, R = qr(A)
56+
eigenvalues = T(2) .* rand(T, n) .- T(1)
57+
Q = U*Diagonal(eigenvalues)*U'
58+
Q = 0.5*(Q + Q')
59+
q = randn(T, n)
60+
61+
low = T(-1.0)
62+
upp = T(+1.0)
63+
64+
f = Quadratic(Q, q)
65+
g = IndBox(low, upp)
66+
67+
Lip = maximum(abs.(eigenvalues))
68+
gamma = T(0.95) / Lip
69+
70+
TOL = 1e-4
71+
72+
@testset "PANOC" begin
73+
x0 = zeros(T, n)
74+
solver = ProximalAlgorithms.PANOC{T}(tol=TOL)
75+
x, it = solver(x0, f=f, g=g)
76+
z = min.(upp, max.(low, x .- gamma .* (Q*x + q)))
77+
# println(it, " ", 0.5*dot(x, Q*x) + dot(q, x))
78+
@test norm(x - z, Inf)/gamma <= solver.tol
79+
end
80+
81+
@testset "ZeroFPR" begin
82+
x0 = zeros(T, n)
83+
solver = ProximalAlgorithms.ZeroFPR{T}(tol=TOL)
84+
x, it = solver(x0, f=f, g=g)
85+
z = min.(upp, max.(low, x .- gamma .* (Q*x + q)))
86+
# println(it, " ", 0.5*dot(x, Q*x) + dot(q, x))
87+
@test norm(x - z, Inf)/gamma <= solver.tol
88+
end
89+
90+
@testset "LiLin" begin
91+
x0 = zeros(T, n)
92+
solver = ProximalAlgorithms.LiLin(gamma=gamma, tol=TOL)
93+
x, it = solver(x0, f=f, g=g)
94+
z = min.(upp, max.(low, x .- gamma .* (Q*x + q)))
95+
# println(it, " ", 0.5*dot(x, Q*x) + dot(q, x))
96+
@test norm(x - z, Inf)/gamma <= solver.tol
97+
end
98+
end
99+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ include("problems/test_lasso_small_h_split.jl")
1313
include("problems/test_linear_programs.jl")
1414
include("problems/test_sparse_logistic_small.jl")
1515
include("problems/test_verbose.jl")
16+
include("problems/test_nonconvex_qp.jl")

0 commit comments

Comments
 (0)