Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 47 additions & 53 deletions src/controller/nonlinmpc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -729,73 +729,69 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
end

# TODO: move docstring of method above here an re-work it
function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<:Real
function get_nonlinops(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT<:Real
# ----------- common cache for all functions ----------------------------------------
model = mpc.estim.model
transcription = mpc.transcription
grad, jac = mpc.gradient, mpc.jacobian
nu, ny, nx̂, nϵ = model.nu, model.ny, mpc.estim.nx̂, mpc.nϵ
nk = get_nk(model, transcription)
Hp, Hc = mpc.Hp, mpc.Hc
ng, nc, neq = length(mpc.con.i_g), mpc.con.nc, mpc.con.neq
i_g = findall(mpc.con.i_g) # convert to non-logical indices for non-allocating @views
ng, ngi = length(mpc.con.i_g), sum(mpc.con.i_g)
nc, neq = mpc.con.nc, mpc.con.neq
nZ̃, nU, nŶ, nX̂, nK = length(mpc.Z̃), Hp*nu, Hp*ny, Hp*nx̂, Hp*nk
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
strict = Val(true)
myNaN, myInf = convert(JNT, NaN), convert(JNT, Inf)
J::Vector{JNT} = zeros(JNT, 1)
ΔŨ::Vector{JNT} = zeros(JNT, nΔŨ)
x̂0end::Vector{JNT} = zeros(JNT, nx̂)
K0::Vector{JNT} = zeros(JNT, nK)
Ue::Vector{JNT}, Ŷe::Vector{JNT} = zeros(JNT, nUe), zeros(JNT, nŶe)
U0::Vector{JNT}, Ŷ0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nŶ)
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
geq::Vector{JNT} = zeros(JNT, neq)
myNaN, myInf = convert(JNT, NaN), convert(JNT, Inf)
J::Vector{JNT} = zeros(JNT, 1)
ΔŨ::Vector{JNT} = zeros(JNT, nΔŨ)
x̂0end::Vector{JNT} = zeros(JNT, nx̂)
K0::Vector{JNT} = zeros(JNT, nK)
Ue::Vector{JNT}, Ŷe::Vector{JNT} = zeros(JNT, nUe), zeros(JNT, nŶe)
U0::Vector{JNT}, Ŷ0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nŶ)
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
gi::Vector{JNT}, geq::Vector{JNT} = zeros(JNT, ngi), zeros(JNT, neq)
# -------------- inequality constraint: nonlinear oracle -----------------------------
function g!(g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq)
function gi!(gi, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
gi .= @views g[i_g]
return nothing
end
Z̃_∇g = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
g_context = (
Z̃_∇gi = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
gi_context = (
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
Cache(Û0), Cache(K0), Cache(X̂0),
Cache(gc), Cache(geq),
Cache(gc), Cache(geq), Cache(g)
)
## temporarily enable all the inequality constraints for sparsity detection:
# mpc.con.i_g[1:end-nc] .= true
∇g_prep = prepare_jacobian(g!, g, jac, Z̃_∇g, ∇g_context...; strict)
# mpc.con.i_g[1:end-nc] .= false
∇g = init_diffmat(JNT, jac, ∇g_prep, nZ̃, ng)
function update_con!(g, ∇g, Z̃_∇g, Z̃_arg)
if isdifferent(Z̃_arg, Z̃_∇g)
Z̃_∇g .= Z̃_arg
value_and_jacobian!(g!, g, ∇g, ∇g_prep, jac, Z̃_∇g, ∇g_context...)
∇gi_prep = prepare_jacobian(gi!, gi, jac, Z̃_∇gi, ∇gi_context...; strict)
∇gi = init_diffmat(JNT, jac, ∇gi_prep, nZ̃, ngi)
function update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
if isdifferent(Z̃_arg, Z̃_∇gi)
Z̃_∇gi .= Z̃_arg
value_and_jacobian!(gi!, gi, ∇gi, ∇gi_prep, jac, Z̃_∇gi, ∇gi_context...)
end
return nothing
end
function gfunc_oracle!(g_arg, Z̃_arg)
update_con!(g, ∇g, Z̃_∇g, Z̃_arg)
g_arg .= @views g[mpc.con.i_g]
return nothing
function gi_func!(gi_vec, Z̃_arg)
update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
return gi_vec .= gi
end
∇g_i_g = ∇g[mpc.con.i_g, :]
function ∇gfunc_oracle!(∇g_arg, Z̃_arg)
update_con!(g, ∇g, Z̃_∇g, Z̃_arg)
∇g_i_g .= @views ∇g[mpc.con.i_g, :]
diffmat2vec!(∇g_arg, ∇g_i_g)
return nothing
function ∇gi_func!(∇gi_vec, Z̃_arg)
update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
return diffmat2vec!(∇gi_vec, ∇gi)
end
g_min = fill(-myInf, sum(mpc.con.i_g))
g_max = zeros(JNT, sum(mpc.con.i_g))
g_structure = init_diffstructure(∇g[mpc.con.i_g, :])
gi_min = fill(-myInf, ngi)
gi_max = zeros(JNT, ngi)
gi_structure = init_diffstructure(∇gi)
g_oracle = Ipopt._VectorNonlinearOracle(;
dimension = nZ̃,
l = g_min,
u = g_max,
eval_f = gfunc_oracle!,
jacobian_structure = ∇g_structure,
eval_jacobian = ∇gfunc_oracle!
l = gi_min,
u = gi_max,
eval_f = gi_func!,
jacobian_structure = ∇gi_structure,
eval_jacobian = ∇gi_func!
)
# ------------- equality constraints : nonlinear oracle ------------------------------
function geq!(geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
Expand All @@ -817,25 +813,23 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
end
return nothing
end
function geq_oracle!(geq_arg, Z̃_arg)
function geq_func!(geq_vec, Z̃_arg)
update_con_eq!(geq, ∇geq, Z̃_∇geq, Z̃_arg)
geq_arg .= geq
return nothing
return geq_vec .= geq
end
function ∇geq_oracle!(∇geq_arg, Z̃_arg)
function ∇geq_func!(∇geq_vec, Z̃_arg)
update_con_eq!(geq, ∇geq, Z̃_∇geq, Z̃_arg)
diffmat2vec!(∇geq_arg, ∇geq)
return nothing
return diffmat2vec!(∇geq_vec, ∇geq)
end
geq_min = geq_max = zeros(JNT, neq)
∇geq_structure = init_diffstructure(∇geq)
geq_oracle = Ipopt._VectorNonlinearOracle(;
dimension = nZ̃,
l = geq_min,
u = geq_max,
eval_f = geq_oracle!,
eval_f = geq_func!,
jacobian_structure = ∇geq_structure,
eval_jacobian = ∇geq_oracle!
eval_jacobian = ∇geq_func!
)
# ------------- objective function: splatting syntax ---------------------------------
function J!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
Expand Down Expand Up @@ -909,8 +903,8 @@ function set_nonlincon_exp!(
optim, JuMP.Vector{JuMP.VariableRef}, Ipopt._VectorNonlinearOracle
)
map(con_ref -> JuMP.delete(optim, con_ref), nonlin_constraints)
@constraint(optim, Z̃var in g_oracle)
mpc.con.neq > 0 && @constraint(optim, Z̃var in geq_oracle)
any(mpc.con.i_g) && @constraint(optim, Z̃var in g_oracle)
mpc.con.neq > 0 && @constraint(optim, Z̃var in geq_oracle)
return nothing
end

Expand Down