Skip to content

Commit 1799708

Browse files
authored
Merge pull request #1169 from SciML/smc/manopt-retcodes
Improve return codes from OptimizationManopt
2 parents d17e99b + 14578b1 commit 1799708

2 files changed

Lines changed: 47 additions & 38 deletions

File tree

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,50 +18,27 @@ OptimizationBase.supports_sense(::AbstractManoptOptimizer) = true
1818

1919
function __map_optimizer_args!(
2020
cache::OptimizationBase.OptimizationCache,
21-
opt::AbstractManoptOptimizer;
21+
opt::AbstractManoptOptimizer,
22+
manifold;
2223
callback = nothing,
2324
maxiters::Union{Number, Nothing} = nothing,
2425
maxtime::Union{Number, Nothing} = nothing,
2526
abstol::Union{Number, Nothing} = nothing,
2627
reltol::Union{Number, Nothing} = nothing,
2728
kwargs...
2829
)
29-
solver_kwargs = (; kwargs...)
30+
criteria = Manopt.StoppingCriterion[]
3031

3132
if !isnothing(maxiters)
32-
solver_kwargs = (;
33-
solver_kwargs..., stopping_criterion = [Manopt.StopAfterIteration(maxiters)],
34-
)
33+
push!(criteria, Manopt.StopAfterIteration(maxiters))
3534
end
3635

3736
if !isnothing(maxtime)
38-
if haskey(solver_kwargs, :stopping_criterion)
39-
solver_kwargs = (;
40-
solver_kwargs...,
41-
stopping_criterion = push!(
42-
solver_kwargs.stopping_criterion, Manopt.StopAfterTime(maxtime)
43-
),
44-
)
45-
else
46-
solver_kwargs = (;
47-
solver_kwargs..., stopping_criterion = [Manopt.StopAfter(maxtime)],
48-
)
49-
end
37+
push!(criteria, Manopt.StopAfter(maxtime))
5038
end
5139

5240
if !isnothing(abstol)
53-
if haskey(solver_kwargs, :stopping_criterion)
54-
solver_kwargs = (;
55-
solver_kwargs...,
56-
stopping_criterion = push!(
57-
solver_kwargs.stopping_criterion, Manopt.StopWhenChangeLess(abstol)
58-
),
59-
)
60-
else
61-
solver_kwargs = (;
62-
solver_kwargs..., stopping_criterion = [Manopt.StopWhenChangeLess(abstol)],
63-
)
64-
end
41+
push!(criteria, _default_convergence_criterion(opt, manifold, abstol))
6542
end
6643

6744
if !isnothing(reltol)
@@ -70,6 +47,11 @@ function __map_optimizer_args!(
7047
cache.verbose, :unsupported_kwargs
7148
)
7249
end
50+
51+
solver_kwargs = (; kwargs...)
52+
if !isempty(criteria)
53+
solver_kwargs = (; solver_kwargs..., stopping_criterion = criteria)
54+
end
7355
return solver_kwargs
7456
end
7557

@@ -283,6 +265,20 @@ function SciMLBase.requireshessian(
283265
return true
284266
end
285267

268+
const GradientBasedManoptOptimizer = Union{
269+
GradientDescentOptimizer, ConjugateGradientDescentOptimizer,
270+
QuasiNewtonOptimizer, ConvexBundleOptimizer, FrankWolfeOptimizer,
271+
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer,
272+
}
273+
274+
function _default_convergence_criterion(::GradientBasedManoptOptimizer, M, abstol)
275+
return Manopt.StopWhenGradientNormLess(abstol)
276+
end
277+
278+
function _default_convergence_criterion(::AbstractManoptOptimizer, M, abstol)
279+
return Manopt.StopWhenChangeLess(M, abstol)
280+
end
281+
286282
function build_loss(f::OptimizationFunction, prob, cb)
287283
# TODO: I do not understand this. Why is the manifold not used?
288284
# Either this is an Euclidean cost, then we should probably still call `embed`,
@@ -353,11 +349,11 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractMano
353349
return cb_call
354350
end
355351
solver_kwarg = __map_optimizer_args!(
356-
cache, cache.opt, callback = _cb,
352+
cache, cache.opt, manifold; callback = _cb,
357353
maxiters = cache.solver_args.maxiters,
358354
maxtime = cache.solver_args.maxtime,
359355
abstol = cache.solver_args.abstol,
360-
reltol = cache.solver_args.reltol;
356+
reltol = cache.solver_args.reltol,
361357
cache.solver_args...
362358
)
363359

@@ -371,19 +367,32 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractMano
371367
hessF = build_hessF(cache.f)
372368
end
373369

374-
if haskey(solver_kwarg, :stopping_criterion)
375-
stopping_criterion = Manopt.StopWhenAny(solver_kwarg.stopping_criterion...)
370+
stopping_kwarg = if haskey(solver_kwarg, :stopping_criterion)
371+
(; stopping_criterion = Manopt.StopWhenAny(solver_kwarg.stopping_criterion...))
376372
else
377-
stopping_criterion = Manopt.StopAfterIteration(500)
373+
(;)
378374
end
379375

380376
opt_res = call_manopt_optimizer(
381377
manifold, cache.opt, _loss, gradF, cache.u0;
382-
solver_kwarg..., stopping_criterion = stopping_criterion, hessF
378+
solver_kwarg..., stopping_kwarg..., hessF
383379
)
384380

385381
asc = get_stopping_criterion(opt_res.options)
386-
opt_ret = Manopt.has_converged(asc) ? ReturnCode.Success : ReturnCode.Failure
382+
active = Manopt.get_active_stopping_criteria(asc)
383+
opt_ret = if Manopt.has_converged(asc)
384+
ReturnCode.Success
385+
elseif any(c -> c isa Manopt.StopAfterIteration, active)
386+
ReturnCode.MaxIters
387+
elseif any(c -> c isa Manopt.StopAfter, active)
388+
ReturnCode.MaxTime
389+
elseif any(c -> c isa Union{Manopt.StopWhenCostNaN, Manopt.StopWhenIterateNaN}, active)
390+
ReturnCode.Unstable
391+
elseif any(c -> c isa Manopt.StopWhenStepsizeLess, active)
392+
ReturnCode.Stalled
393+
else
394+
ReturnCode.Failure
395+
end
387396

388397
return SciMLBase.build_solution(
389398
cache,

lib/OptimizationManopt/test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ R2 = Euclidean(2)
162162

163163
sol = OptimizationBase.solve(prob, opt)
164164
@test sol.objective < 0.1
165-
@test SciMLBase.successful_retcode(sol) broken = true
165+
@test SciMLBase.successful_retcode(sol)
166166
end
167167

168168
@testset "TrustRegions" begin
@@ -178,7 +178,7 @@ R2 = Euclidean(2)
178178

179179
sol = OptimizationBase.solve(prob, opt)
180180
@test sol.objective < 0.1
181-
@test SciMLBase.successful_retcode(sol) broken = true
181+
@test SciMLBase.successful_retcode(sol)
182182
end
183183

184184
@testset "Custom constraints" begin

0 commit comments

Comments
 (0)