Skip to content

Commit 96ed8e5

Browse files
authored
Merge pull request #88 from control-toolbox/87-dev-handling-dual-variables
return dual variables
2 parents 5f4db4f + 3bdf848 commit 96ed8e5

3 files changed

Lines changed: 17 additions & 35 deletions

File tree

src/dual_model.jl

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
$(TYPEDSIGNATURES)
88
99
"""
10-
function dual(sol::Solution, model::Model, label::Symbol; bound::Symbol=:lower)
10+
function dual(sol::Solution, model::Model, label::Symbol)
1111

1212
# check if the label is in the path constraints
1313
cp = path_constraints_nl(model)
@@ -45,18 +45,13 @@ function dual(sol::Solution, model::Model, label::Symbol; bound::Symbol=:lower)
4545
if label in labels
4646
# get all the indices of the label
4747
indices = findall(x -> x == label, labels)
48-
# get the corresponding dual values, either lower or upper bound
49-
duals = if bound == :lower
50-
state_constraints_lb_dual(sol)
51-
elseif bound == :upper
52-
state_constraints_ub_dual(sol)
53-
else
54-
return CTBase.IncorrectArgument("bound must be :lower or :upper")
55-
end
48+
# get the corresponding dual values
49+
duals_lb = state_constraints_lb_dual(sol)
50+
duals_ub = state_constraints_ub_dual(sol)
5651
if length(indices) == 1
57-
return t -> duals(t)[indices[1]]
52+
return t -> ( duals_lb(t)[indices[1]] - duals_ub(t)[indices[1]] )
5853
else
59-
return t -> duals(t)[indices]
54+
return t -> ( duals_lb(t)[indices] - duals_ub(t)[indices] )
6055
end
6156
end
6257

@@ -67,17 +62,12 @@ function dual(sol::Solution, model::Model, label::Symbol; bound::Symbol=:lower)
6762
# get all the indices of the label
6863
indices = findall(x -> x == label, labels)
6964
# get the corresponding dual values, either lower or upper bound
70-
duals = if bound == :lower
71-
control_constraints_lb_dual(sol)
72-
elseif bound == :upper
73-
control_constraints_ub_dual(sol)
74-
else
75-
return CTBase.IncorrectArgument("bound must be :lower or :upper")
76-
end
65+
duals_lb = control_constraints_lb_dual(sol)
66+
duals_ub = control_constraints_ub_dual(sol)
7767
if length(indices) == 1
78-
return t -> duals(t)[indices[1]]
68+
return t -> ( duals_lb(t)[indices[1]] - duals_ub(t)[indices[1]] )
7969
else
80-
return t -> duals(t)[indices]
70+
return t -> ( duals_lb(t)[indices] - duals_ub(t)[indices] )
8171
end
8272
end
8373

@@ -88,17 +78,12 @@ function dual(sol::Solution, model::Model, label::Symbol; bound::Symbol=:lower)
8878
# get all the indices of the label
8979
indices = findall(x -> x == label, labels)
9080
# get the corresponding dual values, either lower or upper bound
91-
duals = if bound == :lower
92-
variable_constraints_lb_dual(sol)
93-
elseif bound == :upper
94-
variable_constraints_ub_dual(sol)
95-
else
96-
return CTBase.IncorrectArgument("bound must be :lower or :upper")
97-
end
81+
duals_lb = variable_constraints_lb_dual(sol)
82+
duals_ub = variable_constraints_ub_dual(sol)
9883
if length(indices) == 1
99-
return duals[indices[1]]
84+
return duals_lb[indices[1]] - duals_ub[indices[1]]
10085
else
101-
return duals[indices]
86+
return duals_lb[indices] - duals_ub[indices]
10287
end
10388
end
10489

test/solution_test.jld2

0 Bytes
Binary file not shown.

test/test_solution.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,9 @@ function test_solution()
211211
)
212212
@test CTModels.dual(sol_, ocp, :path)(1) == [5.0, 6.0]
213213
@test CTModels.dual(sol_, ocp, :boundary) == [3.0, 2.0]
214-
@test CTModels.dual(sol_, ocp, :state_rg, bound=:lower)(1) == [5.0, 6.0]
215-
@test CTModels.dual(sol_, ocp, :state_rg, bound=:upper)(1) == -[5.0, 6.0]
216-
@test CTModels.dual(sol_, ocp, :control_rg, bound=:lower)(1) == 3.0
217-
@test CTModels.dual(sol_, ocp, :control_rg, bound=:upper)(1) == -3.0
218-
@test CTModels.dual(sol_, ocp, :variable_rg, bound=:lower) == [1.0, 2.0]
219-
@test CTModels.dual(sol_, ocp, :variable_rg, bound=:upper) == -[1.0, 2.0]
214+
@test CTModels.dual(sol_, ocp, :state_rg)(1) == [5.0, 6.0] - (-[5.0, 6.0])
215+
@test CTModels.dual(sol_, ocp, :control_rg)(1) == 3.0 - (-3.0)
216+
@test CTModels.dual(sol_, ocp, :variable_rg) == [1.0, 2.0] - (-[1.0, 2.0])
220217
end
221218

222219
end

0 commit comments

Comments
 (0)