Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions ext/DualizationJuMPExt/DualizationJuMPExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ import JuMP
import MathOptInterface as MOI

function Dualization.dualize(
model::JuMP.Model,
model::JuMP.GenericModel{T},
optimizer_constructor = nothing;
kwargs...,
)
) where {T}
mode = JuMP.mode(model)
if mode != JuMP.AUTOMATIC
error("Dualization does not support solvers in $(mode) mode")
end
dual_model = JuMP.Model()
dual_problem = Dualization.DualProblem(JuMP.backend(dual_model))
dual_model = JuMP.GenericModel{T}()
dual_problem = Dualization.DualProblem{T}(JuMP.backend(dual_model))
Dualization.dualize(JuMP.backend(model), dual_problem; kwargs...)
_fill_obj_dict_with_variables!(dual_model)
_fill_obj_dict_with_constraints!(dual_model)
Expand All @@ -31,21 +31,21 @@ function Dualization.dualize(
return dual_model
end

function _fill_obj_dict_with_variables!(model::JuMP.Model)
function _fill_obj_dict_with_variables!(model::JuMP.GenericModel)
list = MOI.get(model, MOI.ListOfVariableAttributesSet())
if !(MOI.VariableName() in list)
return
end
for vi in MOI.get(model, MOI.ListOfVariableIndices())
name = MOI.get(JuMP.backend(model), MOI.VariableName(), vi)
if !isempty(name)
model[Symbol(name)] = JuMP.VariableRef(model, vi)
model[Symbol(name)] = JuMP.GenericVariableRef(model, vi)
end
end
return
end

function _fill_obj_dict_with_constraints!(model::JuMP.Model)
function _fill_obj_dict_with_constraints!(model::JuMP.GenericModel)
con_types = MOI.get(model, MOI.ListOfConstraintTypesPresent())
for (F, S) in con_types
_fill_obj_dict_with_constraints!(model, F, S)
Expand All @@ -54,7 +54,7 @@ function _fill_obj_dict_with_constraints!(model::JuMP.Model)
end

function _fill_obj_dict_with_constraints!(
model::JuMP.Model,
model::JuMP.GenericModel,
::Type{F},
::Type{S},
) where {F,S}
Expand All @@ -71,13 +71,13 @@ function _fill_obj_dict_with_constraints!(
return
end

function _get_primal_dual_map(model::JuMP.Model)
function _get_primal_dual_map(model::JuMP.GenericModel)
return model.ext[:_Dualization_jl_PrimalDualMap]
end

function Dualization._get_dual_constraint(
dual_model,
primal_ref::JuMP.VariableRef,
primal_ref::JuMP.GenericVariableRef,
)
map = _get_primal_dual_map(dual_model)
moi_primal_vi = JuMP.index(primal_ref)
Expand All @@ -91,8 +91,8 @@ function Dualization._get_dual_constraint(
end

function Dualization._get_primal_constraint(
dual_model::JuMP.Model,
primal_vi::JuMP.VariableRef,
dual_model::JuMP.GenericModel,
primal_vi::JuMP.GenericVariableRef,
)
primal_model = JuMP.owner_model(primal_vi)
map = _get_primal_dual_map(dual_model)
Expand All @@ -105,22 +105,22 @@ function Dualization._get_primal_constraint(
end

function Dualization._get_dual_variables(
dual_model::JuMP.Model,
dual_model::JuMP.GenericModel{T},
primal_ref::JuMP.ConstraintRef,
)
) where {T}
map = _get_primal_dual_map(dual_model)
moi_primal_ci = JuMP.index(primal_ref)
moi_dual_vis = Dualization._get_dual_variables(map, moi_primal_ci)
if moi_dual_vis === nothing
# main constraint of a constrained variable
return nothing
end
return [JuMP.VariableRef(dual_model, vi) for vi in moi_dual_vis]
return [JuMP.GenericVariableRef{T}(dual_model, vi) for vi in moi_dual_vis]
end

# this is a constrained variable constraint
function Dualization._get_dual_constraint(
dual_model::JuMP.Model,
dual_model::JuMP.GenericModel,
primal_ref::JuMP.ConstraintRef,
)
map = _get_primal_dual_map(dual_model)
Expand All @@ -136,24 +136,24 @@ function Dualization._get_dual_constraint(
end

function Dualization._get_dual_parameter(
dual_model::JuMP.Model,
primal_ref::JuMP.VariableRef,
)
dual_model::JuMP.GenericModel{T},
primal_ref::JuMP.GenericVariableRef,
) where {T}
map = _get_primal_dual_map(dual_model)
moi_primal_vi = JuMP.index(primal_ref)
moi_dual_vi = Dualization._get_dual_parameter(map, moi_primal_vi)
# the above line might error
return JuMP.VariableRef(dual_model, moi_dual_vi)
return JuMP.GenericVariableRef{T}(dual_model, moi_dual_vi)
end

function Dualization._get_dual_slack_variable(
dual_model::JuMP.Model,
primal_ref::JuMP.VariableRef,
)
dual_model::JuMP.GenericModel{T},
primal_ref::JuMP.GenericVariableRef,
) where {T}
map = _get_primal_dual_map(dual_model)
moi_primal_vi = JuMP.index(primal_ref)
moi_dual_vi = Dualization._get_dual_slack_variable(map, moi_primal_vi)
return JuMP.VariableRef(dual_model, moi_dual_vi)
return JuMP.GenericVariableRef{T}(dual_model, moi_dual_vi)
end

end # module DualizationJuMPExt
18 changes: 18 additions & 0 deletions test/Tests/test_JuMP_dualize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,24 @@ end
con = Dualization._get_dual_constraint(dual_model, cv)
@test con isa ConstraintRef
end
@testset "GenericModel{$T}" for T in (Float32, BigFloat)
model = JuMP.GenericModel{T}()
JuMP.@variable(model, x >= zero(T))
JuMP.@constraint(model, c, x <= one(T) + one(T))
JuMP.@objective(model, Max, T(2) * x + one(T))
dual_model = Dualization.dualize(
model;
dual_names = DualNames("dual_", "dual_"),
consider_constrained_variables = false,
)
@test dual_model isa JuMP.GenericModel{T}
@test num_variables(dual_model) == 2
con = Dualization._get_dual_constraint(dual_model, x)
@test con[1] isa ConstraintRef
var = Dualization._get_dual_variables(dual_model, c)
@test length(var) == 1
@test var[] isa JuMP.GenericVariableRef{T}
end
@testset "JuMP parametric quadratic" begin
model = Model()
@variable(model, x)
Expand Down
Loading