Skip to content
Merged
149 changes: 149 additions & 0 deletions ext/MadNLPSolverExt/MadNLPSolverExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,162 @@ include("utils.jl")
@test opts.max_iter == 3000
@test opts.print_level == 3
@test opts.hessian_approximation == "exact"
@test opts.intermediate_callback === nothing
@test opts.fixed_variable_treatment === nothing

opts2 = DirectTrajOpt.MadNLPOptions(max_iter = 100, tol = 1e-6)
@test opts2.max_iter == 100
@test opts2.tol == 1e-6
@test opts isa Solvers.AbstractSolverOptions
end

@testitem "MadNLP intermediate_callback (raw MadNLP callback) fires per iter" setup=[
DTOTestHelpers,
] begin
import MadNLP

mutable struct _IterCounter <: MadNLP.AbstractUserCallback
count::Base.RefValue{Int}
end
(cb::_IterCounter)(::MadNLP.AbstractMadNLPSolver, _) = (cb.count[] += 1; true)

cb = _IterCounter(Ref(0))
prob, _ = make_standard_prob()
solve!(
prob;
options = DirectTrajOpt.MadNLPOptions(
max_iter = 5,
intermediate_callback = cb,
fixed_variable_treatment = MadNLP.RelaxBound,
),
verbose = false,
)
@test cb.count[] > 0
end

@testitem "MadNLP intermediate_callback (AbstractIntermediateCallback) fires per iter" setup=[
DTOTestHelpers,
] begin
import MadNLP

mutable struct _AgnosticCounter <: DirectTrajOpt.AbstractIntermediateCallback
count::Base.RefValue{Int}
last_primal_len::Base.RefValue{Int}
end
function (cb::_AgnosticCounter)(primal::AbstractVector, iter::Integer)
cb.count[] += 1
cb.last_primal_len[] = length(primal)
return true
end

cb = _AgnosticCounter(Ref(0), Ref(0))
prob, _ = make_standard_prob()
solve!(
prob;
options = DirectTrajOpt.MadNLPOptions(
max_iter = 5,
intermediate_callback = cb,
fixed_variable_treatment = MadNLP.RelaxBound,
),
verbose = false,
)
@test cb.count[] > 0
# With RelaxBound, the primal vector matches the full NLP variable count.
@test cb.last_primal_len[] ==
length(prob.trajectory.datavec) + prob.trajectory.global_dim
end

@testitem "MadNLP intermediate_callback auto-couples RelaxBound" setup=[DTOTestHelpers] begin
import MadNLP

mutable struct _AutoCoupleProbe <: DirectTrajOpt.AbstractIntermediateCallback
last_primal_len::Base.RefValue{Int}
end
function (cb::_AutoCoupleProbe)(primal::AbstractVector, _)
cb.last_primal_len[] = length(primal)
return true
end

cb = _AutoCoupleProbe(Ref(0))
prob, _ = make_standard_prob()
# Note: NOT passing fixed_variable_treatment. set_options! should auto-set it.
solve!(
prob;
options = DirectTrajOpt.MadNLPOptions(max_iter = 5, intermediate_callback = cb),
verbose = false,
)
# If RelaxBound auto-coupled correctly, the primal includes fixed variables.
@test cb.last_primal_len[] ==
length(prob.trajectory.datavec) + prob.trajectory.global_dim
end

@testitem "MadNLP auto-couple respects MadNLP's conditional default" setup=[DTOTestHelpers] begin
import MadNLP

mutable struct _PassthroughProbe <: DirectTrajOpt.AbstractIntermediateCallback
len::Base.RefValue{Int}
end
(cb::_PassthroughProbe)(primal, _) = (cb.len[] = length(primal); true)

cb = _PassthroughProbe(Ref(0))
prob, _ = make_standard_prob()
# With `kkt_system = SparseCondensedKKTSystem`, MadNLP's own conditional
# default for `fixed_variable_treatment` is already `RelaxBound`, so the
# auto-couple should not fire. Capture logs and assert our @info is absent.
logs, _ = Test.collect_test_logs() do
solve!(
prob;
options = DirectTrajOpt.MadNLPOptions(
max_iter = 5,
intermediate_callback = cb,
kkt_system = MadNLP.SparseCondensedKKTSystem,
),
verbose = false,
)
end
@test !any(l -> occursin("Setting fixed_variable_treatment", l.message), logs)
# MadNLP's untouched conditional default still yields the full primal.
@test cb.len[] == length(prob.trajectory.datavec) + prob.trajectory.global_dim
end

@testitem "MadNLP intermediate_callback early termination via return false" setup=[
DTOTestHelpers,
] begin
import MadNLP

