@@ -18,50 +18,27 @@ OptimizationBase.supports_sense(::AbstractManoptOptimizer) = true
1818
1919function __map_optimizer_args! (
2020 cache:: OptimizationBase.OptimizationCache ,
21- opt:: AbstractManoptOptimizer ;
21+ opt:: AbstractManoptOptimizer ,
22+ manifold;
2223 callback = nothing ,
2324 maxiters:: Union{Number, Nothing} = nothing ,
2425 maxtime:: Union{Number, Nothing} = nothing ,
2526 abstol:: Union{Number, Nothing} = nothing ,
2627 reltol:: Union{Number, Nothing} = nothing ,
2728 kwargs...
2829 )
29- solver_kwargs = (; kwargs ... )
30+ criteria = Manopt . StoppingCriterion[]
3031
3132 if ! isnothing (maxiters)
32- solver_kwargs = (;
33- solver_kwargs... , stopping_criterion = [Manopt. StopAfterIteration (maxiters)],
34- )
33+ push! (criteria, Manopt. StopAfterIteration (maxiters))
3534 end
3635
3736 if ! isnothing (maxtime)
38- if haskey (solver_kwargs, :stopping_criterion )
39- solver_kwargs = (;
40- solver_kwargs... ,
41- stopping_criterion = push! (
42- solver_kwargs. stopping_criterion, Manopt. StopAfterTime (maxtime)
43- ),
44- )
45- else
46- solver_kwargs = (;
47- solver_kwargs... , stopping_criterion = [Manopt. StopAfter (maxtime)],
48- )
49- end
37+ push! (criteria, Manopt. StopAfter (maxtime))
5038 end
5139
5240 if ! isnothing (abstol)
53- if haskey (solver_kwargs, :stopping_criterion )
54- solver_kwargs = (;
55- solver_kwargs... ,
56- stopping_criterion = push! (
57- solver_kwargs. stopping_criterion, Manopt. StopWhenChangeLess (abstol)
58- ),
59- )
60- else
61- solver_kwargs = (;
62- solver_kwargs... , stopping_criterion = [Manopt. StopWhenChangeLess (abstol)],
63- )
64- end
41+ push! (criteria, _default_convergence_criterion (opt, manifold, abstol))
6542 end
6643
6744 if ! isnothing (reltol)
@@ -70,6 +47,11 @@ function __map_optimizer_args!(
7047 cache. verbose, :unsupported_kwargs
7148 )
7249 end
50+
51+ solver_kwargs = (; kwargs... )
52+ if ! isempty (criteria)
53+ solver_kwargs = (; solver_kwargs... , stopping_criterion = criteria)
54+ end
7355 return solver_kwargs
7456end
7557
@@ -283,6 +265,20 @@ function SciMLBase.requireshessian(
283265 return true
284266end
285267
268+ const GradientBasedManoptOptimizer = Union{
269+ GradientDescentOptimizer, ConjugateGradientDescentOptimizer,
270+ QuasiNewtonOptimizer, ConvexBundleOptimizer, FrankWolfeOptimizer,
271+ AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer,
272+ }
273+
274+ function _default_convergence_criterion (:: GradientBasedManoptOptimizer , M, abstol)
275+ return Manopt. StopWhenGradientNormLess (abstol)
276+ end
277+
278+ function _default_convergence_criterion (:: AbstractManoptOptimizer , M, abstol)
279+ return Manopt. StopWhenChangeLess (M, abstol)
280+ end
281+
286282function build_loss (f:: OptimizationFunction , prob, cb)
287283 # TODO : I do not understand this. Why is the manifold not used?
288284 # Either this is an Euclidean cost, then we should probably still call `embed`,
@@ -353,11 +349,11 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractMano
353349 return cb_call
354350 end
355351 solver_kwarg = __map_optimizer_args! (
356- cache, cache. opt, callback = _cb,
352+ cache, cache. opt, manifold; callback = _cb,
357353 maxiters = cache. solver_args. maxiters,
358354 maxtime = cache. solver_args. maxtime,
359355 abstol = cache. solver_args. abstol,
360- reltol = cache. solver_args. reltol;
356+ reltol = cache. solver_args. reltol,
361357 cache. solver_args...
362358 )
363359
@@ -371,19 +367,32 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractMano
371367 hessF = build_hessF (cache. f)
372368 end
373369
374- if haskey (solver_kwarg, :stopping_criterion )
375- stopping_criterion = Manopt. StopWhenAny (solver_kwarg. stopping_criterion... )
370+ stopping_kwarg = if haskey (solver_kwarg, :stopping_criterion )
371+ (; stopping_criterion = Manopt. StopWhenAny (solver_kwarg. stopping_criterion... ) )
376372 else
377- stopping_criterion = Manopt . StopAfterIteration ( 500 )
373+ (; )
378374 end
379375
380376 opt_res = call_manopt_optimizer (
381377 manifold, cache. opt, _loss, gradF, cache. u0;
382- solver_kwarg... , stopping_criterion = stopping_criterion , hessF
378+ solver_kwarg... , stopping_kwarg ... , hessF
383379 )
384380
385381 asc = get_stopping_criterion (opt_res. options)
386- opt_ret = Manopt. has_converged (asc) ? ReturnCode. Success : ReturnCode. Failure
382+ active = Manopt. get_active_stopping_criteria (asc)
383+ opt_ret = if Manopt. has_converged (asc)
384+ ReturnCode. Success
385+ elseif any (c -> c isa Manopt. StopAfterIteration, active)
386+ ReturnCode. MaxIters
387+ elseif any (c -> c isa Manopt. StopAfter, active)
388+ ReturnCode. MaxTime
389+ elseif any (c -> c isa Union{Manopt. StopWhenCostNaN, Manopt. StopWhenIterateNaN}, active)
390+ ReturnCode. Unstable
391+ elseif any (c -> c isa Manopt. StopWhenStepsizeLess, active)
392+ ReturnCode. Stalled
393+ else
394+ ReturnCode. Failure
395+ end
387396
388397 return SciMLBase. build_solution (
389398 cache,
0 commit comments