Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"

[sources]
MadNLP = {rev = "mk/moi_param", url = "https://github.com/klamike/MadNLP.jl.git"}
NLPModels = {rev = "mk/param_api", url = "https://github.com/klamike/NLPModels.jl"}
NLPModels = {rev = "mk/paramnlp", url = "https://github.com/klamike/NLPModels.jl"}

[extensions]
DiffOptExt = ["DiffOpt", "MathOptInterface"]
Expand Down
54 changes: 36 additions & 18 deletions ext/MadIPMExt/MadIPMExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,23 @@ module MadIPMExt

using LinearAlgebra: mul!
import MadDiff
import MadNLP
import MadNLP: AbstractKKTVector, primal, dual, dual_lb, dual_ub, solve_linear_system!
import MadIPM: NormalKKTSystem, MPCSolver, factorize_regularized_system!
import MadIPM: NormalKKTSystem, MPCSolver, factorize_regularized_system!,
AbstractBatchMPCSolver, UniformBatchMPCSolver, UniformBatchCallback,
BatchUnreducedKKTVector, BatchPrimalVector, BatchVector,
SparseUniformBatchKKTSystem, xp_lr, xp_ur,
lower, upper
import MadIPM
import MadDiff: MadDiffSolver, refactorize_kkt!, _SensitivitySolverShim,
_solve_with_refine!, _adjoint_solve_with_refine!,
adjoint_solve_kkt!, adjoint_mul!,
_adjoint_kktmul!, _adjoint_finish_bounds!, _adjoint_reduce_rhs!
_adjoint_kktmul!, _adjoint_finish_bounds!, _adjoint_reduce_rhs!,
has_hess_param, has_jac_param, has_lvar_param, has_uvar_param,
has_lcon_param, has_ucon_param, has_grad_param
import NLPModels: hpprod!, jpprod!, lvar_jpprod!, uvar_jpprod!, lcon_jpprod!, ucon_jpprod!,
grad!, grad_param!, hptprod!, jptprod!,
lvar_jptprod!, uvar_jptprod!, lcon_jptprod!, ucon_jptprod!

function _adjoint_normal_solve!(kkt::NormalKKTSystem{T}, w::AbstractKKTVector) where {T}
r1 = kkt.buffer_n
Expand Down Expand Up @@ -60,22 +71,29 @@ function refactorize_kkt!(kkt, solver::MPCSolver)
return nothing
end

# function _solve_with_refine!(
# sens::MadDiffSolver{T, KKT, MPCSolver, VI, VB, FC, RC, F},
# w::AbstractKKTVector,
# cache,
# ) where {T, KKT, VI, VB, FC, RC, F}
# solve!(sens.kkt, w)
# return nothing
# end
function _solve_with_refine!(
sens::MadDiffSolver{T, KKT, Solver, VB, FC, RC, JC, TC},
w::MadNLP.AbstractKKTVector,
cache,
) where {T, KKT<:MadNLP.AbstractKKTSystem{T}, Solver<:MPCSolver{T}, VB, FC, RC, JC, TC}
MadNLP.solve_kkt!(sens.kkt, w)
return nothing
end

function _adjoint_solve_with_refine!(
sens::MadDiffSolver{T, KKT, Solver, VB, FC, RC, JC, TC},
w::MadNLP.AbstractKKTVector,
cache,
) where {T, KKT<:MadNLP.AbstractKKTSystem{T}, Solver<:MPCSolver{T}, VB, FC, RC, JC, TC}
adjoint_solve_kkt!(sens.kkt, w)
return nothing
end

# function _adjoint_solve_with_refine!(
# sens::MadDiffSolver{T, KKT, MPCSolver, VI, VB, FC, RC, F},
# w::AbstractKKTVector,
# cache,
# ) where {T, KKT, VI, VB, FC, RC, F}
# adjoint_solve_kkt!(sens.kkt, w)
# return nothing
# end
include("batch_api.jl")
include("batch_cache.jl")
include("batch_packing.jl")
include("batch_kkt.jl")
include("batch_jvp.jl")
include("batch_vjp.jl")

end # module
24 changes: 24 additions & 0 deletions ext/MadIPMExt/batch_api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
BatchMadDiffSolver(batch_solver::AbstractBatchMPCSolver)

