Skip to content

Commit 685be6c

Browse files
authored
Fix dualize for arbitrary coefficient types in JuMP (#226)
1 parent d0743cc commit 685be6c

2 files changed

Lines changed: 42 additions & 24 deletions

File tree

ext/DualizationJuMPExt/DualizationJuMPExt.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@ import JuMP
1010
import MathOptInterface as MOI
1111

1212
function Dualization.dualize(
13-
model::JuMP.Model,
13+
model::JuMP.GenericModel{T},
1414
optimizer_constructor = nothing;
1515
kwargs...,
16-
)
16+
) where {T}
1717
mode = JuMP.mode(model)
1818
if mode != JuMP.AUTOMATIC
1919
error("Dualization does not support solvers in $(mode) mode")
2020
end
21-
dual_model = JuMP.Model()
22-
dual_problem = Dualization.DualProblem(JuMP.backend(dual_model))
21+
dual_model = JuMP.GenericModel{T}()
22+
dual_problem = Dualization.DualProblem{T}(JuMP.backend(dual_model))
2323
Dualization.dualize(JuMP.backend(model), dual_problem; kwargs...)
2424
_fill_obj_dict_with_variables!(dual_model)
2525
_fill_obj_dict_with_constraints!(dual_model)
@@ -31,21 +31,21 @@ function Dualization.dualize(
3131
return dual_model
3232
end
3333

34-
function _fill_obj_dict_with_variables!(model::JuMP.Model)
34+
function _fill_obj_dict_with_variables!(model::JuMP.GenericModel)
3535
list = MOI.get(model, MOI.ListOfVariableAttributesSet())
3636
if !(MOI.VariableName() in list)
3737
return
3838
end
3939
for vi in MOI.get(model, MOI.ListOfVariableIndices())
4040
name = MOI.get(JuMP.backend(model), MOI.VariableName(), vi)
4141
if !isempty(name)
42-
model[Symbol(name)] = JuMP.VariableRef(model, vi)
42+
model[Symbol(name)] = JuMP.GenericVariableRef(model, vi)
4343
end
4444
end
4545
return
4646
end
4747

48-
function _fill_obj_dict_with_constraints!(model::JuMP.Model)
48+
function _fill_obj_dict_with_constraints!(model::JuMP.GenericModel)
4949
con_types = MOI.get(model, MOI.ListOfConstraintTypesPresent())
5050
for (F, S) in con_types
5151
_fill_obj_dict_with_constraints!(model, F, S)
@@ -54,7 +54,7 @@ function _fill_obj_dict_with_constraints!(model::JuMP.Model)
5454
end
5555

5656
function _fill_obj_dict_with_constraints!(
57-
model::JuMP.Model,
57+
model::JuMP.GenericModel,
5858
::Type{F},
5959
::Type{S},
6060
) where {F,S}
@@ -71,13 +71,13 @@ function _fill_obj_dict_with_constraints!(
7171
return
7272
end
7373

74-
function _get_primal_dual_map(model::JuMP.Model)
74+
function _get_primal_dual_map(model::JuMP.GenericModel)
7575
return model.ext[:_Dualization_jl_PrimalDualMap]
7676
end
7777

7878
function Dualization._get_dual_constraint(
7979
dual_model,
80-
primal_ref::JuMP.VariableRef,
80+
primal_ref::JuMP.GenericVariableRef,
8181
)
8282
map = _get_primal_dual_map(dual_model)
8383
moi_primal_vi = JuMP.index(primal_ref)
@@ -91,8 +91,8 @@ function Dualization._get_dual_constraint(
9191
end
9292

9393
function Dualization._get_primal_constraint(
94-
dual_model::JuMP.Model,
95-
primal_vi::JuMP.VariableRef,
94+
dual_model::JuMP.GenericModel,
95+
primal_vi::JuMP.GenericVariableRef,
9696
)
9797
primal_model = JuMP.owner_model(primal_vi)
9898
map = _get_primal_dual_map(dual_model)
@@ -105,22 +105,22 @@ function Dualization._get_primal_constraint(
105105
end
106106

107107
function Dualization._get_dual_variables(
108-
dual_model::JuMP.Model,
108+
dual_model::JuMP.GenericModel{T},
109109
primal_ref::JuMP.ConstraintRef,
110-
)
110+
) where {T}
111111
map = _get_primal_dual_map(dual_model)
112112
moi_primal_ci = JuMP.index(primal_ref)
113113
moi_dual_vis = Dualization._get_dual_variables(map, moi_primal_ci)
114114
if moi_dual_vis === nothing
115115
# main constraint of a constrained variable
116116
return nothing
117117
end
118-
return [JuMP.VariableRef(dual_model, vi) for vi in moi_dual_vis]
118+
return [JuMP.GenericVariableRef{T}(dual_model, vi) for vi in moi_dual_vis]
119119
end
120120

121121
# this is a constrained variable constraint
122122
function Dualization._get_dual_constraint(
123-
dual_model::JuMP.Model,
123+
dual_model::JuMP.GenericModel,
124124
primal_ref::JuMP.ConstraintRef,
125125
)
126126
map = _get_primal_dual_map(dual_model)
@@ -136,24 +136,24 @@ function Dualization._get_dual_constraint(
136136
end
137137

138138
function Dualization._get_dual_parameter(
139-
dual_model::JuMP.Model,
140-
primal_ref::JuMP.VariableRef,
141-
)
139+
dual_model::JuMP.GenericModel{T},
140+
primal_ref::JuMP.GenericVariableRef,
141+
) where {T}
142142
map = _get_primal_dual_map(dual_model)
143143
moi_primal_vi = JuMP.index(primal_ref)
144144
moi_dual_vi = Dualization._get_dual_parameter(map, moi_primal_vi)
145145
# the above line might error
146-
return JuMP.VariableRef(dual_model, moi_dual_vi)
146+
return JuMP.GenericVariableRef{T}(dual_model, moi_dual_vi)
147147
end
148148

149149
function Dualization._get_dual_slack_variable(
150-
dual_model::JuMP.Model,
151-
primal_ref::JuMP.VariableRef,
152-
)
150+
dual_model::JuMP.GenericModel{T},
151+
primal_ref::JuMP.GenericVariableRef,
152+
) where {T}
153153
map = _get_primal_dual_map(dual_model)
154154
moi_primal_vi = JuMP.index(primal_ref)
155155
moi_dual_vi = Dualization._get_dual_slack_variable(map, moi_primal_vi)
156-
return JuMP.VariableRef(dual_model, moi_dual_vi)
156+
return JuMP.GenericVariableRef{T}(dual_model, moi_dual_vi)
157157
end
158158

159159
end # module DualizationJuMPExt

test/Tests/test_JuMP_dualize.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,24 @@ end
128128
con = Dualization._get_dual_constraint(dual_model, cv)
129129
@test con isa ConstraintRef
130130
end
131+
@testset "GenericModel{$T}" for T in (Float32, BigFloat)
132+
model = JuMP.GenericModel{T}()
133+
JuMP.@variable(model, x >= zero(T))
134+
JuMP.@constraint(model, c, x <= one(T) + one(T))
135+
JuMP.@objective(model, Max, T(2) * x + one(T))
136+
dual_model = Dualization.dualize(
137+
model;
138+
dual_names = DualNames("dual_", "dual_"),
139+
consider_constrained_variables = false,
140+
)
141+
@test dual_model isa JuMP.GenericModel{T}
142+
@test num_variables(dual_model) == 2
143+
con = Dualization._get_dual_constraint(dual_model, x)
144+
@test con[1] isa ConstraintRef
145+
var = Dualization._get_dual_variables(dual_model, c)
146+
@test length(var) == 1
147+
@test var[] isa JuMP.GenericVariableRef{T}
148+
end
131149
@testset "JuMP parametric quadratic" begin
132150
model = Model()
133151
@variable(model, x)

0 commit comments

Comments
 (0)