mutable struct _Stopper <: DirectTrajOpt.AbstractIntermediateCallback
max_iters::Int
count::Base.RefValue{Int}
end
function (cb::_Stopper)(_, _)
cb.count[] += 1
return cb.count[] < cb.max_iters
end

cb = _Stopper(3, Ref(0))
prob, _ = make_standard_prob()
solve!(
prob;
options = DirectTrajOpt.MadNLPOptions(max_iter = 100, intermediate_callback = cb),
verbose = false,
)
# Callback stopped the solve well before max_iter=100.
@test cb.count[] <= 5
end

@testitem "MadNLP intermediate_callback rejects invalid type" setup=[DTOTestHelpers] begin
prob, _ = make_standard_prob()
bogus_cb(args...) = true # bare Function — neither abstract nor MadNLP subtype
@test_throws ArgumentError solve!(
prob;
options = DirectTrajOpt.MadNLPOptions(
max_iter = 5,
intermediate_callback = bogus_cb,
),
verbose = false,
)
end

@testitem "MadNLP basic solve" setup=[DTOTestHelpers] begin
prob, _ = make_standard_prob()
traj_before = deepcopy(prob.trajectory.data)
Expand Down
62 changes: 62 additions & 0 deletions ext/MadNLPSolverExt/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,55 @@ end
# ----------------------------------------------------------------------------


"""
_MadNLPCallbackAdapter(inner)

Wrap a `DirectTrajOpt.AbstractIntermediateCallback` so MadNLP can call it
with its native `(solver, mode)` signature.

The adapter:
- **Filters mode to `UserCallbackRegular`.** MadNLP also invokes user
callbacks during feasibility restoration and robust mode; those phases
surface intermediate IPM state that's typically not meaningful to a
trajectory-level callback, so they're silently skipped (return `true`).
This makes `solver.cnt.k` monotonic from the callback's point of view.
- **Translates `solver.x` → `MadNLP.variable(solver.x)`** to strip the
slack tail and hand back just the NLP primal.
- **Forwards `solver.cnt.k`** as the iteration index.
"""
struct _MadNLPCallbackAdapter <: MadNLP.AbstractUserCallback
inner::DirectTrajOpt.AbstractIntermediateCallback
end

function (a::_MadNLPCallbackAdapter)(
solver::MadNLP.AbstractMadNLPSolver,
mode::MadNLP.AbstractUserCallbackStatus,
)
mode isa MadNLP.UserCallbackRegular || return true
return a.inner(MadNLP.variable(solver.x), solver.cnt.k)
end

function DirectTrajOpt.set_options!(optimizer::AbstractOptimizer, options::MadNLPOptions)
ignored_options = [:eval_hessian]

# Auto-couple: an AbstractIntermediateCallback needs the full primal vector,
# which requires fixed_variable_treatment = RelaxBound. We only override when
# MadNLP's own conditional default (`kkt_system <: SparseCondensedKKTSystem ?
# RelaxBound : MakeParameter`) would otherwise pick `MakeParameter` and break
# the callback. When the user has selected a kkt_system whose default is
# already `RelaxBound`, MadNLP's if-one-liner gets to do its job untouched.
# Raw MadNLP callbacks are presumed to manage this themselves.
if options.intermediate_callback isa DirectTrajOpt.AbstractIntermediateCallback &&
options.fixed_variable_treatment === nothing
madnlp_default_is_relax_bound =
options.kkt_system isa Type &&
options.kkt_system <: MadNLP.SparseCondensedKKTSystem
if !madnlp_default_is_relax_bound
@info "Setting fixed_variable_treatment = MadNLP.RelaxBound for AbstractIntermediateCallback (MadNLP's kkt_system default would otherwise eliminate fixed vars from solver.x)"
optimizer.options[:fixed_variable_treatment] = MadNLP.RelaxBound
end
end

