|
| 1 | +# Li, Lin, "Accelerated Proximal Gradient Methods for Nonconvex Programming", |
| 2 | +# Proceedings of NIPS 2015 (2015). |
| 3 | + |
| 4 | +using Base.Iterators |
| 5 | +using ProximalAlgorithms.IterationTools |
| 6 | +using ProximalOperators: Zero |
| 7 | +using LinearAlgebra |
| 8 | +using Printf |
| 9 | + |
| 10 | +struct LiLin_iterable{R <: Real, C <: Union{R, Complex{R}}, Tx <: AbstractArray{C}, Tf, TA, Tg} |
| 11 | + f::Tf # smooth term |
| 12 | + A::TA # matrix/linear operator |
| 13 | + g::Tg # (possibly) nonsmooth, proximable term |
| 14 | + x0::Tx # initial point |
| 15 | + gamma::Maybe{R} # stepsize parameter of forward and backward steps |
| 16 | + adaptive::Bool # enforce adaptive stepsize even if L is provided |
| 17 | + delta::R # |
| 18 | + eta::R # |
| 19 | +end |
| 20 | + |
| 21 | +mutable struct LiLin_state{R <: Real, Tx, TAx} |
| 22 | + x::Tx # iterate |
| 23 | + y::Tx # extrapolated point |
| 24 | + Ay::TAx # A times y |
| 25 | + f_Ay::R # value of smooth term at y |
| 26 | + grad_f_Ay::TAx # gradient of f at Ay |
| 27 | + At_grad_f_Ay::Tx # gradient of smooth term at y |
| 28 | + # TODO: *two* gammas should be used in general, one for y and one for x |
| 29 | + gamma::R # stepsize parameter of forward and backward steps |
| 30 | + y_forward::Tx # forward point at y |
| 31 | + z::Tx # forward-backward point |
| 32 | + g_z::R # value of nonsmooth term at z |
| 33 | + res::Tx # fixed-point-residual (at y) |
| 34 | + theta::R # auxiliary sequence to compute extrapolated points |
| 35 | + F_average::R # moving average of objective values |
| 36 | + q::R # auxiliary sequence to compute moving average |
| 37 | +end |
| 38 | + |
| 39 | +function Base.iterate(iter::LiLin_iterable{R}) where R |
| 40 | + y = iter.x0 |
| 41 | + Ay = iter.A * y |
| 42 | + grad_f_Ay, f_Ay = gradient(iter.f, Ay) |
| 43 | + |
| 44 | + # TODO: initialize gamma if not provided |
| 45 | + # TODO: authors suggest Barzilai-Borwein rule? |
| 46 | + # TODO: *two* gammas should be used in general, one for y and one for x |
| 47 | + |
| 48 | + # compute initial forward-backward step |
| 49 | + At_grad_f_Ay = iter.A' * grad_f_Ay |
| 50 | + y_forward = y - iter.gamma .* At_grad_f_Ay |
| 51 | + z, g_z = prox(iter.g, y_forward, iter.gamma) |
| 52 | + |
| 53 | + Fy = f_Ay + iter.g(y) |
| 54 | + |
| 55 | + @assert isfinite(Fy) "initial point must be feasible" |
| 56 | + |
| 57 | + # compute initial fixed-point residual |
| 58 | + res = y - z |
| 59 | + |
| 60 | + state = LiLin_state( |
| 61 | + copy(iter.x0), y, Ay, f_Ay, grad_f_Ay, At_grad_f_Ay, |
| 62 | + iter.gamma, y_forward, z, g_z, res, R(1), Fy, R(1) |
| 63 | + ) |
| 64 | + |
| 65 | + return state, state |
| 66 | +end |
| 67 | + |
| 68 | +function Base.iterate(iter::LiLin_iterable{R}, state::LiLin_state{R, Tx, TAx}) where {R, Tx, TAx} |
| 69 | + # TODO: backtrack gamma at y |
| 70 | + |
| 71 | + Fz = iter.f(state.z) + state.g_z |
| 72 | + |
| 73 | + theta1 = (R(1)+sqrt(R(1)+4*state.theta^2))/R(2) |
| 74 | + |
| 75 | + if Fz <= state.F_average - iter.delta * norm(state.res)^2 |
| 76 | + case = 1 |
| 77 | + else |
| 78 | + # TODO: re-use available space in state? |
| 79 | + # TODO: backtrack gamma at x |
| 80 | + Ax = iter.A * state.x |
| 81 | + grad_f_Ax, f_Ax = gradient(iter.f, Ax) |
| 82 | + At_grad_f_Ax = iter.A' * grad_f_Ax |
| 83 | + x_forward = state.x - state.gamma .* At_grad_f_Ax |
| 84 | + v, g_v = prox(iter.g, x_forward, state.gamma) |
| 85 | + Fv = iter.f(v) + g_v |
| 86 | + case = Fz <= Fv ? 1 : 2 |
| 87 | + end |
| 88 | + |
| 89 | + if case == 1 |
| 90 | + state.y .= state.z .+ ((state.theta - R(1)) / theta1) .* (state.z .- state.x) |
| 91 | + state.x, state.z = state.z, state.x |
| 92 | + Fx = Fz |
| 93 | + elseif case == 2 |
| 94 | + state.y .= state.z .+ (state.theta / theta1) .* (state.z .- v) .+ ((state.theta - R(1)) / theta1) .* (v .- state.x) |
| 95 | + state.x = v |
| 96 | + Fx = Fv |
| 97 | + end |
| 98 | + |
| 99 | + mul!(state.Ay, iter.A, state.y) |
| 100 | + state.f_Ay = gradient!(state.grad_f_Ay, iter.f, state.Ay) |
| 101 | + mul!(state.At_grad_f_Ay, adjoint(iter.A), state.grad_f_Ay) |
| 102 | + state.y_forward .= state.y .- state.gamma .* state.At_grad_f_Ay |
| 103 | + state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma) |
| 104 | + |
| 105 | + state.res .= state.y - state.z |
| 106 | + |
| 107 | + state.theta = theta1 |
| 108 | + |
| 109 | + # NOTE: the following can be simplified |
| 110 | + q1 = iter.eta * state.q + 1 |
| 111 | + state.F_average = (iter.eta * state.q * state.F_average + Fx)/q1 |
| 112 | + state.q = q1 |
| 113 | + |
| 114 | + return state, state |
| 115 | +end |
| 116 | + |
| 117 | +# Solver |
| 118 | + |
| 119 | +struct LiLin{R <: Real} |
| 120 | + gamma::Maybe{R} |
| 121 | + adaptive::Bool |
| 122 | + delta::R |
| 123 | + eta::R |
| 124 | + maxit::Int |
| 125 | + tol::R |
| 126 | + verbose::Bool |
| 127 | + freq::Int |
| 128 | + |
| 129 | + function LiLin{R}(; gamma::Maybe{R}=nothing, adaptive::Bool=false, |
| 130 | + delta::R=R(1e-3), eta::R=R(0.8), maxit::Int=10000, tol::R=R(1e-8), |
| 131 | + verbose::Bool=false, freq::Int=100 |
| 132 | + ) where R |
| 133 | + @assert gamma === nothing || gamma > 0 |
| 134 | + @assert delta > 0 |
| 135 | + @assert 0 < eta < 1 |
| 136 | + @assert maxit > 0 |
| 137 | + @assert tol > 0 |
| 138 | + @assert freq > 0 |
| 139 | + new(gamma, adaptive, delta, eta, maxit, tol, verbose, freq) |
| 140 | + end |
| 141 | +end |
| 142 | + |
| 143 | +function (solver::LiLin{R})( |
| 144 | + x0::AbstractArray{C}; f=Zero(), A=I, g=Zero(), L::Maybe{R}=nothing |
| 145 | +) where {R, C <: Union{R, Complex{R}}} |
| 146 | + |
| 147 | + stop(state::LiLin_state) = norm(state.res, Inf)/state.gamma <= solver.tol |
| 148 | + disp((it, state)) = @printf( |
| 149 | + "%5d | %.3e | %.3e\n", |
| 150 | + it, state.gamma, norm(state.res, Inf)/state.gamma |
| 151 | + ) |
| 152 | + |
| 153 | + if solver.gamma === nothing && L !== nothing |
| 154 | + gamma = R(1)/L |
| 155 | + elseif solver.gamma !== nothing |
| 156 | + gamma = solver.gamma |
| 157 | + else |
| 158 | + gamma = nothing |
| 159 | + end |
| 160 | + |
| 161 | + iter = LiLin_iterable(f, A, g, x0, gamma, solver.adaptive, solver.delta, solver.eta) |
| 162 | + iter = take(halt(iter, stop), solver.maxit) |
| 163 | + iter = enumerate(iter) |
| 164 | + if solver.verbose iter = tee(sample(iter, solver.freq), disp) end |
| 165 | + |
| 166 | + num_iters, state_final = loop(iter) |
| 167 | + |
| 168 | + return state_final.z, num_iters |
| 169 | + |
| 170 | +end |
| 171 | + |
| 172 | +# Outer constructors |
| 173 | + |
| 174 | +""" |
| 175 | + LiLin([gamma, adaptive, fast, maxit, tol, verbose, freq]) |
| 176 | +
|
| 177 | +Instantiate the nonconvex accelerated proximal gradient method by Li and Lin |
| 178 | +(see Algorithm 2 in [1]) for solving optimization problems of the form |
| 179 | +
|
| 180 | + minimize f(Ax) + g(x), |
| 181 | +
|
| 182 | +where `f` is smooth and `A` is a linear mapping (for example, a matrix). |
| 183 | +If `solver = LiLin(args...)`, then the above problem is solved with |
| 184 | +
|
| 185 | + solver(x0, [f, A, g, L]) |
| 186 | +
|
| 187 | +Optional keyword arguments: |
| 188 | +
|
| 189 | +* `gamma::Real` (default: `nothing`), the stepsize to use; defaults to `1/L` if not set (but `L` is). |
| 190 | +* `adaptive::Bool` (default: `false`), if true, forces the method stepsize to be adaptively adjusted. |
| 191 | +* `delta::Real` (default: `1e-3`), parameter determinining when extrapolated steps are to be accepted. |
| 192 | +* `maxit::Integer` (default: `10000`), maximum number of iterations to perform. |
| 193 | +* `tol::Real` (default: `1e-8`), absolute tolerance on the fixed-point residual. |
| 194 | +* `verbose::Bool` (default: `true`), whether or not to print information during the iterations. |
| 195 | +* `freq::Integer` (default: `10`), frequency of verbosity. |
| 196 | +
|
| 197 | +If `gamma` is not specified at construction time, the following keyword |
| 198 | +argument can be used to set the stepsize parameter: |
| 199 | +
|
| 200 | +* `L::Real` (default: `nothing`), the Lipschitz constant of the gradient of x ↦ f(Ax). |
| 201 | +
|
| 202 | +References: |
| 203 | +
|
| 204 | +[1] Li, Lin, "Accelerated Proximal Gradient Methods for Nonconvex Programming", |
| 205 | +Proceedings of NIPS 2015 (2015). |
| 206 | +""" |
| 207 | +LiLin(::Type{R}; kwargs...) where R = LiLin{R}(; kwargs...) |
| 208 | +LiLin(; kwargs...) = LiLin(Float64; kwargs...) |
0 commit comments