Skip to content

Commit bd6a561

Browse files
committed
fix constraint
1 parent b866d16 commit bd6a561

2 files changed

Lines changed: 124 additions & 43 deletions

File tree

src/constraints.jl

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,14 @@ end
342342
"""
343343
$(TYPEDSIGNATURES)
344344
345-
Get a labelled constraint from the model.
345+
Get a labelled constraint from the model. Returns a tuple of the form
346+
`(type, f, lb, ub)` where `type` is the type of the constraint, `f` is the function
347+
of the constraint, `lb` is the lower bound of the constraint and `ub` is the upper
348+
bound of the constraint.
349+
350+
The function returns an exception if the label is not found in the model.
346351
"""
347-
function constraint(model::Model, label::Symbol)::Function # not type stable
352+
function constraint(model::Model, label::Symbol)::Tuple # not type stable
348353

349354
# check if the label is in the path constraints
350355
cp = path_constraints_nl(model)
@@ -357,7 +362,11 @@ function constraint(model::Model, label::Symbol)::Function # not type stable
357362
cp[2](r_, t, x, u, v)
358363
r .= r_[indices]
359364
end
360-
return to_out_of_place(fc!, length(indices))
365+
return (:path, # type of the constraint
366+
to_out_of_place(fc!, length(indices)), # function
367+
length(indices)==1 ? cp[1][indices[1]] : cp[1][indices], # lower bound
368+
length(indices)==1 ? cp[3][indices[1]] : cp[3][indices], # upper bound
369+
)
361370
end
362371

363372
# check if the label is in the boundary constraints
@@ -371,58 +380,83 @@ function constraint(model::Model, label::Symbol)::Function # not type stable
371380
cp[2](r_, x0, xf, v)
372381
r .= r_[indices]
373382
end
374-
return to_out_of_place(fc!, length(indices))
383+
return (:boundary, # type of the constraint
384+
to_out_of_place(fc!, length(indices)),
385+
length(indices)==1 ? cp[1][indices[1]] : cp[1][indices], # lower bound
386+
length(indices)==1 ? cp[3][indices[1]] : cp[3][indices], # upper bound
387+
)
375388
end
376389

377390
# check if the label is in the state constraints
378391
cp = state_constraints_box(model)
379392
labels = cp[4] # vector of labels
380393
if label in labels
381394
# get all the indices of the label
382-
indices = Int[]
395+
indices_state = Int[]
396+
indices_bound = Int[]
383397
for i in eachindex(labels)
384398
if labels[i] == label
385-
push!(indices, cp[2][i])
399+
push!(indices_state, cp[2][i])
400+
push!(indices_bound, i)
386401
end
387402
end
388403
fc = (t, x, u, v) -> begin
389-
length(indices) == 1 ? x[indices[1]] : x[indices]
404+
length(indices_state) == 1 ? x[indices_state[1]] : x[indices_state]
390405
end
391-
return fc
406+
return (:state, # type of the constraint
407+
fc,
408+
length(indices_bound)==1 ? cp[1][indices_bound[1]] : cp[1][indices_bound], # lower bound
409+
length(indices_bound)==1 ? cp[3][indices_bound[1]] : cp[3][indices_bound], # upper bound
410+
)
392411
end
393412

394413
# check if the label is in the control constraints
395414
cp = control_constraints_box(model)
396415
labels = cp[4] # vector of labels
397416
if label in labels
398417
# get all the indices of the label
399-
indices = Int[]
418+
indices_state = Int[]
419+
indices_bound = Int[]
400420
for i in eachindex(labels)
401421
if labels[i] == label
402-
push!(indices, cp[2][i])
422+
push!(indices_state, cp[2][i])
423+
push!(indices_bound, i)
403424
end
404425
end
405426
fc = (t, x, u, v) -> begin
406-
length(indices) == 1 ? u[indices[1]] : u[indices]
427+
length(indices_state) == 1 ? u[indices_state[1]] : u[indices_state]
407428
end
408-
return fc
429+
return (:control, # type of the constraint
430+
fc,
431+
length(indices_bound)==1 ? cp[1][indices_bound[1]] : cp[1][indices_bound], # lower bound
432+
length(indices_bound)==1 ? cp[3][indices_bound[1]] : cp[3][indices_bound], # upper bound
433+
)
409434
end
410435

