Skip to content

Commit d4db7c2

Browse files
authored
Merge pull request #72 from control-toolbox/52-dev-export-import-a-solution
json and jld split, dual and cons are saved
2 parents 728796b + 49aadd5 commit d4db7c2

14 files changed

Lines changed: 392 additions & 238 deletions

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
2020
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2121

2222
[extensions]
23-
CTModelsExportImport = ["JLD2", "JSON3"]
23+
CTModelsJLD = "JLD2"
24+
CTModelsJSON = "JSON3"
2425
CTModelsPlots = "Plots"
2526

2627
[compat]
2728
CTBase = "0.16"
2829
DocStringExtensions = "0.9"
2930
Interpolations = "0.15"
3031
JLD2 = "0.5"
31-
JSON3 = "1"
32+
JSON3 = "1.14"
3233
LinearAlgebra = "1"
3334
MLStyle = "0.4"
3435
MacroTools = "0.5"

ext/CTModelsExportImport.jl

Lines changed: 0 additions & 110 deletions
This file was deleted.

ext/CTModelsJLD.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module CTModelsJLD
2+
3+
using CTBase
4+
using CTModels
5+
using DocStringExtensions
6+
7+
using JLD2
8+
9+
"""
10+
$(TYPEDSIGNATURES)
11+
12+
Export OCP solution in JLD format
13+
"""
14+
function CTModels.export_ocp_solution(
15+
::CTModels.JLD2Tag, sol::CTModels.Solution; filename_prefix="solution"
16+
)
17+
save_object(filename_prefix * ".jld2", sol)
18+
return nothing
19+
end
20+
21+
"""
22+
$(TYPEDSIGNATURES)
23+
24+
Read OCP solution in JLD format
25+
"""
26+
function CTModels.import_ocp_solution(
27+
::CTModels.JLD2Tag, ocp::CTModels.Model; filename_prefix="solution"
28+
)
29+
return load_object(filename_prefix * ".jld2")
30+
end
31+
32+
end

