Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 140 additions & 68 deletions src/C_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,36 @@
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

mutable struct IpoptProblem
mutable struct IntermediateCallbackWrapper
f::Function
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one okay? Is it just that you didn't test JuliaC with an intermediate callback?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed in latest commit.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand: why is the non-type-stable f::Function okay here, but it wasn't okay in IpoptProblem for eval_f::Function etc?

end
(cb::IntermediateCallbackWrapper)(
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
) = cb.f(
alg_mod,
iter_count,
obj_value,
inf_pr,
inf_du,
mu,
d_norm,
regularization_size,
alpha_du,
alpha_pr,
ls_trials,
)

mutable struct IpoptProblem{F,G,GF,JG,H,I}
ipopt_problem::Ptr{Cvoid} # Reference to the internal data structure
n::Int # Num vars
m::Int # Num cons
Expand All @@ -15,12 +44,12 @@ mutable struct IpoptProblem
obj_val::Float64 # Final objective
status::Cint # Final status
# Callbacks
eval_f::Function
eval_g::Function
eval_grad_f::Function
eval_jac_g::Function
eval_h::Union{Function,Nothing}
intermediate::Union{Function,Nothing}
eval_f::F
eval_g::G
eval_grad_f::GF
eval_jac_g::JG
eval_h::H
intermediate::I
end

Base.unsafe_convert(::Type{Ptr{Cvoid}}, p::IpoptProblem) = p.ipopt_problem
Expand All @@ -30,9 +59,9 @@ function _Eval_F_CB(
x_ptr::Ptr{Float64},
x_new::Cint,
obj_value::Ptr{Float64},
user_data::Ptr{Cvoid},
)
prob = unsafe_pointer_to_objref(user_data)::IpoptProblem
user_data::Ptr{IpoptProblem{F,G,GF,JG,H,I}},
) where {F,G,GF,JG,H,I}
prob = unsafe_pointer_to_objref(Ptr{Cvoid}(user_data))::IpoptProblem{F,G,GF,JG,H,I}
x = unsafe_wrap(Array, x_ptr, Int(n))
if x_new == Cint(1)
prob.x .= x
Expand All @@ -48,9 +77,9 @@ function _Eval_Grad_F_CB(
# A Bool indicating if `x` is a new point. We don't make use of this.
::Cint,
grad_f::Ptr{Float64},
user_data::Ptr{Cvoid},
)
prob = unsafe_pointer_to_objref(user_data)::IpoptProblem
user_data::Ptr{IpoptProblem{F,G,GF,JG,H,I}},
) where {F,G,GF,JG,H,I}
prob = unsafe_pointer_to_objref(Ptr{Cvoid}(user_data))::IpoptProblem{F,G,GF,JG,H,I}
new_grad_f = unsafe_wrap(Array, grad_f, Int(n))
x = unsafe_wrap(Array, x_ptr, Int(n))
prob.eval_grad_f(x, new_grad_f)
Expand All @@ -63,9 +92,9 @@ function _Eval_G_CB(
x_new::Cint,
m::Cint,
g_ptr::Ptr{Float64},
user_data::Ptr{Cvoid},
)
prob = unsafe_pointer_to_objref(user_data)::IpoptProblem
user_data::Ptr{IpoptProblem{F,G,GF,JG,H,I}},
) where {F,G,GF,JG,H,I}
prob = unsafe_pointer_to_objref(Ptr{Cvoid}(user_data))::IpoptProblem{F,G,GF,JG,H,I}
new_g = unsafe_wrap(Array, g_ptr, Int(m))
x = unsafe_wrap(Array, x_ptr, Int(n))
if x_new == Cint(1)
Expand All @@ -84,9 +113,9 @@ function _Eval_Jac_G_CB(
iRow::Ptr{Cint},
jCol::Ptr{Cint},
values_ptr::Ptr{Float64},
user_data::Ptr{Cvoid},
)
prob = unsafe_pointer_to_objref(user_data)::IpoptProblem
user_data::Ptr{IpoptProblem{F,G,GF,JG,H,I}},
) where {F,G,GF,JG,H,I}
prob = unsafe_pointer_to_objref(Ptr{Cvoid}(user_data))::IpoptProblem{F,G,GF,JG,H,I}
x = unsafe_wrap(Array, x_ptr, Int(n))
rows = unsafe_wrap(Array, iRow, Int(nele_jac))
cols = unsafe_wrap(Array, jCol, Int(nele_jac))
Expand All @@ -111,9 +140,9 @@ function _Eval_H_CB(
iRow::Ptr{Cint},
jCol::Ptr{Cint},
values_ptr::Ptr{Float64},
user_data::Ptr{Cvoid},
)
prob = unsafe_pointer_to_objref(user_data)::IpoptProblem
user_data::Ptr{IpoptProblem{F,G,GF,JG,H,I}},
) where {F,G,GF,JG,H,I}
prob = unsafe_pointer_to_objref(Ptr{Cvoid}(user_data))::IpoptProblem{F,G,GF,JG,H,I}
if prob.eval_h === nothing
# No hessian. Return FALSE for failure.
return Cint(0)
Expand Down Expand Up @@ -143,11 +172,11 @@ function _Intermediate_CB(
alpha_du::Float64,
alpha_pr::Float64,
ls_trials::Cint,
user_data::Ptr{Cvoid},
)::Cint
user_data::Ptr{IpoptProblem{F,G,GF,JG,H,I}},
) where {F,G,GF,JG,H,I}
try
return reenable_sigint() do
prob = unsafe_pointer_to_objref(user_data)::IpoptProblem
ret = reenable_sigint() do
prob = unsafe_pointer_to_objref(Ptr{Cvoid}(user_data))::IpoptProblem{F,G,GF,JG,H,I}
return prob.intermediate(
alg_mod,
iter_count,
Expand All @@ -162,11 +191,12 @@ function _Intermediate_CB(
ls_trials,
)
end
return Cint(ret)
catch err
if !(err isa InterruptException)
rethrow(err)
end
return false # optimization should stop
return Cint(0) # optimization should stop
end
end

Expand All @@ -179,28 +209,30 @@ function CreateIpoptProblem(
g_U::Vector{Float64},
nele_jac::Int,
nele_hess::Int,
eval_f,
eval_g,
eval_grad_f,
eval_jac_g,
eval_h,
)
eval_f::F,
eval_g::G,
eval_grad_f::GF,
eval_jac_g::JG,
eval_h::H,
intermediate::I,
) where {F,G,GF,JG,H,I}