411436
# check if the label is in the variable constraints
412437
cp = variable_constraints_box(model)
413438
labels = cp[4] # vector of labels
414439
if label in labels
415440
# get all the indices of the label
416-
indices = Int[]
441+
indices_state = Int[]
442+
indices_bound = Int[]
417443
for i in eachindex(labels)
418444
if labels[i] == label
419-
push!(indices, cp[2][i])
445+
push!(indices_state, cp[2][i])
446+
push!(indices_bound, i)
420447
end
421448
end
422-
fc = (t, x, u, v) -> begin
423-
length(indices) == 1 ? v[indices[1]] : v[indices]
449+
fc = (x0, xf, v) -> begin
450+
length(indices_state) == 1 ? v[indices_state[1]] : v[indices_state]
424451
end
425-
return fc
452+
return (:variable, # type of the constraint
453+
fc,
454+
length(indices_bound)==1 ? cp[1][indices_bound[1]] : cp[1][indices_bound], # lower bound
455+
length(indices_bound)==1 ? cp[3][indices_bound[1]] : cp[3][indices_bound], # upper bound
456+
)
426457
end
458+
459+
# return an exception if the label is not found
460+
return CTBase.IncorrectArgument("Label $label not found in the model.")
427461

428462
end

test/test_model.jl

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,22 @@ function test_model()
5858
f_path(r, t, x, u, v) = r .= x .+ u .+ v .+ t
5959
f_boundary(r, x0, xf, v) = r .= x0 .+ v .* (xf .- x0)
6060

61-
CTModels.constraint!(pre_ocp, :path; f=f_path, lb=[0, 1], ub=[1, 2], label=:path)
62-
CTModels.constraint!(pre_ocp, :boundary; f=f_boundary, lb=[0, 1], ub=[1, 2], label=:boundary)
63-
CTModels.constraint!(pre_ocp, :state; rg=1:2, lb=[0, 1], ub=[1, 2], label=:state)
64-
CTModels.constraint!(pre_ocp, :control; rg=1:2, lb=[0, 1], ub=[1, 2], label=:control)
65-
CTModels.constraint!(pre_ocp, :variable; rg=1:2, lb=[0, 1], ub=[1, 2], label=:variable)
61+
CTModels.constraint!(pre_ocp, :path; f=f_path, lb=[-0, -1], ub=[1, 2], label=:path)
62+
CTModels.constraint!(pre_ocp, :boundary; f=f_boundary, lb=[-2, -3], ub=[3, 4], label=:boundary)
63+
CTModels.constraint!(pre_ocp, :state; rg=1:2, lb=[-4, -5], ub=[5, 6], label=:state)
64+
CTModels.constraint!(pre_ocp, :control; rg=1:2, lb=[-6, -7], ub=[7, 8], label=:control)
65+
CTModels.constraint!(pre_ocp, :variable; rg=1:2, lb=[-8, -9], ub=[9, 10], label=:variable)
6666

