Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 45 additions & 36 deletions lib/OptimizationManopt/src/OptimizationManopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,50 +18,27 @@ OptimizationBase.supports_sense(::AbstractManoptOptimizer) = true

function __map_optimizer_args!(
cache::OptimizationBase.OptimizationCache,
opt::AbstractManoptOptimizer;
opt::AbstractManoptOptimizer,
manifold;
callback = nothing,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
kwargs...
)
solver_kwargs = (; kwargs...)
criteria = Manopt.StoppingCriterion[]

if !isnothing(maxiters)
solver_kwargs = (;
solver_kwargs..., stopping_criterion = [Manopt.StopAfterIteration(maxiters)],
)
push!(criteria, Manopt.StopAfterIteration(maxiters))
end

if !isnothing(maxtime)
if haskey(solver_kwargs, :stopping_criterion)
solver_kwargs = (;
solver_kwargs...,
stopping_criterion = push!(
solver_kwargs.stopping_criterion, Manopt.StopAfterTime(maxtime)
),
)
else
solver_kwargs = (;
solver_kwargs..., stopping_criterion = [Manopt.StopAfter(maxtime)],
)
end
push!(criteria, Manopt.StopAfter(maxtime))
end

if !isnothing(abstol)
if haskey(solver_kwargs, :stopping_criterion)
solver_kwargs = (;
solver_kwargs...,
stopping_criterion = push!(
solver_kwargs.stopping_criterion, Manopt.StopWhenChangeLess(abstol)
),
)
else
solver_kwargs = (;
solver_kwargs..., stopping_criterion = [Manopt.StopWhenChangeLess(abstol)],
)
end
push!(criteria, _default_convergence_criterion(opt, manifold, abstol))
end

if !isnothing(reltol)
Expand All @@ -70,6 +47,11 @@ function __map_optimizer_args!(
cache.verbose, :unsupported_kwargs
)
end

solver_kwargs = (; kwargs...)
if !isempty(criteria)
solver_kwargs = (; solver_kwargs..., stopping_criterion = criteria)
end
return solver_kwargs
end

Expand Down Expand Up @@ -283,6 +265,20 @@ function SciMLBase.requireshessian(
return true
end

const GradientBasedManoptOptimizer = Union{
GradientDescentOptimizer, ConjugateGradientDescentOptimizer,
QuasiNewtonOptimizer, ConvexBundleOptimizer, FrankWolfeOptimizer,
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer,
}

function _default_convergence_criterion(::GradientBasedManoptOptimizer, M, abstol)
return Manopt.StopWhenGradientNormLess(abstol)
end

function _default_convergence_criterion(::AbstractManoptOptimizer, M, abstol)
return Manopt.StopWhenChangeLess(M, abstol)
end

function build_loss(f::OptimizationFunction, prob, cb)
# TODO: I do not understand this. Why is the manifold not used?
# Either this is an Euclidean cost, then we should probably still call `embed`,
Expand Down Expand Up @@ -353,11 +349,11 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractMano
return cb_call
end
solver_kwarg = __map_optimizer_args!(
cache, cache.opt, callback = _cb,
cache, cache.opt, manifold; callback = _cb,
maxiters = cache.solver_args.maxiters,
maxtime = cache.solver_args.maxtime,
abstol = cache.solver_args.abstol,
reltol = cache.solver_args.reltol;
reltol = cache.solver_args.reltol,
cache.solver_args...
)

Expand All @@ -371,19 +367,32 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractMano
hessF = build_hessF(cache.f)
end

if haskey(solver_kwarg, :stopping_criterion)
stopping_criterion = Manopt.StopWhenAny(solver_kwarg.stopping_criterion...)
stopping_kwarg = if haskey(solver_kwarg, :stopping_criterion)
(; stopping_criterion = Manopt.StopWhenAny(solver_kwarg.stopping_criterion...))
else
stopping_criterion = Manopt.StopAfterIteration(500)
(;)
end

opt_res = call_manopt_optimizer(
manifold, cache.opt, _loss, gradF, cache.u0;
solver_kwarg..., stopping_criterion = stopping_criterion, hessF
solver_kwarg..., stopping_kwarg..., hessF
)

asc = get_stopping_criterion(opt_res.options)
opt_ret = Manopt.has_converged(asc) ? ReturnCode.Success : ReturnCode.Failure
active = Manopt.get_active_stopping_criteria(asc)
opt_ret = if Manopt.has_converged(asc)
ReturnCode.Success
elseif any(c -> c isa Manopt.StopAfterIteration, active)
ReturnCode.MaxIters
elseif any(c -> c isa Manopt.StopAfter, active)
ReturnCode.MaxTime
elseif any(c -> c isa Union{Manopt.StopWhenCostNaN, Manopt.StopWhenIterateNaN}, active)
ReturnCode.Unstable
elseif any(c -> c isa Manopt.StopWhenStepsizeLess, active)
ReturnCode.Stalled
else
ReturnCode.Failure
end

return SciMLBase.build_solution(
cache,
Expand Down
4 changes: 2 additions & 2 deletions lib/OptimizationManopt/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ R2 = Euclidean(2)

sol = OptimizationBase.solve(prob, opt)
@test sol.objective < 0.1
@test SciMLBase.successful_retcode(sol) broken = true
@test SciMLBase.successful_retcode(sol)
end

@testset "TrustRegions" begin
Expand All @@ -178,7 +178,7 @@ R2 = Euclidean(2)

sol = OptimizationBase.solve(prob, opt)
@test sol.objective < 0.1
@test SciMLBase.successful_retcode(sol) broken = true
@test SciMLBase.successful_retcode(sol)
end

@testset "Custom constraints" begin
Expand Down
Loading