@assert n == length(x_L) == length(x_U)
@assert m == length(g_L) == length(g_U)
eval_f_cb = @cfunction(
_Eval_F_CB,
Cint,
(Cint, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Cvoid}),
(Cint, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{IpoptProblem{F,G,GF,JG,H,I}}),
)
eval_g_cb = @cfunction(
_Eval_G_CB,
Cint,
(Cint, Ptr{Float64}, Cint, Cint, Ptr{Float64}, Ptr{Cvoid}),
(Cint, Ptr{Float64}, Cint, Cint, Ptr{Float64}, Ptr{IpoptProblem{F,G,GF,JG,H,I}}),
)
eval_grad_f_cb = @cfunction(
_Eval_Grad_F_CB,
Cint,
(Cint, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Cvoid}),
(Cint, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{IpoptProblem{F,G,GF,JG,H,I}}),
)
eval_jac_g_cb = @cfunction(
_Eval_Jac_G_CB,
Expand All @@ -214,7 +246,7 @@ function CreateIpoptProblem(
Ptr{Cint},
Ptr{Cint},
Ptr{Float64},
Ptr{Cvoid},
Ptr{IpoptProblem{F,G,GF,JG,H,I}},
),
)
eval_h_cb = @cfunction(
Expand All @@ -232,7 +264,7 @@ function CreateIpoptProblem(
Ptr{Cint},
Ptr{Cint},
Ptr{Float64},
Ptr{Cvoid},
Ptr{IpoptProblem{F,G,GF,JG,H,I}},
),
)
ipopt_problem = @ccall libipopt.CreateIpoptProblem(
Expand Down Expand Up @@ -262,7 +294,29 @@ function CreateIpoptProblem(
error("IPOPT: Failed to construct problem for some unknown reason.")
end
end
prob = IpoptProblem(
intermediate_cb = @cfunction(
_Intermediate_CB,
Cint,
(
Cint,
Cint,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Cint,
Ptr{IpoptProblem{F,G,GF,JG,H,I}},
),
)
@ccall libipopt.SetIntermediateCallback(
ipopt_problem::Ptr{Cvoid},
intermediate_cb::Ptr{Cvoid},
)::Bool
prob = IpoptProblem{F,G,GF,JG,H,I}(
ipopt_problem,
n,
m,
Expand All @@ -278,12 +332,58 @@ function CreateIpoptProblem(
eval_grad_f,
eval_jac_g,
eval_h,
nothing,
intermediate,
)
finalizer(FreeIpoptProblem, prob)
return prob
end

function CreateIpoptProblem(
n::Int,
x_L::Vector{Float64},
x_U::Vector{Float64},
m::Int,
g_L::Vector{Float64},
g_U::Vector{Float64},
nele_jac::Int,
nele_hess::Int,
eval_f::F,
eval_g::G,
eval_grad_f::GF,
eval_jac_g::JG,
eval_h::H,
) where {F,G,GF,JG,H}
return CreateIpoptProblem(
n,
x_L,
x_U,
m,
g_L,
g_U,
nele_jac,
nele_hess,
eval_f,
eval_g,
eval_grad_f,
eval_jac_g,
eval_h,
IntermediateCallbackWrapper((args...) -> Cint(1)),
)
end

function SetIntermediateCallback(
prob::IpoptProblem{F,G,GF,JG,H,IntermediateCallbackWrapper},
f::Function,
) where {F,G,GF,JG,H}
prob.intermediate.f = f
return
end

function SetIntermediateCallback(prob, f)
error("Cannot SetIntermediateCallback if intermediate was set in the initial call to CreateIpoptProblem")
return
end

function FreeIpoptProblem(prob::IpoptProblem)
@ccall libipopt.FreeIpoptProblem(prob::Ptr{Cvoid})::Cvoid
return
Expand Down Expand Up @@ -373,34 +473,6 @@ function SetIpoptProblemScaling(
return
end

function SetIntermediateCallback(prob::IpoptProblem, intermediate::Function)
intermediate_cb = @cfunction(
_Intermediate_CB,
Cint,
(
Cint,
Cint,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Cint,
Ptr{Cvoid},
),
)
ret = @ccall libipopt.SetIntermediateCallback(
prob::Ptr{Cvoid},
intermediate_cb::Ptr{Cvoid},
)::Bool
@assert ret # The C++ code has `return true`
prob.intermediate = intermediate
return
end

function IpoptSolve(prob::IpoptProblem)
p_objval = Ref{Cdouble}(0.0)
disable_sigint() do
Expand Down
1 change: 1 addition & 0 deletions test/MOI_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function runtests()
for name in names(@__MODULE__; all = true)
if startswith("$(name)", "test_")
@testset "$(name)" begin
@info "$(name)"
getfield(@__MODULE__, name)()
end
end
Expand Down
Loading