6767
f_path_scalar(r, t, x, u, v) = r .= x[1] + u[1] + v[1] + t
6868
f_boundary_scalar(r, x0, xf, v) = r .= x0[1] + v[1] * (xf[1] - x0[1])
69-
CTModels.constraint!(pre_ocp, :path; f=f_path_scalar, lb=0, ub=1, label=:path_scalar)
70-
CTModels.constraint!(pre_ocp, :boundary; f=f_boundary_scalar, lb=0, ub=1, label=:boundary_scalar)
71-
CTModels.constraint!(pre_ocp, :state; rg=1, lb=0, ub=1, label=:state_scalar)
72-
CTModels.constraint!(pre_ocp, :control; rg=1, lb=0, ub=1, label=:control_scalar)
73-
CTModels.constraint!(pre_ocp, :variable; rg=1, lb=0, ub=1, label=:variable_scalar)
74-
CTModels.constraint!(pre_ocp, :state; rg=2, lb=0, ub=1, label=:state_scalar_2)
75-
CTModels.constraint!(pre_ocp, :control; rg=2, lb=0, ub=1, label=:control_scalar_2)
76-
CTModels.constraint!(pre_ocp, :variable; rg=2, lb=0, ub=1, label=:variable_scalar_2)
69+
CTModels.constraint!(pre_ocp, :path; f=f_path_scalar, lb=-10, ub=11, label=:path_scalar)
70+
CTModels.constraint!(pre_ocp, :boundary; f=f_boundary_scalar, lb=-11, ub=12, label=:boundary_scalar)
71+
CTModels.constraint!(pre_ocp, :state; rg=1, lb=-12, ub=13, label=:state_scalar)
72+
CTModels.constraint!(pre_ocp, :control; rg=1, lb=-13, ub=14, label=:control_scalar)
73+
CTModels.constraint!(pre_ocp, :variable; rg=1, lb=-14, ub=15, label=:variable_scalar)
74+
CTModels.constraint!(pre_ocp, :state; rg=2, lb=-15, ub=16, label=:state_scalar_2)
75+
CTModels.constraint!(pre_ocp, :control; rg=2, lb=-16, ub=17, label=:control_scalar_2)
76+
CTModels.constraint!(pre_ocp, :variable; rg=2, lb=-17, ub=18, label=:variable_scalar_2)
7777

