diff --git a/Project.toml b/Project.toml index 7611dea..9d82b4c 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" [sources] MadNLP = {rev = "mk/moi_param", url = "https://github.com/klamike/MadNLP.jl.git"} -NLPModels = {rev = "mk/param_api", url = "https://github.com/klamike/NLPModels.jl"} +NLPModels = {rev = "mk/paramnlp", url = "https://github.com/klamike/NLPModels.jl"} [extensions] DiffOptExt = ["DiffOpt", "MathOptInterface"] diff --git a/ext/MadIPMExt/MadIPMExt.jl b/ext/MadIPMExt/MadIPMExt.jl index 088c4d6..9fbd6d1 100644 --- a/ext/MadIPMExt/MadIPMExt.jl +++ b/ext/MadIPMExt/MadIPMExt.jl @@ -2,12 +2,23 @@ module MadIPMExt using LinearAlgebra: mul! import MadDiff +import MadNLP import MadNLP: AbstractKKTVector, primal, dual, dual_lb, dual_ub, solve_linear_system! -import MadIPM: NormalKKTSystem, MPCSolver, factorize_regularized_system! +import MadIPM: NormalKKTSystem, MPCSolver, factorize_regularized_system!, + AbstractBatchMPCSolver, UniformBatchMPCSolver, UniformBatchCallback, + BatchUnreducedKKTVector, BatchPrimalVector, BatchVector, + SparseUniformBatchKKTSystem, xp_lr, xp_ur, + lower, upper +import MadIPM import MadDiff: MadDiffSolver, refactorize_kkt!, _SensitivitySolverShim, _solve_with_refine!, _adjoint_solve_with_refine!, adjoint_solve_kkt!, adjoint_mul!, - _adjoint_kktmul!, _adjoint_finish_bounds!, _adjoint_reduce_rhs! + _adjoint_kktmul!, _adjoint_finish_bounds!, _adjoint_reduce_rhs!, + has_hess_param, has_jac_param, has_lvar_param, has_uvar_param, + has_lcon_param, has_ucon_param, has_grad_param +import NLPModels: hpprod!, jpprod!, lvar_jpprod!, uvar_jpprod!, lcon_jpprod!, ucon_jpprod!, + grad!, grad_param!, hptprod!, jptprod!, + lvar_jptprod!, uvar_jptprod!, lcon_jptprod!, ucon_jptprod! function _adjoint_normal_solve!(kkt::NormalKKTSystem{T}, w::AbstractKKTVector) where {T} r1 = kkt.buffer_n @@ -60,22 +71,29 @@ function refactorize_kkt!(kkt, solver::MPCSolver) return nothing end -# function _solve_with_refine!( -# sens::MadDiffSolver{T, KKT, MPCSolver, VI, VB, FC, RC, F}, -# w::AbstractKKTVector, -# cache, -# ) where {T, KKT, VI, VB, FC, RC, F} -# solve!(sens.kkt, w) -# return nothing -# end +function _solve_with_refine!( + sens::MadDiffSolver{T, KKT, Solver, VB, FC, RC, JC, TC}, + w::MadNLP.AbstractKKTVector, + cache, +) where {T, KKT<:MadNLP.AbstractKKTSystem{T}, Solver<:MPCSolver{T}, VB, FC, RC, JC, TC} + MadNLP.solve_kkt!(sens.kkt, w) + return nothing +end + +function _adjoint_solve_with_refine!( + sens::MadDiffSolver{T, KKT, Solver, VB, FC, RC, JC, TC}, + w::MadNLP.AbstractKKTVector, + cache, +) where {T, KKT<:MadNLP.AbstractKKTSystem{T}, Solver<:MPCSolver{T}, VB, FC, RC, JC, TC} + adjoint_solve_kkt!(sens.kkt, w) + return nothing +end -# function _adjoint_solve_with_refine!( -# sens::MadDiffSolver{T, KKT, MPCSolver, VI, VB, FC, RC, F}, -# w::AbstractKKTVector, -# cache, -# ) where {T, KKT, VI, VB, FC, RC, F} -# adjoint_solve_kkt!(sens.kkt, w) -# return nothing -# end +include("batch_api.jl") +include("batch_cache.jl") +include("batch_packing.jl") +include("batch_kkt.jl") +include("batch_jvp.jl") +include("batch_vjp.jl") end # module diff --git a/ext/MadIPMExt/batch_api.jl b/ext/MadIPMExt/batch_api.jl new file mode 100644 index 0000000..63e4bbe --- /dev/null +++ b/ext/MadIPMExt/batch_api.jl @@ -0,0 +1,24 @@ +""" + BatchMadDiffSolver(batch_solver::AbstractBatchMPCSolver) + +Create a batch sensitivity solver from a solved `UniformBatchMPCSolver`. +Supports `jacobian_vector_product!` and `vector_jacobian_product!` for +computing sensitivities across all batch instances simultaneously. +""" +mutable struct BatchMadDiffSolver{ + T, + BatchSolver <: AbstractBatchMPCSolver{T}, + FC, RC, +} + solver::BatchSolver + n_p::Int + is_eq::Vector{Bool} + jvp_cache::Union{Nothing, FC} + vjp_cache::Union{Nothing, RC} +end + +function MadDiff.reset_sensitivity_cache!(sens::BatchMadDiffSolver) + sens.jvp_cache = nothing + sens.vjp_cache = nothing + return sens +end diff --git a/ext/MadIPMExt/batch_cache.jl b/ext/MadIPMExt/batch_cache.jl new file mode 100644 index 0000000..fc9b1f3 --- /dev/null +++ b/ext/MadIPMExt/batch_cache.jl @@ -0,0 +1,201 @@ +struct BatchJVPCache{MT, BPV} + kkt_rhs::BatchUnreducedKKTVector + d2L_dxdp::MT # (nvar × batch_size) Hessian-param product (packed) + dg_dp::MT # (m × batch_size) Jacobian-param product (packed) + dlvar_dp::BPV # BatchPrimalVector: lower var bound param product + duvar_dp::BPV # BatchPrimalVector: upper var bound param product + dlcon_dp::MT # (m × batch_size) lower con bound param product + ducon_dp::MT # (m × batch_size) upper con bound param product + bx::MT # (nvar_nlp × batch_size) current x in NLP space + by::MT # (m × batch_size) current y in NLP space + hpv_nlp::MT # (nvar_nlp × batch_size) hpprod output + jpv_nlp::MT # (m × batch_size) jpprod output + dlvar_nlp::MT # (nvar_nlp × batch_size) lvar_jpprod output + duvar_nlp::MT # (nvar_nlp × batch_size) uvar_jpprod output + dlcon_nlp::MT # (m × batch_size) lcon_jpprod output + ducon_nlp::MT # (m × batch_size) ucon_jpprod output + grad_x::MT # (nvar_nlp × batch_size) grad_f for dobj + grad_p::MT # (nparam × batch_size) grad_param for dobj +end + +function _zeros_like(proto::AbstractMatrix{T}, dims::Int...) where {T} + fill!(similar(proto, T, dims...), zero(T)) +end + +function get_batch_jvp_cache!(sens::BatchMadDiffSolver{T}) where {T} + if isnothing(sens.jvp_cache) + solver = sens.solver + bcb = solver.bcb + nlp = solver.nlp + bs = solver.batch_size + + nvar_nlp = nlp.meta.nvar + n_con = bcb.ncon + n_p = sens.n_p + nx = bcb.nvar + ns = length(bcb.ind_ineq) + n_tot = nx + ns + m = bcb.ncon + nlb = length(bcb.ind_lb) + nub = length(bcb.ind_ub) + + proto = solver.workspace.bx + MT = typeof(proto) + VT = typeof(similar(proto, T, 0)) + + sens.jvp_cache = BatchJVPCache{MT, BatchPrimalVector{T, MT, typeof(bcb.ind_lb)}}( + BatchUnreducedKKTVector(MT, VT, n_tot, m, nlb, nub, bs, bcb.ind_lb, bcb.ind_ub), + _zeros_like(proto, nx, bs), # d2L_dxdp + _zeros_like(proto, m, bs), # dg_dp + BatchPrimalVector(MT, VT, nx, ns, bs, bcb.ind_lb, bcb.ind_ub), # dlvar_dp + BatchPrimalVector(MT, VT, nx, ns, bs, bcb.ind_lb, bcb.ind_ub), # duvar_dp + _zeros_like(proto, m, bs), # dlcon_dp + _zeros_like(proto, m, bs), # ducon_dp + _zeros_like(proto, nvar_nlp, bs), # bx + _zeros_like(proto, n_con, bs), # by + _zeros_like(proto, nvar_nlp, bs), # hpv_nlp + _zeros_like(proto, n_con, bs), # jpv_nlp + _zeros_like(proto, nvar_nlp, bs), # dlvar_nlp + _zeros_like(proto, nvar_nlp, bs), # duvar_nlp + _zeros_like(proto, n_con, bs), # dlcon_nlp + _zeros_like(proto, n_con, bs), # ducon_nlp + _zeros_like(proto, nvar_nlp, bs), # grad_x + _zeros_like(proto, n_p, bs), # grad_p + ) + end + return sens.jvp_cache +end + +struct BatchJVPResult{MT, VT} + dx::MT # (nvar_nlp × batch_size) + dy::MT # (ncon × batch_size) + dzl::MT # (nvar_nlp × batch_size) + dzu::MT # (nvar_nlp × batch_size) + dobj::VT # (batch_size,) +end + +function BatchJVPResult(sens::BatchMadDiffSolver{T}) where {T} + solver = sens.solver + bcb = solver.bcb + nlp = solver.nlp + bs = solver.batch_size + nvar_nlp = nlp.meta.nvar + n_con = bcb.ncon + proto = solver.workspace.bx + BatchJVPResult( + _zeros_like(proto, nvar_nlp, bs), + _zeros_like(proto, n_con, bs), + _zeros_like(proto, nvar_nlp, bs), + _zeros_like(proto, nvar_nlp, bs), + fill!(similar(proto, T, bs), zero(T)), + ) +end + +struct BatchVJPCache{MT, BPV} + kkt_rhs::BatchUnreducedKKTVector + dL_dx::MT # (nvar × batch_size) + dL_dy::MT # (m × batch_size) + dL_dzl::MT # (nlb × batch_size) + dL_dzu::MT # (nub × batch_size) + dzl_full::BPV # BatchPrimalVector: work buffer for bound unpacking + dzu_full::BPV # BatchPrimalVector: work buffer for bound unpacking + bx::MT # (nvar_nlp × batch_size) + by::MT # (m × batch_size) y scaled for pullback + dy_scaled::MT # (m × batch_size) adjoint dy scaled for jptprod + tmp_p::MT # (nparam × batch_size) + grad_x::MT # (nvar × batch_size) +end + +function get_batch_vjp_cache!(sens::BatchMadDiffSolver{T}) where {T} + if isnothing(sens.vjp_cache) + solver = sens.solver + bcb = solver.bcb + nlp = solver.nlp + bs = solver.batch_size + + nvar_nlp = nlp.meta.nvar + n_con = bcb.ncon + n_p = sens.n_p + nx = bcb.nvar + ns = length(bcb.ind_ineq) + n_tot = nx + ns + m = bcb.ncon + nlb = length(bcb.ind_lb) + nub = length(bcb.ind_ub) + + proto = solver.workspace.bx + MT = typeof(proto) + VT = typeof(similar(proto, T, 0)) + + sens.vjp_cache = BatchVJPCache{MT, BatchPrimalVector{T, MT, typeof(bcb.ind_lb)}}( + BatchUnreducedKKTVector(MT, VT, n_tot, m, nlb, nub, bs, bcb.ind_lb, bcb.ind_ub), + _zeros_like(proto, nx, bs), # dL_dx + _zeros_like(proto, m, bs), # dL_dy + _zeros_like(proto, nlb, bs), # dL_dzl + _zeros_like(proto, nub, bs), # dL_dzu + BatchPrimalVector(MT, VT, nx, ns, bs, bcb.ind_lb, bcb.ind_ub), # dzl_full + BatchPrimalVector(MT, VT, nx, ns, bs, bcb.ind_lb, bcb.ind_ub), # dzu_full + _zeros_like(proto, nvar_nlp, bs), # bx + _zeros_like(proto, n_con, bs), # by + _zeros_like(proto, n_con, bs), # dy_scaled + _zeros_like(proto, n_p, bs), # tmp_p + _zeros_like(proto, nx, bs), # grad_x + ) + end + return sens.vjp_cache +end + +struct BatchVJPResult{MT} + dx::MT # (nvar_nlp × batch_size) + dy::MT # (ncon × batch_size) + dzl::MT # (nvar_nlp × batch_size) + dzu::MT # (nvar_nlp × batch_size) + grad_p::MT # (nparam × batch_size) +end + +function BatchVJPResult(sens::BatchMadDiffSolver{T}) where {T} + solver = sens.solver + bcb = solver.bcb + nlp = solver.nlp + bs = solver.batch_size + nvar_nlp = nlp.meta.nvar + n_con = bcb.ncon + proto = solver.workspace.bx + BatchVJPResult( + _zeros_like(proto, nvar_nlp, bs), + _zeros_like(proto, n_con, bs), + _zeros_like(proto, nvar_nlp, bs), + _zeros_like(proto, nvar_nlp, bs), + _zeros_like(proto, sens.n_p, bs), + ) +end + +has_hess_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzhp != 0 +has_jac_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjp != 0 +has_lvar_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjplvar != 0 +has_uvar_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjpuvar != 0 +has_lcon_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjplcon != 0 +has_ucon_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzjpucon != 0 +has_grad_param(::Union{BatchJVPCache, BatchVJPCache}, meta) = meta.nnzgp != 0 + +function BatchMadDiffSolver(batch_solver::AbstractBatchMPCSolver{T}) where {T} + bcb = batch_solver.bcb + m = bcb.ncon + is_eq = Vector{Bool}(undef, m) + fill!(is_eq, false) + is_eq[bcb.ind_eq] .= true + + n_p = batch_solver.nlp.meta.nparam + bs = batch_solver.batch_size + MT = typeof(batch_solver.workspace.bx) + + FC = BatchJVPCache{MT, BatchPrimalVector{T, MT, typeof(bcb.ind_lb)}} + RC = BatchVJPCache{MT, BatchPrimalVector{T, MT, typeof(bcb.ind_lb)}} + + # FIXME: BatchMadDiffConfig + refactorize_kkt!(batch_solver.kkt, batch_solver) + + return BatchMadDiffSolver{T, typeof(batch_solver), FC, RC}( + batch_solver, n_p, is_eq, nothing, nothing, + ) +end diff --git a/ext/MadIPMExt/batch_jvp.jl b/ext/MadIPMExt/batch_jvp.jl new file mode 100644 index 0000000..b1f4928 --- /dev/null +++ b/ext/MadIPMExt/batch_jvp.jl @@ -0,0 +1,140 @@ +""" + jacobian_vector_product!(sens::BatchMadDiffSolver, Δp::AbstractMatrix) + +Compute batch JVP of the optimal solution with respect to parameters. +`Δp` is `(nparam × batch_size)`. + +Returns a [`BatchJVPResult`] with per-instance solution sensitivities. +""" +function MadDiff.jacobian_vector_product!(sens::BatchMadDiffSolver, Δp::AbstractMatrix) + return MadDiff.jacobian_vector_product!(BatchJVPResult(sens), sens, Δp) +end + +function MadDiff.jacobian_vector_product!( + result::BatchJVPResult, sens::BatchMadDiffSolver{T}, Δp::AbstractMatrix, +) where {T} + solver = sens.solver + bcb = solver.bcb + nlp = solver.nlp + meta = nlp.meta + cache = get_batch_jvp_cache!(sens) + + MadNLP.unpack_x!(cache.bx, bcb, solver.x) + MadNLP.unpack_y!(cache.by, bcb, MadNLP.full(solver.y)) + + bx = cache.bx + by = cache.by + + fill!(cache.hpv_nlp, zero(T)) + fill!(cache.jpv_nlp, zero(T)) + fill!(cache.dlvar_nlp, zero(T)) + fill!(cache.duvar_nlp, zero(T)) + fill!(cache.dlcon_nlp, zero(T)) + fill!(cache.ducon_nlp, zero(T)) + + bobj_sign = vec(bcb.obj_sign) + has_hess_param(cache, meta) && hpprod!(nlp, bx, by, Δp, bobj_sign, cache.hpv_nlp) + has_jac_param(cache, meta) && jpprod!(nlp, bx, Δp, cache.jpv_nlp) + has_lvar_param(cache, meta) && lvar_jpprod!(nlp, Δp, cache.dlvar_nlp) + has_uvar_param(cache, meta) && uvar_jpprod!(nlp, Δp, cache.duvar_nlp) + has_lcon_param(cache, meta) && lcon_jpprod!(nlp, Δp, cache.dlcon_nlp) + has_ucon_param(cache, meta) && ucon_jpprod!(nlp, Δp, cache.ducon_nlp) + + _batch_pack_jvp!(sens, cache) + _batch_solve_jvp!(sens) + _batch_unpack_jvp!(result, sens, cache) + _batch_compute_objective_sensitivity!(result, sens, cache, Δp) + + return result +end + +function _batch_pack_jvp!(sens::BatchMadDiffSolver{T}, cache::BatchJVPCache) where {T} + bcb = sens.solver.bcb + nx = bcb.nvar + ind_ineq = bcb.ind_ineq + + fill!(cache.d2L_dxdp, zero(T)) + fill!(cache.dg_dp, zero(T)) + fill!(cache.dlcon_dp, zero(T)) + fill!(cache.ducon_dp, zero(T)) + fill!(MadNLP.full(cache.dlvar_dp), zero(T)) + fill!(MadNLP.full(cache.duvar_dp), zero(T)) + + MadDiff.pack_hess!(cache.d2L_dxdp, bcb, cache.hpv_nlp) + MadDiff.pack_cons!(cache.dg_dp, bcb, cache.jpv_nlp) + MadDiff.pack_cons!(cache.dlcon_dp, bcb, cache.dlcon_nlp) + MadDiff.pack_cons!(cache.ducon_dp, bcb, cache.ducon_nlp) + + MadDiff.pack_dx!(MadNLP.variable(cache.dlvar_dp), bcb, cache.dlvar_nlp) + MadDiff.pack_dx!(MadNLP.variable(cache.duvar_dp), bcb, cache.duvar_nlp) + + ns = length(ind_ineq) + if ns > 0 + MadDiff.pack_slack!(MadNLP.slack(cache.dlvar_dp), bcb, cache.dlcon_nlp) + MadDiff.pack_slack!(MadNLP.slack(cache.duvar_dp), bcb, cache.ducon_nlp) + end + + return nothing +end + +function _batch_solve_jvp!(sens::BatchMadDiffSolver{T}) where {T} + cache = get_batch_jvp_cache!(sens) + w = cache.kkt_rhs + _batch_assemble_jvp_rhs!(sens, w, cache) + _solve_with_refine!(sens, w, cache) + return nothing +end + +function _batch_assemble_jvp_rhs!( + sens::BatchMadDiffSolver{T}, w::BatchUnreducedKKTVector, cache::BatchJVPCache, +) where {T} + bcb = sens.solver.bcb + nx = bcb.nvar + + fill!(MadNLP.full(w), zero(T)) + view(MadNLP.primal(w), 1:nx, :) .= .-cache.d2L_dxdp + MadNLP.dual(w) .= .-cache.dg_dp .+ sens.is_eq .* (cache.dlcon_dp .+ cache.ducon_dp) ./ 2 + MadDiff.jvp_set_bound_rhs!(sens.solver.kkt, w, cache.dlvar_dp, cache.duvar_dp) + + return w +end + +function _batch_unpack_jvp!( + result::BatchJVPResult, sens::BatchMadDiffSolver, cache::BatchJVPCache, +) + bcb = sens.solver.bcb + w = cache.kkt_rhs + ind_lb = bcb.ind_lb + ind_ub = bcb.ind_ub + + MadDiff.unpack_dx!(result.dx, bcb, MadNLP.primal(w)) + MadDiff.set_fixed_sensitivity!(result.dx, bcb, cache.dlvar_nlp, cache.duvar_nlp) + MadNLP.unpack_y!(result.dy, bcb, MadNLP.dual(w)) + + MadDiff.unpack_dzl!(result.dzl, bcb, MadNLP.dual_lb(w), cache.dlvar_dp) + MadDiff.unpack_dzu!(result.dzu, bcb, MadNLP.dual_ub(w), cache.duvar_dp) + + return result +end + +function _batch_compute_objective_sensitivity!( + result::BatchJVPResult, sens::BatchMadDiffSolver{T}, cache::BatchJVPCache, Δp::AbstractMatrix, +) where {T} + solver = sens.solver + nlp = solver.nlp + meta = nlp.meta + bx = cache.bx + + grad!(nlp, bx, cache.grad_x) + if has_grad_param(cache, meta) + grad_param!(nlp, bx, cache.grad_p) + else + fill!(cache.grad_p, zero(T)) + end + + # dobj[j] = dot(grad_x[:,j], dx[:,j]) + dot(grad_p[:,j], Δp[:,j]) + result.dobj .= vec(sum(cache.grad_x .* result.dx, dims=1)) .+ + vec(sum(cache.grad_p .* Δp, dims=1)) + + return nothing +end diff --git a/ext/MadIPMExt/batch_kkt.jl b/ext/MadIPMExt/batch_kkt.jl new file mode 100644 index 0000000..18c03e8 --- /dev/null +++ b/ext/MadIPMExt/batch_kkt.jl @@ -0,0 +1,74 @@ +# converged instances are marked inactive (batch_map=0).. +# FIXME: filter by termination status +function _reset_batch_map!(bkkt::SparseUniformBatchKKTSystem) + bs = bkkt.batch_size + for i in 1:bs + bkkt.batch_map[i] = i + bkkt.batch_map_rev[i] = i + end + bkkt.active_batch_size[] = bs + return +end + +function refactorize_kkt!(kkt::SparseUniformBatchKKTSystem, solver::AbstractBatchMPCSolver) + _reset_batch_map!(kkt) + MadIPM.set_aug_diagonal_reg!(kkt, solver) + MadNLP.eval_jac_wrapper!(solver, kkt) + MadNLP.eval_lag_hess_wrapper!(solver, kkt) + MadNLP.build_kkt!(kkt) + MadNLP.factorize_kkt!(kkt) + return nothing +end + +function MadNLP.solve_kkt!(bkkt::SparseUniformBatchKKTSystem, d::BatchUnreducedKKTVector) + MadNLP.reduce_rhs!(bkkt, d) + rhs = bkkt.rhs_buffer + pd = MadNLP.primal_dual(d) + copyto!(rhs, pd) + MadNLP.solve_linear_system!(bkkt, rhs) + copyto!(pd, rhs) + MadNLP.finish_aug_solve!(bkkt, d) + return d +end + +function adjoint_solve_kkt!(bkkt::SparseUniformBatchKKTSystem, d::BatchUnreducedKKTVector) + lb_off = d.n + d.m + MadIPM._adjoint_finish_bounds_batch!( + d.values, d.ind_lb, lb_off, bkkt.l_lower, bkkt.l_diag, + d.ind_ub, lb_off + d.nlb, bkkt.u_lower, bkkt.u_diag, + ) + rhs = bkkt.rhs_buffer + pd = MadNLP.primal_dual(d) + copyto!(rhs, pd) + MadNLP.solve_linear_system!(bkkt, rhs) + copyto!(pd, rhs) + MadIPM._adjoint_reduce_rhs_batch!( + d.values, d.ind_lb, lb_off, bkkt.l_diag, + d.ind_ub, lb_off + d.nlb, bkkt.u_diag, + ) + return d +end + +function MadDiff.jvp_set_bound_rhs!(::SparseUniformBatchKKTSystem, w::BatchUnreducedKKTVector, dlvar_dp::BatchPrimalVector, duvar_dp::BatchPrimalVector) + MadNLP.dual_lb(w) .= lower(dlvar_dp) + MadNLP.dual_ub(w) .= .-upper(duvar_dp) + return nothing +end + +function MadDiff.vjp_fill_pv!(::SparseUniformBatchKKTSystem, pvl::BatchPrimalVector, pvu::BatchPrimalVector, w::BatchUnreducedKKTVector) + fill!(MadNLP.full(pvl), zero(eltype(MadNLP.full(pvl)))) + fill!(MadNLP.full(pvu), zero(eltype(MadNLP.full(pvu)))) + lower(pvl) .= MadNLP.dual_lb(w) + upper(pvu) .= .-MadNLP.dual_ub(w) + return nothing +end + +function _solve_with_refine!(sens::BatchMadDiffSolver{T}, w, cache) where {T} + MadNLP.solve_kkt!(sens.solver.kkt, w) + return nothing +end + +function _adjoint_solve_with_refine!(sens::BatchMadDiffSolver{T}, w, cache) where {T} + adjoint_solve_kkt!(sens.solver.kkt, w) + return nothing +end diff --git a/ext/MadIPMExt/batch_packing.jl b/ext/MadIPMExt/batch_packing.jl new file mode 100644 index 0000000..8a0b701 --- /dev/null +++ b/ext/MadIPMExt/batch_packing.jl @@ -0,0 +1,45 @@ +function MadDiff.pack_hess!(x::AbstractMatrix, bcb::UniformBatchCallback, x_full::AbstractMatrix) + MadDiff.pack_dx!(x, bcb, x_full) + x .*= bcb.obj_scale + return nothing +end + +function MadDiff.pack_cons!(c::AbstractMatrix, bcb::UniformBatchCallback, c_full::AbstractMatrix) + c .= c_full .* bcb.con_scale + return nothing +end + +function MadDiff.pack_z!(z::AbstractMatrix, bcb::UniformBatchCallback, z_full::AbstractMatrix) + z .= z_full ./ bcb.obj_scale + return nothing +end + +function MadDiff.pack_dy!(y::AbstractMatrix, bcb::UniformBatchCallback, y_full::AbstractMatrix) + y .= (y_full .* (bcb.obj_sign ./ bcb.obj_scale)) .* bcb.con_scale + return nothing +end + +function MadDiff.unpack_dzl!(dz::AbstractMatrix, bcb::UniformBatchCallback, rhs::AbstractMatrix, pv::BatchPrimalVector) + fill!(MadNLP.full(pv), zero(eltype(MadNLP.full(pv)))) + lower(pv) .= rhs + MadDiff.unpack_dx!(dz, bcb, MadNLP.variable(pv)) + dz ./= bcb.obj_scale + return nothing +end + +function MadDiff.unpack_dzu!(dz::AbstractMatrix, bcb::UniformBatchCallback, rhs::AbstractMatrix, pv::BatchPrimalVector) + fill!(MadNLP.full(pv), zero(eltype(MadNLP.full(pv)))) + upper(pv) .= rhs + MadDiff.unpack_dx!(dz, bcb, MadNLP.variable(pv)) + dz ./= bcb.obj_scale + return nothing +end + +function MadDiff.unpack_slack!(out::AbstractMatrix, bcb::UniformBatchCallback, dz::BatchPrimalVector, is_eq, dy::AbstractMatrix) + out .= (is_eq .* dy ./ 2) .* bcb.con_scale + ns = length(bcb.ind_ineq) + if ns > 0 + out[bcb.ind_ineq, :] .+= MadNLP.slack(dz) .* bcb.con_scale[bcb.ind_ineq, :] + end + return nothing +end diff --git a/ext/MadIPMExt/batch_vjp.jl b/ext/MadIPMExt/batch_vjp.jl new file mode 100644 index 0000000..adbe2ab --- /dev/null +++ b/ext/MadIPMExt/batch_vjp.jl @@ -0,0 +1,168 @@ +""" + vector_jacobian_product!(sens::BatchMadDiffSolver; dL_dx, dL_dy, dL_dzl, dL_dzu, dobj) + +Compute batch VJP to backpropagate a per-instance loss through the optimal solution +with respect to the parameters. + +Keyword arguments are matrices `(dim × batch_size)` providing loss sensitivities, +except `dobj` which is a vector of length `batch_size`. At least one must be provided. + +Returns a [`BatchVJPResult`] containing per-instance parameter gradient `grad_p`. +""" +function MadDiff.vector_jacobian_product!( + sens::BatchMadDiffSolver; + dL_dx = nothing, dL_dy = nothing, dL_dzl = nothing, dL_dzu = nothing, dobj = nothing, +) + return MadDiff.vector_jacobian_product!( + BatchVJPResult(sens), sens; + dL_dx, dL_dy, dL_dzl, dL_dzu, dobj, + ) +end + +function MadDiff.vector_jacobian_product!( + result::BatchVJPResult, sens::BatchMadDiffSolver{T}; + dL_dx = nothing, dL_dy = nothing, dL_dzl = nothing, dL_dzu = nothing, dobj = nothing, +) where {T} + all(isnothing, (dL_dx, dL_dy, dL_dzl, dL_dzu, dobj)) && + throw(ArgumentError("At least one of dL_dx, dL_dy, dL_dzl, dL_dzu, dobj must be provided")) + + _batch_pack_vjp!(sens; dL_dx, dL_dy, dL_dzl, dL_dzu, dobj) + _batch_solve_vjp!(sens) + _batch_unpack_vjp!(result, sens) + _batch_vjp_pullback!(result, sens; dobj) + return result +end + +function _batch_pack_vjp!( + sens::BatchMadDiffSolver{T}; + dL_dx = nothing, dL_dy = nothing, dL_dzl = nothing, dL_dzu = nothing, dobj = nothing, +) where {T} + cache = get_batch_vjp_cache!(sens) + bcb = sens.solver.bcb + + fill!(cache.dL_dx, zero(T)) + fill!(cache.dL_dy, zero(T)) + fill!(cache.dL_dzl, zero(T)) + fill!(cache.dL_dzu, zero(T)) + fill!(MadNLP.full(cache.dzl_full), zero(T)) + fill!(MadNLP.full(cache.dzu_full), zero(T)) + + isnothing(dL_dx) || MadDiff.pack_dx!(cache.dL_dx, bcb, dL_dx) + isnothing(dL_dy) || MadDiff.pack_dy!(cache.dL_dy, bcb, dL_dy) + isnothing(dL_dzl) || MadDiff.pack_dzl!(cache.dL_dzl, bcb, dL_dzl, cache.dzl_full) + isnothing(dL_dzu) || MadDiff.pack_dzu!(cache.dL_dzu, bcb, dL_dzu, cache.dzu_full) + + if !isnothing(dobj) + MadNLP.unpack_x!(cache.bx, bcb, sens.solver.x) + MadNLP._eval_grad_f_wrapper!(bcb, cache.bx, cache.grad_x) + cache.dL_dx .+= cache.grad_x .* reshape(dobj, 1, :) + end + + return nothing +end + +function _batch_solve_vjp!(sens::BatchMadDiffSolver{T}) where {T} + cache = get_batch_vjp_cache!(sens) + w = cache.kkt_rhs + nx = size(cache.dL_dx, 1) + + fill!(MadNLP.full(w), zero(T)) + view(MadNLP.primal(w), 1:nx, :) .= cache.dL_dx + MadNLP.dual(w) .= cache.dL_dy + MadNLP.dual_lb(w) .= cache.dL_dzl + MadNLP.dual_ub(w) .= cache.dL_dzu + + _adjoint_solve_with_refine!(sens, w, cache) + return nothing +end + +function _batch_unpack_vjp!(result::BatchVJPResult, sens::BatchMadDiffSolver) + cache = get_batch_vjp_cache!(sens) + bcb = sens.solver.bcb + w = cache.kkt_rhs + + MadDiff.unpack_dx!(result.dx, bcb, MadNLP.primal(w)) + MadNLP.unpack_y!(result.dy, bcb, MadNLP.dual(w)) + MadDiff.unpack_dzl!(result.dzl, bcb, MadNLP.dual_lb(w), cache.dzl_full) + MadDiff.unpack_dzu!(result.dzu, bcb, MadNLP.dual_ub(w), cache.dzu_full) + + return result +end + +function _batch_vjp_pullback!( + result::BatchVJPResult, sens::BatchMadDiffSolver{T}; + dobj = nothing, +) where {T} + solver = sens.solver + nlp = solver.nlp + meta = nlp.meta + bcb = solver.bcb + cache = get_batch_vjp_cache!(sens) + w = cache.kkt_rhs + bx = cache.bx + by = cache.by + dx = result.dx + dy = cache.dy_scaled + pvl = cache.dzl_full + pvu = cache.dzu_full + tmp = cache.tmp_p + + bobj_scale = vec(bcb.obj_scale) + bobj_sign = vec(bcb.obj_sign) + bσ_scaled = bobj_sign .* bobj_scale + + grad_p = result.grad_p + fill!(grad_p, zero(T)) + + MadNLP.unpack_x!(bx, bcb, solver.x) + + if has_hess_param(cache, meta) + MadNLP.unpack_y!(by, bcb, MadNLP.full(solver.y)) + by .*= reshape(bσ_scaled, 1, :) + hptprod!(nlp, bx, by, dx, bσ_scaled, grad_p) + end + + if !isnothing(dobj) && has_grad_param(cache, meta) + grad_param!(nlp, bx, tmp) + # grad_p -= dobj .* grad_param + tmp .*= reshape(dobj, 1, :) + grad_p .-= tmp + end + + if has_jac_param(cache, meta) + dy .= result.dy .* reshape(bσ_scaled, 1, :) + jptprod!(nlp, bx, dy, tmp) + grad_p .+= tmp + end + + MadDiff.vjp_fill_pv!(solver.kkt, pvl, pvu, w) + + if has_lvar_param(cache, meta) + MadDiff.unpack_dx!(bx, bcb, MadNLP.variable(pvl)) + lvar_jptprod!(nlp, bx, tmp) + grad_p .-= tmp + end + + if has_uvar_param(cache, meta) + MadDiff.unpack_dx!(bx, bcb, MadNLP.variable(pvu)) + uvar_jptprod!(nlp, bx, tmp) + grad_p .+= tmp + end + + if has_lcon_param(cache, meta) + MadDiff.unpack_slack!(by, bcb, pvl, sens.is_eq, MadNLP.dual(w)) + lcon_jptprod!(nlp, by, tmp) + grad_p .-= tmp + end + + if has_ucon_param(cache, meta) + MadDiff.unpack_slack!(by, bcb, pvu, sens.is_eq, MadNLP.dual(w)) + ucon_jptprod!(nlp, by, tmp) + grad_p .+= tmp + end + + grad_p .*= -one(T) + + return result +end + diff --git a/test/Project.toml b/test/Project.toml index 335fbbb..a9bc3c4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,6 +17,7 @@ NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" NLPModelsJuMP = "792afdf1-32c1-5681-94e0-d7bf7a5df49e" QuadraticModels = "f468eda6-eac5-11e8-05a5-ff9e497bcd19" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [sources] @@ -25,7 +26,9 @@ ExaModels = {rev = "mk/param_ad", url = "https://github.com/klamike/ExaModels.jl HybridKKT = {rev = "master", url = "https://github.com/MadNLP/HybridKKT.jl.git"} MadDiff = {path = ".."} MadNLP = {rev = "mk/moi_param", url = "https://github.com/klamike/MadNLP.jl.git"} -NLPModels = {rev = "mk/param_api", url = "https://github.com/klamike/NLPModels.jl"} +MadIPM = {rev = "mk/batchperf", url = "https://github.com/klamike/MadIPM.jl"} +NLPModels = {rev = "mk/paramnlp", url = "https://github.com/klamike/NLPModels.jl"} +QuadraticModels = {rev = "mk/rhsbatch", url = "https://github.com/klamike/QuadraticModels.jl"} [compat] DiffOpt = "0.5.5" diff --git a/test/runtests.jl b/test/runtests.jl index 430919e..1720f46 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -274,6 +274,8 @@ end end end +include("test_batch_diff.jl") + @testset "MOI attributes" begin model = MadDiff.diff_model(MadNLP.Optimizer) set_silent(model) diff --git a/test/test_batch_diff.jl b/test/test_batch_diff.jl new file mode 100644 index 0000000..497e769 --- /dev/null +++ b/test/test_batch_diff.jl @@ -0,0 +1,86 @@ +using Test, LinearAlgebra, SparseArrays +using MadDiff, MadNLP, MadIPM, NLPModels, QuadraticModels + +const BatchMadDiffSolver = Base.get_extension(MadDiff, :MadIPMExt).BatchMadDiffSolver + +@testset "Batch JVP/VJP" begin + T = Float64 + # Simple parametric QP: + # min ½x₁² + ½x₂² + (F*θ)ᵀx + # s.t. x₁ + x₂ ≥ 1 + B*θ + # 0 ≤ x₁, x₂ ≤ 10 + nvar = 2; ncon = 1; nparam = 1; nbatch = 3 + + c = zeros(T, nvar) + H = sparse([1, 2], [1, 2], [1.0, 1.0], nvar, nvar) + A = sparse([1, 1], [1, 2], [1.0, 1.0], ncon, nvar) + F = [1.0; 0.0;;] # nvar × nparam + B = [1.0;;] # ncon × nparam + + # Different parameter values per instance + θ_vals = [0.0, 0.1, 0.2] + + # Build sequential models + lpqps = [LinearParametricQuadraticModel( + c, H, A, F, B, [θ]; + lvar = [0.0, 0.0], uvar = [10.0, 10.0], + lcon = [1.0], ucon = [Inf], + ) for θ in θ_vals] + + # Build batch model from sequential models + bqp = BatchLinearParametricQuadraticModel(lpqps) + + # Solve batch + batch_solver = MadIPM.UniformBatchMPCSolver(bqp; tol=1e-8, max_iter=200) + MadIPM.solve!(batch_solver) + + # Build batch diff solver + bsens = BatchMadDiffSolver(batch_solver) + + # Parameter perturbation + Δp = ones(T, nparam, nbatch) + + @testset "Batch JVP" begin + result = MadDiff.jacobian_vector_product!(bsens, Δp) + + for j in 1:nbatch + seq_solver = MadIPM.MPCSolver(lpqps[j]; print_level=MadNLP.ERROR, tol=1e-8, max_iter=200) + MadIPM.solve!(seq_solver) + seq_sens = MadDiffSolver(seq_solver) + seq_result = MadDiff.jacobian_vector_product!(seq_sens, Δp[:, j]) + + @test result.dx[:, j] ≈ seq_result.dx atol=1e-6 + @test result.dy[:, j] ≈ seq_result.dy atol=1e-6 + @test result.dobj[j] ≈ seq_result.dobj[] atol=1e-6 + end + end + + @testset "Batch VJP" begin + dobj = ones(T, nbatch) + result = MadDiff.vector_jacobian_product!(bsens; dobj = dobj) + + for j in 1:nbatch + seq_solver = MadIPM.MPCSolver(lpqps[j]; print_level=MadNLP.ERROR, tol=1e-8, max_iter=200) + MadIPM.solve!(seq_solver) + seq_sens = MadDiffSolver(seq_solver) + seq_result = MadDiff.vector_jacobian_product!(seq_sens; dobj = dobj[j]) + + @test result.grad_p[:, j] ≈ seq_result.grad_p atol=1e-6 + end + end + + @testset "Batch VJP with dL_dx" begin + nvar_nlp = bqp.meta.nvar + dL_dx = ones(T, nvar_nlp, nbatch) + result = MadDiff.vector_jacobian_product!(bsens; dL_dx = dL_dx) + + for j in 1:nbatch + seq_solver = MadIPM.MPCSolver(lpqps[j]; print_level=MadNLP.ERROR, tol=1e-8, max_iter=200) + MadIPM.solve!(seq_solver) + seq_sens = MadDiffSolver(seq_solver) + seq_result = MadDiff.vector_jacobian_product!(seq_sens; dL_dx = dL_dx[:, j]) + + @test result.grad_p[:, j] ≈ seq_result.grad_p atol=1e-6 + end + end +end diff --git a/test/test_batch_diff_cuda.jl b/test/test_batch_diff_cuda.jl new file mode 100644 index 0000000..2b29a01 --- /dev/null +++ b/test/test_batch_diff_cuda.jl @@ -0,0 +1,116 @@ +using Test, LinearAlgebra, SparseArrays +using MadDiff, MadNLP, MadIPM, NLPModels, QuadraticModels +using CUDA + +if !CUDA.functional() + @info "CUDA not functional, skipping batch CUDA diff tests" +else + +include("test_batch_diff.jl") # reuse TestBatchParametricQP and helpers + +# Build a CuArray-backed batch model from the CPU one +function gpu_batch_parametric_qp(c, H, A, F, B, θ_batch; lvar, uvar, lcon, ucon) + T = eltype(c) + nvar = length(c) + ncon = size(A, 1) + nparam = size(F, 2) + nbatch = size(θ_batch, 2) + MT = CuMatrix{T} + + meta = NLPModels.BatchNLPModelMeta{T, MT}( + nbatch, nvar; + ncon = ncon, + lvar = CuMatrix(lvar), uvar = CuMatrix(uvar), + lcon = CuMatrix(lcon), ucon = CuMatrix(ucon), + nnzj = nnz(A), + nnzh = nnz(H), + islp = (nnz(H) == 0), + nparam = nparam, + nnzgp = nparam, + nnzjp = length(B), + nnzhp = length(F), + grad_param_available = true, + jac_param_available = ncon > 0, + hess_param_available = true, + jpprod_available = ncon > 0, + jptprod_available = ncon > 0, + hpprod_available = true, + hptprod_available = true, + ) + + c_eff = CuMatrix(repeat(c, 1, nbatch) .+ F * θ_batch) + Bθ = CuMatrix(B * θ_batch) + _HX = CUDA.zeros(T, nvar, nbatch) + + return TestBatchParametricQP{T, MT}( + meta, CuSparseMatrixCSC(H), CuSparseMatrixCSC(A), + F, B, # F, B stay on CPU (used via mul! with CuMatrix) + CuMatrix(θ_batch), copy(c), c_eff, Bθ, _HX, + ) +end + +@testset "Batch JVP/VJP CUDA" begin + T = Float64 + nvar = 2; ncon = 1; nparam = 1; nbatch = 3 + + c = zeros(T, nvar) + H = sparse([1, 2], [1, 2], [1.0, 1.0], nvar, nvar) + A = sparse([1, 1], [1, 2], [1.0, 1.0], ncon, nvar) + F = [1.0; 0.0;;] + B = [1.0;;] + θ_batch = [0.0 0.5 1.0] + + lcon = fill(1.0, ncon, nbatch) .+ B * θ_batch + ucon = fill(T(Inf), ncon, nbatch) + lvar = fill(0.0, nvar, nbatch) + uvar = fill(10.0, nvar, nbatch) + + # CPU reference + bqp_cpu = TestBatchParametricQP(c, H, A, F, B, θ_batch; + lvar = lvar, uvar = uvar, lcon = lcon, ucon = ucon) + batch_solver_cpu = MadIPM.UniformBatchMPCSolver(bqp_cpu) + MadIPM.solve!(batch_solver_cpu) + bsens_cpu = MadDiff.MadIPMExt.BatchMadDiffSolver(batch_solver_cpu) + + Δp_cpu = ones(T, nparam, nbatch) + jvp_cpu = MadDiff.jacobian_vector_product!(bsens_cpu, Δp_cpu) + + # GPU solve + bqp_gpu = gpu_batch_parametric_qp(c, H, A, F, B, θ_batch; + lvar = lvar, uvar = uvar, lcon = lcon, ucon = ucon) + batch_solver_gpu = MadIPM.UniformBatchMPCSolver(bqp_gpu) + MadIPM.solve!(batch_solver_gpu) + bsens_gpu = MadDiff.MadIPMExt.BatchMadDiffSolver(batch_solver_gpu) + + @testset "JVP" begin + Δp_gpu = CuMatrix(Δp_cpu) + jvp_gpu = MadDiff.jacobian_vector_product!(bsens_gpu, Δp_gpu) + + @test Array(jvp_gpu.dx) ≈ jvp_cpu.dx atol=1e-5 + @test Array(jvp_gpu.dy) ≈ jvp_cpu.dy atol=1e-5 + @test Array(jvp_gpu.dobj) ≈ jvp_cpu.dobj atol=1e-5 + end + + @testset "VJP with dobj" begin + dobj_cpu = ones(T, nbatch) + vjp_cpu = MadDiff.vector_jacobian_product!(bsens_cpu; dobj = dobj_cpu) + + dobj_gpu = CuVector(dobj_cpu) + vjp_gpu = MadDiff.vector_jacobian_product!(bsens_gpu; dobj = dobj_gpu) + + @test Array(vjp_gpu.grad_p) ≈ vjp_cpu.grad_p atol=1e-5 + end + + @testset "VJP with dL_dx" begin + nvar_nlp = bqp_cpu.meta.nvar + dL_dx_cpu = ones(T, nvar_nlp, nbatch) + vjp_cpu = MadDiff.vector_jacobian_product!(bsens_cpu; dL_dx = dL_dx_cpu) + + dL_dx_gpu = CuMatrix(dL_dx_cpu) + vjp_gpu = MadDiff.vector_jacobian_product!(bsens_gpu; dL_dx = dL_dx_gpu) + + @test Array(vjp_gpu.grad_p) ≈ vjp_cpu.grad_p atol=1e-5 + end +end + +end # if CUDA.functional()