Create a batch sensitivity solver from a solved `UniformBatchMPCSolver`.
Supports `jacobian_vector_product!` and `vector_jacobian_product!` for
computing sensitivities across all batch instances simultaneously.
"""
mutable struct BatchMadDiffSolver{
T,
BatchSolver <: AbstractBatchMPCSolver{T},
FC, RC,
}
solver::BatchSolver
n_p::Int
is_eq::Vector{Bool}
jvp_cache::Union{Nothing, FC}
vjp_cache::Union{Nothing, RC}
end

function MadDiff.reset_sensitivity_cache!(sens::BatchMadDiffSolver)
sens.jvp_cache = nothing
sens.vjp_cache = nothing
return sens
end
201 changes: 201 additions & 0 deletions ext/MadIPMExt/batch_cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
struct BatchJVPCache{MT, BPV}
kkt_rhs::BatchUnreducedKKTVector
d2L_dxdp::MT # (nvar × batch_size) Hessian-param product (packed)
dg_dp::MT # (m × batch_size) Jacobian-param product (packed)
dlvar_dp::BPV # BatchPrimalVector: lower var bound param product
duvar_dp::BPV # BatchPrimalVector: upper var bound param product
dlcon_dp::MT # (m × batch_size) lower con bound param product
ducon_dp::MT # (m × batch_size) upper con bound param product
bx::MT # (nvar_nlp × batch_size) current x in NLP space
by::MT # (m × batch_size) current y in NLP space
hpv_nlp::MT # (nvar_nlp × batch_size) hpprod output
jpv_nlp::MT # (m × batch_size) jpprod output
dlvar_nlp::MT # (nvar_nlp × batch_size) lvar_jpprod output
duvar_nlp::MT # (nvar_nlp × batch_size) uvar_jpprod output
dlcon_nlp::MT # (m × batch_size) lcon_jpprod output
ducon_nlp::MT # (m × batch_size) ucon_jpprod output
grad_x::MT # (nvar_nlp × batch_size) grad_f for dobj
grad_p::MT # (nparam × batch_size) grad_param for dobj
end

function _zeros_like(proto::AbstractMatrix{T}, dims::Int...) where {T}
fill!(similar(proto, T, dims...), zero(T))
end

function get_batch_jvp_cache!(sens::BatchMadDiffSolver{T}) where {T}
if isnothing(sens.jvp_cache)
solver = sens.solver
bcb = solver.bcb
nlp = solver.nlp
bs = solver.batch_size

nvar_nlp = nlp.meta.nvar
n_con = bcb.ncon
n_p = sens.n_p
nx = bcb.nvar
ns = length(bcb.ind_ineq)
n_tot = nx + ns
m = bcb.ncon
nlb = length(bcb.ind_lb)
nub = length(bcb.ind_ub)

proto = solver.workspace.bx
MT = typeof(proto)
VT = typeof(similar(proto, T, 0))

sens.jvp_cache = BatchJVPCache{MT, BatchPrimalVector{T, MT, typeof(bcb.ind_lb)}}(
BatchUnreducedKKTVector(MT, VT, n_tot, m, nlb, nub, bs, bcb.ind_lb, bcb.ind_ub),
_zeros_like(proto, nx, bs), # d2L_dxdp
_zeros_like(proto, m, bs), # dg_dp
BatchPrimalVector(MT, VT, nx, ns, bs, bcb.ind_lb, bcb.ind_ub), # dlvar_dp
BatchPrimalVector(MT, VT, nx, ns, bs, bcb.ind_lb, bcb.ind_ub), # duvar_dp
_zeros_like(proto, m, bs), # dlcon_dp
_zeros_like(proto, m, bs), # ducon_dp
_zeros_like(proto, nvar_nlp, bs), # bx
_zeros_like(proto, n_con, bs), # by
_zeros_like(proto, nvar_nlp, bs), # hpv_nlp
_zeros_like(proto, n_con, bs), # jpv_nlp
_zeros_like(proto, nvar_nlp, bs), # dlvar_nlp
_zeros_like(proto, nvar_nlp, bs), # duvar_nlp
_zeros_like(proto, n_con, bs), # dlcon_nlp
_zeros_like(proto, n_con, bs), # ducon_nlp
_zeros_like(proto, nvar_nlp, bs), # grad_x
_zeros_like(proto, n_p, bs), # grad_p
)
end
return sens.jvp_cache
end

struct BatchJVPResult{MT, VT}
dx::MT # (nvar_nlp × batch_size)
dy::MT # (ncon × batch_size)
dzl::MT # (nvar_nlp × batch_size)
dzu::MT # (nvar_nlp × batch_size)
dobj::VT # (batch_size,)
end

function BatchJVPResult(sens::BatchMadDiffSolver{T}) where {T}
solver = sens.solver
bcb = solver.bcb
nlp = solver.nlp
bs = solver.batch_size
nvar_nlp = nlp.meta.nvar
n_con = bcb.ncon
proto = solver.workspace.bx
BatchJVPResult(
_zeros_like(proto, nvar_nlp, bs),
_zeros_like(proto, n_con, bs),
_zeros_like(proto, nvar_nlp, bs),
_zeros_like(proto, nvar_nlp, bs),
fill!(similar(proto, T, bs), zero(T)),
)
end

struct BatchVJPCache{MT, BPV}
kkt_rhs::BatchUnreducedKKTVector
dL_dx::MT # (nvar × batch_size)
dL_dy::MT # (m × batch_size)
dL_dzl::MT # (nlb × batch_size)
dL_dzu::MT # (nub × batch_size)
dzl_full::BPV # BatchPrimalVector: work buffer for bound unpacking
dzu_full::BPV # BatchPrimalVector: work buffer for bound unpacking
bx::MT # (nvar_nlp × batch_size)
by::MT # (m × batch_size) y scaled for pullback
dy_scaled::MT # (m × batch_size) adjoint dy scaled for jptprod
tmp_p::MT # (nparam × batch_size)
grad_x::MT # (nvar × batch_size)
end

function get_batch_vjp_cache!(sens::BatchMadDiffSolver{T}) where {T}
if isnothing(sens.vjp_cache)
solver = sens.solver
bcb = solver.bcb
nlp = solver.nlp
bs = solver.batch_size

nvar_nlp = nlp.meta.nvar
n_con = bcb.ncon
n_p = sens.n_p
nx = bcb.nvar
ns = length(bcb.ind_ineq)
n_tot = nx + ns
m = bcb.ncon
nlb = length(bcb.ind_lb)
nub = length(bcb.ind_ub)

proto = solver.workspace.bx
MT = typeof(proto)
VT = typeof(similar(proto, T, 0))

sens.vjp_cache = BatchVJPCache{MT, BatchPrimalVector{T, MT, typeof(bcb.ind_lb)}}(
BatchUnreducedKKTVector(MT, VT, n_tot, m, nlb, nub, bs, bcb.ind_lb, bcb.ind_ub),
_zeros_like(proto, nx, bs), # dL_dx
_zeros_like(proto, m, bs), # dL_dy
_zeros_like(proto, nlb, bs), # dL_dzl
_zeros_like(proto, nub, bs), # dL_dzu
BatchPrimalVector(MT, VT, nx, ns, bs, bcb.ind_lb, bcb.ind_ub), # dzl_full
BatchPrimalVector(MT, VT, nx, ns, bs, bcb.ind_lb, bcb.ind_ub), # dzu_full
_zeros_like(proto, nvar_nlp, bs), # bx
_zeros_like(proto, n_con, bs), # by
_zeros_like(proto, n_con, bs), # dy_scaled
_zeros_like(proto, n_p, bs), # tmp_p
_zeros_like(proto, nx, bs), # grad_x
)
end
return sens.vjp_cache
end

struct BatchVJPResult{MT}
dx::MT # (nvar_nlp × batch_size)
dy::MT # (ncon × batch_size)
dzl::MT # (nvar_nlp × batch_size)
dzu::MT # (nvar_nlp × batch_size)
grad_p::MT # (nparam × batch_size)
end

function BatchVJPResult(sens::BatchMadDiffSolver{T}) where {T}
solver = sens.solver
bcb = solver.bcb
nlp = solver.nlp
bs = solver.batch_size
nvar_nlp = nlp.meta.nvar
n_con = bcb.ncon
proto = solver.workspace.bx
BatchVJPResult(
_zeros_like(proto, nvar_nlp, bs),
_zeros_like(proto, n_con, bs),
_zeros_like(proto, nvar_nlp, bs),
_zeros_like(proto, nvar_nlp, bs),
_zeros_like(proto, sens.n_p, bs),
)
end

has_hess_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzhp != 0
has_jac_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjp != 0
has_lvar_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjplvar != 0
has_uvar_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjpuvar != 0
has_lcon_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjplcon != 0
has_ucon_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjpucon != 0
has_grad_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzgp != 0

function BatchMadDiffSolver(batch_solver::AbstractBatchMPCSolver{T}) where {T}
bcb = batch_solver.bcb
m = bcb.ncon
is_eq = Vector{Bool}(undef, m)
fill!(is_eq, false)
is_eq[bcb.ind_eq] .= true

n_p = batch_solver.nlp.meta.nparam
bs = batch_solver.batch_size
MT = typeof(batch_solver.workspace.bx)

FC = BatchJVPCache{MT, BatchPrimalVector{T, MT, typeof(bcb.ind_lb)}}
RC = BatchVJPCache{MT, BatchPrimalVector{T, MT, typeof(bcb.ind_lb)}}

# FIXME: BatchMadDiffConfig
refactorize_kkt!(batch_solver.kkt, batch_solver)

return BatchMadDiffSolver{T, typeof(batch_solver), FC, RC}(
batch_solver, n_p, is_eq, nothing, nothing,
)
end
Loading
Loading