7878
# build the model
7979
model = CTModels.build_model(pre_ocp)
@@ -88,19 +88,66 @@ function test_model()
8888
v = [6, 7]
8989
x0 = [1, 2]
9090
xf = [3, 4]
91-
@test CTModels.constraint(model, :path)(t, x, u, v) == x .+ u .+ v .+ t
92-
@test CTModels.constraint(model, :boundary)(x0, xf, v) == x0 .+ v .* (xf .- x0)
93-
@test CTModels.constraint(model, :state)(t, x, u, v) == x
94-
@test CTModels.constraint(model, :control)(t, x, u, v) == u
95-
@test CTModels.constraint(model, :variable)(t, x, u, v) == v
96-
@test CTModels.constraint(model, :path_scalar)(t, x, u, v) == x[1] + u[1] + v[1] + t
97-
@test CTModels.constraint(model, :boundary_scalar)(x0, xf, v) == x0[1] + v[1] * (xf[1] - x0[1])
98-
@test CTModels.constraint(model, :state_scalar)(t, x, u, v) == x[1]
99-
@test CTModels.constraint(model, :control_scalar)(t, x, u, v) == u[1]
100-
@test CTModels.constraint(model, :variable_scalar)(t, x, u, v) == v[1]
101-
@test CTModels.constraint(model, :state_scalar_2)(t, x, u, v) == x[2]
102-
@test CTModels.constraint(model, :control_scalar_2)(t, x, u, v) == u[2]
103-
@test CTModels.constraint(model, :variable_scalar_2)(t, x, u, v) == v[2]
91+
92+
# test the functions
93+
@test CTModels.constraint(model, :path)[2](t, x, u, v) == x .+ u .+ v .+ t
94+
@test CTModels.constraint(model, :boundary)[2](x0, xf, v) == x0 .+ v .* (xf .- x0)
95+
@test CTModels.constraint(model, :state)[2](t, x, u, v) == x
96+
@test CTModels.constraint(model, :control)[2](t, x, u, v) == u
97+
@test CTModels.constraint(model, :variable)[2](x0, xf, v) == v
98+
@test CTModels.constraint(model, :path_scalar)[2](t, x, u, v) == x[1] + u[1] + v[1] + t
99+
@test CTModels.constraint(model, :boundary_scalar)[2](x0, xf, v) == x0[1] + v[1] * (xf[1] - x0[1])
100+
@test CTModels.constraint(model, :state_scalar)[2](t, x, u, v) == x[1]
101+
@test CTModels.constraint(model, :control_scalar)[2](t, x, u, v) == u[1]
102+
@test CTModels.constraint(model, :variable_scalar)[2](x0, xf, v) == v[1]
103+
@test CTModels.constraint(model, :state_scalar_2)[2](t, x, u, v) == x[2]
104+
@test CTModels.constraint(model, :control_scalar_2)[2](t, x, u, v) == u[2]
105+
@test CTModels.constraint(model, :variable_scalar_2)[2](x0, xf, v) == v[2]
106+
107+
# test the type of the constraints
108+
@test CTModels.constraint(model, :path)[1] == :path
109+
@test CTModels.constraint(model, :boundary)[1] == :boundary
110+
@test CTModels.constraint(model, :state)[1] == :state
111+
@test CTModels.constraint(model, :control)[1] == :control
112+
@test CTModels.constraint(model, :variable)[1] == :variable
113+
@test CTModels.constraint(model, :path_scalar)[1] == :path
114+
@test CTModels.constraint(model, :boundary_scalar)[1] == :boundary
115+
@test CTModels.constraint(model, :state_scalar)[1] == :state
116+
@test CTModels.constraint(model, :control_scalar)[1] == :control
117+
@test CTModels.constraint(model, :variable_scalar)[1] == :variable
118+
@test CTModels.constraint(model, :state_scalar_2)[1] == :state
119+
@test CTModels.constraint(model, :control_scalar_2)[1] == :control
120+
@test CTModels.constraint(model, :variable_scalar_2)[1] == :variable
121+
122+
# test the lower bounds
123+
@test CTModels.constraint(model, :path)[3] == [-0, -1]
124+
@test CTModels.constraint(model, :boundary)[3] == [-2, -3]
125+
@test CTModels.constraint(model, :state)[3] == [-4, -5]
126+
@test CTModels.constraint(model, :control)[3] == [-6, -7]
127+
@test CTModels.constraint(model, :variable)[3] == [-8, -9]
128+
@test CTModels.constraint(model, :path_scalar)[3] == -10
129+
@test CTModels.constraint(model, :boundary_scalar)[3] == -11
130+
@test CTModels.constraint(model, :state_scalar)[3] == -12
131+
@test CTModels.constraint(model, :control_scalar)[3] == -13
132+
@test CTModels.constraint(model, :variable_scalar)[3] == -14
133+
@test CTModels.constraint(model, :state_scalar_2)[3] == -15
134+
@test CTModels.constraint(model, :control_scalar_2)[3] == -16
135+
@test CTModels.constraint(model, :variable_scalar_2)[3] == -17
136+
137+
# test the upper bounds
138+
@test CTModels.constraint(model, :path)[4] == [1, 2]
139+
@test CTModels.constraint(model, :boundary)[4] == [3, 4]
140+
@test CTModels.constraint(model, :state)[4] == [5, 6]
141+
@test CTModels.constraint(model, :control)[4] == [7, 8]
142+
@test CTModels.constraint(model, :variable)[4] == [9, 10]
143+
@test CTModels.constraint(model, :path_scalar)[4] == 11
144+
@test CTModels.constraint(model, :boundary_scalar)[4] == 12
145+
@test CTModels.constraint(model, :state_scalar)[4] == 13
146+
@test CTModels.constraint(model, :control_scalar)[4] == 14
147+
@test CTModels.constraint(model, :variable_scalar)[4] == 15
148+
@test CTModels.constraint(model, :state_scalar_2)[4] == 16
149+
@test CTModels.constraint(model, :control_scalar_2)[4] == 17
150+
@test CTModels.constraint(model, :variable_scalar_2)[4] == 18
104151

105152
# print the premodel
106153
display(pre_ocp)

0 commit comments

Comments
 (0)