ext/CTModelsJSON.jl

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
module CTModelsJSON
2+
3+
using CTBase
4+
using CTModels
5+
using DocStringExtensions
6+
7+
using JSON3
8+
9+
"""
10+
$(TYPEDSIGNATURES)
11+
12+
Export OCP solution in JSON format
13+
"""
14+
function CTModels.export_ocp_solution(
15+
::CTModels.JSON3Tag, sol::CTModels.Solution; filename_prefix="solution"
16+
)
17+
18+
T = CTModels.time_grid(sol)
19+
20+
blob = Dict(
21+
"time_grid" => CTModels.time_grid(sol),
22+
"state" => CTModels.state_discretized(sol),
23+
"control" => CTModels.control_discretized(sol),
24+
"variable" => CTModels.variable(sol),
25+
"costate" => CTModels.costate_discretized(sol)[1:(end - 1), :],
26+
"objective" => CTModels.objective(sol),
27+
"iterations" => CTModels.iterations(sol),
28+
"constraints_violation" => CTModels.constraints_violation(sol),
29+
"message" => CTModels.message(sol),
30+
"stopping" => CTModels.stopping(sol),
31+
"success" => CTModels.success(sol),
32+
"path_constraints" => CTModels.discretize(CTModels.path_constraints(sol), T),
33+
"path_constraints_dual" => CTModels.discretize(CTModels.path_constraints_dual(sol), T),
34+
"state_constraints_lb_dual" => CTModels.discretize(CTModels.state_constraints_lb_dual(sol), T),
35+
"state_constraints_ub_dual" => CTModels.discretize(CTModels.state_constraints_ub_dual(sol), T),
36+
"control_constraints_lb_dual" => CTModels.discretize(CTModels.control_constraints_lb_dual(sol), T),
37+
"control_constraints_ub_dual" => CTModels.discretize(CTModels.control_constraints_ub_dual(sol), T),
38+
"boundary_constraints" => CTModels.boundary_constraints(sol), # ctVector or Nothing
39+
"boundary_constraints_dual" => CTModels.boundary_constraints_dual(sol), # ctVector or Nothing
40+
"variable_constraints_lb_dual" => CTModels.variable_constraints_lb_dual(sol), # ctVector or Nothing
41+
"variable_constraints_ub_dual" => CTModels.variable_constraints_ub_dual(sol), # ctVector or Nothing
42+
)
43+
44+
open(filename_prefix * ".json", "w") do io
45+
JSON3.pretty(io, blob)
46+
end
47+
48+
return nothing
49+
end
50+
51+
"""
52+
$(TYPEDSIGNATURES)
53+
54+
Read OCP solution in JSON format
55+
"""
56+
function CTModels.import_ocp_solution(
57+
::CTModels.JSON3Tag, ocp::CTModels.Model; filename_prefix="solution"
58+
)
59+
60+
json_string = read(filename_prefix * ".json", String)
61+
blob = JSON3.read(json_string)
62+
63+
# get state
64+
X = stack(blob["state"]; dims=1)
65+
if X isa Vector # if X is a Vector, convert it to a Matrix
66+
X = Matrix{Float64}(reduce(hcat, X)')
67+
end
68+
69+
# get control
70+
U = stack(blob["control"]; dims=1)
71+
if U isa Vector # if U is a Vector, convert it to a Matrix
72+
U = Matrix{Float64}(reduce(hcat, U)')
73+
end
74+
75+
# get costate
76+
P = stack(blob["costate"]; dims=1)
77+
if P isa Vector # if P is a Vector, convert it to a Matrix
78+
P = Matrix{Float64}(reduce(hcat, P)')
79+
end
80+
81+
# get path constraints (and dual): convert to matrix
82+
path_constraints = isnothing(blob["path_constraints"]) ? nothing : stack(blob["path_constraints"]; dims=1)
83+
if path_constraints isa Vector # if path_constraints is a Vector, convert it to a Matrix
84+
path_constraints = Matrix{Float64}(reduce(hcat, path_constraints)')
85+
end
86+
path_constraints_dual = isnothing(blob["path_constraints_dual"]) ? nothing : stack(blob["path_constraints_dual"]; dims=1)
87+
if path_constraints_dual isa Vector # if path_constraints_dual is a Vector, convert it to a Matrix
88+
path_constraints_dual = Matrix{Float64}(reduce(hcat, path_constraints_dual)')
89+
end
90+
91+
# get state constraints (and dual): convert to matrix
92+
state_constraints_lb_dual = isnothing(blob["state_constraints_lb_dual"]) ? nothing : stack(blob["state_constraints_lb_dual"]; dims=1)
93+
if state_constraints_lb_dual isa Vector # if state_constraints_lb_dual is a Vector, convert it to a Matrix
94+
state_constraints_lb_dual = Matrix{Float64}(reduce(hcat, state_constraints_lb_dual)')
95+
end
96+
state_constraints_ub_dual = isnothing(blob["state_constraints_ub_dual"]) ? nothing : stack(blob["state_constraints_ub_dual"]; dims=1)
97+
if state_constraints_ub_dual isa Vector # if state_constraints_ub_dual is a Vector, convert it to a Matrix
98+
state_constraints_ub_dual = Matrix{Float64}(reduce(hcat, state_constraints_ub_dual)')
99+
end
100+
101+
# get control constraints (and dual): convert to matrix
102+
control_constraints_lb_dual = isnothing(blob["control_constraints_lb_dual"]) ? nothing : stack(blob["control_constraints_lb_dual"]; dims=1)
103+
if control_constraints_lb_dual isa Vector # if control_constraints_lb_dual is a Vector, convert it to a Matrix
104+
control_constraints_lb_dual = Matrix{Float64}(reduce(hcat, control_constraints_lb_dual)')
105+
end
106+
control_constraints_ub_dual = isnothing(blob["control_constraints_ub_dual"]) ? nothing : stack(blob["control_constraints_ub_dual"]; dims=1)
107+
if control_constraints_ub_dual isa Vector # if control_constraints_ub_dual is a Vector, convert it to a Matrix
108+
control_constraints_ub_dual = Matrix{Float64}(reduce(hcat, control_constraints_ub_dual)')
109+
end
110+
111+
# get boundary constraints (and dual): no conversion needed
112+
boundary_constraints = blob["boundary_constraints"]
113+
boundary_constraints_dual = blob["boundary_constraints_dual"]
114+
115+
# get variable constraints dual: no conversion needed
116+
variable_constraints_lb_dual = blob["variable_constraints_lb_dual"]
117+
variable_constraints_ub_dual = blob["variable_constraints_ub_dual"]
118+
119+
# NB. convert vect{vect} to matrix
120+
return CTModels.build_solution(
121+
ocp,
122+
Vector{Float64}(blob.time_grid),
123+
X,
124+
U,
125+
Vector{Float64}(blob.variable),
126+
P;
127+
objective=Float64(blob.objective),
128+
iterations=blob.iterations,
129+
constraints_violation=Float64(blob.constraints_violation),
130+
message=blob.message,
131+
stopping=Symbol(blob.stopping),
132+
success=blob.success,
133+
path_constraints=path_constraints,
134+
path_constraints_dual=path_constraints_dual,
135+
state_constraints_lb_dual=state_constraints_lb_dual,
136+
state_constraints_ub_dual=state_constraints_ub_dual,
137+
control_constraints_lb_dual=control_constraints_lb_dual,
138+
control_constraints_ub_dual=control_constraints_ub_dual,
139+
boundary_constraints=boundary_constraints,
140+
boundary_constraints_dual=boundary_constraints_dual,
141+
variable_constraints_lb_dual=variable_constraints_lb_dual,
142+
variable_constraints_ub_dual=variable_constraints_ub_dual,
143+
)
144+
145+
end
146+
147+
end

src/CTModels.jl

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,49 @@ const ConstraintsDictType = Dict{
2929
const Times = AbstractVector{<:Time}
3030
const TimesDisc = Union{Times,StepRangeLen}
3131

32+
#
33+
include("default.jl")
34+
35+
# export / import
36+
abstract type AbstractTag end
37+
struct JLD2Tag <: AbstractTag end
38+
struct JSON3Tag <: AbstractTag end
39+
3240
# to be extended
33-
export_ocp_solution(args...; kwargs...) = throw(CTBase.ExtensionError(:JLD2, :JSON3))
34-
import_ocp_solution(args...; kwargs...) = throw(CTBase.ExtensionError(:JLD2, :JSON3))
41+
export_ocp_solution(::JLD2Tag, args...; kwargs...) = throw(CTBase.ExtensionError(:JLD2))
42+
import_ocp_solution(::JLD2Tag, args...; kwargs...) = throw(CTBase.ExtensionError(:JLD2))
43+
export_ocp_solution(::JSON3Tag, args...; kwargs...) = throw(CTBase.ExtensionError(:JSON3))
44+
import_ocp_solution(::JSON3Tag, args...; kwargs...) = throw(CTBase.ExtensionError(:JSON3))
45+
46+
function export_ocp_solution(args...; format=__format(), kwargs...)
47+
if format == :JLD
48+
return export_ocp_solution(JLD2Tag(), args...; kwargs...)
49+
elseif format == :JSON
50+
return export_ocp_solution(JSON3Tag(), args...; kwargs...)
51+
else
52+
throw(
53+
CTBase.IncorrectArgument(
54+
"Export_ocp_solution: unknown format (should be :JLD or :JSON): ", format
55+
),
56+
)
57+
end
58+
end
59+
60+
function import_ocp_solution(args...; format=__format(), kwargs...)
61+
if format == :JLD
62+
return import_ocp_solution(JLD2Tag(), args...; kwargs...)
63+
elseif format == :JSON
64+
return import_ocp_solution(JSON3Tag(), args...; kwargs...)
65+
else
66+
throw(
67+
CTBase.IncorrectArgument(
68+
"Import_ocp_solution: unknown format (should be :JLD or :JSON): ", format
69+
),
70+
)
71+
end
72+
end
3573

3674
#
37-
include("default.jl")
3875
include("utils.jl")
3976
include("types.jl")
4077

src/constraints.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,17 @@ function constraint!(
211211
)
212212
end
213213

214-
as_vector(x::Nothing) = nothing
214+
as_vector(::Nothing) = nothing
215215
(as_vector(x::T)::Vector{T}) where {T<:ctNumber} = [x]
216216
as_vector(x::Vector{T}) where {T<:ctNumber} = x
217217

218-
as_range(r::Nothing) = nothing
218+
as_range(::Nothing) = nothing
219219
as_range(r::T) where {T<:Int} = r:r
220220
as_range(r::OrdinalRange{T}) where {T<:Int} = r
221221

222+
discretize(constraint::Function, grid::Vector{T}) where {T<:ctNumber} = constraint.(grid)
223+
discretize(::Nothing, grid::Vector{T}) where {T<:ctNumber} = nothing
224+
222225
# ------------------------------------------------------------------------------ #
223226
# GETTERS
224227
# ------------------------------------------------------------------------------ #

0 commit comments

Comments
 (0)