for name in fieldnames(typeof(options))
value = getfield(options, name)
if name in ignored_options
Expand All @@ -221,6 +267,22 @@ function DirectTrajOpt.set_options!(optimizer::AbstractOptimizer, options::MadNL
hessian_approximation =
((value == "compact_lbfgs") ? MadNLP.CompactLBFGS : hessian_approximation)
optimizer.options[name] = hessian_approximation
elseif name == :intermediate_callback
if value isa DirectTrajOpt.AbstractIntermediateCallback
# Wrap solver-agnostic callbacks in the MadNLP-shaped adapter.
optimizer.options[name] = _MadNLPCallbackAdapter(value)
elseif value isa MadNLP.AbstractUserCallback
# Raw MadNLP callbacks pass through unwrapped.
optimizer.options[name] = value
else
throw(
ArgumentError(
"intermediate_callback must be a subtype of " *
"`DirectTrajOpt.AbstractIntermediateCallback` or " *
"`MadNLP.AbstractUserCallback`, got $(typeof(value))",
),
)
end
else
optimizer.options[name] = value
end
Expand Down
41 changes: 41 additions & 0 deletions src/solvers/_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Solvers

export AbstractOptimizer
export AbstractSolverOptions, DefaultSolverOptions, _DefaultSolverOptions
export AbstractIntermediateCallback
export _solve
export _solve_with_kwargs
export solve!
Expand All @@ -17,6 +18,46 @@ using TestItemRunner
const AbstractOptimizer = MOI.AbstractOptimizer
abstract type AbstractSolverOptions end

"""
AbstractIntermediateCallback

Solver-agnostic per-iteration callback for trajectory optimization.

Subtypes implement a callable with signature

(cb::SubType)(primal::AbstractVector, iter::Integer) -> Bool

where `primal` is the current full NLP primal vector and `iter` is the
iteration index from the solver's main optimization loop. Return `true` to
continue solving, `false` to stop early (the solver will report a
user-requested termination).

Each solver extension wraps an `AbstractIntermediateCallback` instance in a
solver-specific adapter at solve time, so the same callback object works
with every backend (MadNLP, Ipopt, …).

# Contract

- **`primal` may alias the solver's internal vector.** Copy it (e.g.
`collect(primal)`) if you need to retain the data past the callback
invocation — its contents may shift on the next iteration.
- **`iter` is monotonic.** The callback is invoked only from the solver's
main IPM loop; auxiliary phases (e.g. MadNLP's feasibility restoration
or robust modes) do not fire it.

# Required MadNLP setup

When using MadNLP, the callback must receive the **full** primal vector
to reconstruct trajectories correctly. MadNLP's default
`fixed_variable_treatment = MakeParameter` eliminates variables with
`lb == ub` from the working primal, so any subtype that maps `primal`
back onto a `NamedTrajectory` needs `fixed_variable_treatment =
MadNLP.RelaxBound`. When an `AbstractIntermediateCallback` is installed
via `MadNLPOptions.intermediate_callback`, DTO sets this automatically
(with an `@info` log) unless the user has provided a value.
"""
abstract type AbstractIntermediateCallback end

struct DefaultSolverOptions <: AbstractSolverOptions end
const _DefaultSolverOptions::Ref{Type{<:AbstractSolverOptions}} =
Ref{Type{<:AbstractSolverOptions}}(DefaultSolverOptions)
Expand Down
28 changes: 28 additions & 0 deletions src/solvers/madnlp_solver/options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,34 @@ export MadNLPOptions
kkt_system::Any = nothing # e.g. MadNLP.SparseUnreducedKKTSystem
cudss_ordering::Any = nothing # e.g. MadNLPGPU.AMD_ORDERING

# Per-iteration user callback. Two accepted forms:
#
# 1. A subtype of `DirectTrajOpt.AbstractIntermediateCallback` (solver-agnostic).
# Signature: `(cb)(primal::AbstractVector, iter::Integer) -> Bool`.
# The MadNLP extension wraps it in an internal adapter at solve time.
#
# 2. A raw `MadNLP.AbstractUserCallback` subtype with native MadNLP signature
# `(cb)(solver::MadNLP.AbstractMadNLPSolver, mode) -> Bool` — passed through
# unwrapped for users who want full access to the IPM state.
#
# Return `false` to stop the solver (yields `USER_REQUESTED_STOP`).
intermediate_callback::Any = nothing

# Controls how MadNLP handles variables with `lb == ub`. Mirrors MadNLP's
# own `fixed_variable_treatment::Type` field — must be a `Type` (typically
# `MadNLP.MakeParameter` or `MadNLP.RelaxBound`). Default (`nothing`) defers
# to MadNLP's kkt_system-aware conditional default:
#
# kkt_system <: SparseCondensedKKTSystem ? RelaxBound : MakeParameter
#
# When an `AbstractIntermediateCallback` is installed and this field is
# left at `nothing`, `set_options!` only overrides to `RelaxBound` if
# MadNLP's conditional default would otherwise be `MakeParameter` (which
# eliminates fixed boundary vars from `solver.x` and breaks trajectory
# reconstruction). The conditional default's `RelaxBound` branch is left
# untouched.
fixed_variable_treatment::Union{Type,Nothing} = nothing

# # Only supported by DirectTrajOpt._solve, as an optional kwarg override of `hessian_approximation`;
# # `hessian_approximation = eval_hessian ? "exact" : "compact_lbfgs"`
# eval_hessian::Bool = true
Expand Down
Loading