-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathCTModelsJSON.jl
More file actions
373 lines (320 loc) · 13 KB
/
CTModelsJSON.jl
File metadata and controls
373 lines (320 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
module CTModelsJSON
using CTModels
using DocStringExtensions
using JSON3
# ============================================================================
# Private helper: broadcast with Nothing fallback
# ============================================================================
"""
Apply a function over a grid (broadcast), or return nothing if input is nothing.
"""
_apply_over_grid(f::Function, grid) = f.(grid)
_apply_over_grid(::Nothing, grid) = nothing
# ============================================================================
# Helper functions for serializing/deserializing infos Dict{Symbol,Any}
# ============================================================================
"""
Convert Dict{Symbol,Any} to Dict{String,Any} for JSON serialization.
Only serializes JSON-compatible types (numbers, strings, bools, arrays, dicts).
Returns a tuple: (serialized_dict, symbol_keys) where symbol_keys tracks which values were Symbols.
"""
function _serialize_infos(infos::Dict{Symbol,Any})::Tuple{Dict{String,Any},Vector{String}}
result = Dict{String,Any}()
symbol_keys = String[]
for (k, v) in infos
key_str = string(k)
serialized_value, nested_symbols = _serialize_value(v, key_str)
result[key_str] = serialized_value
append!(symbol_keys, nested_symbols)
end
return (result, symbol_keys)
end
"""
Serialize a single value to JSON-compatible format.
Returns a tuple: (serialized_value, symbol_paths) where symbol_paths tracks Symbol locations.
"""
function _serialize_value(v, path::String="")
if v isa Number || v isa String || v isa Bool || isnothing(v)
return (v, String[])
elseif v isa Symbol
# Mark this path as containing a Symbol
return (string(v), [path])
elseif v isa AbstractVector
serialized = []
all_symbols = String[]
for (i, x) in enumerate(v)
val, syms = _serialize_value(x, "$(path)[$(i-1)]")
push!(serialized, val)
append!(all_symbols, syms)
end
return (serialized, all_symbols)
elseif v isa AbstractDict
result = Dict{String,Any}()
all_symbols = String[]
for (dk, dv) in v
key_str = string(dk)
new_path = isempty(path) ? key_str : "$(path).$(key_str)"
val, syms = _serialize_value(dv, new_path)
result[key_str] = val
append!(all_symbols, syms)
end
return (result, all_symbols)
else
# For non-serializable types, convert to string representation
return (string(v), String[])
end
end
"""
Convert Dict{String,Any} back to Dict{Symbol,Any} after JSON deserialization.
Uses symbol_keys metadata to restore Symbol types where they were originally present.
"""
function _deserialize_infos(
blob, symbol_keys::Vector{String}=String[]
)::Dict{Symbol,Any}
if isnothing(blob) || isempty(blob)
return Dict{Symbol,Any}()
end
result = Dict{Symbol,Any}()
for (k, v) in blob
result[Symbol(k)] = _deserialize_value(v, String(k), symbol_keys)
end
return result
end
"""
Deserialize a single value from JSON format.
Uses symbol_keys to restore Symbol types at the correct paths.
"""
function _deserialize_value(v, path::String, symbol_keys::Vector{String})
if v isa Number || v isa Bool || isnothing(v)
return v
elseif v isa String
# Check if this path should be a Symbol
if path in symbol_keys
return Symbol(v)
else
return v
end
elseif v isa AbstractVector
return [
_deserialize_value(x, "$(path)[$(i-1)]", symbol_keys) for
(i, x) in enumerate(v)
]
elseif v isa AbstractDict
result = Dict{Symbol,Any}()
for (dk, dv) in v
key_str = string(dk)
new_path = isempty(path) ? key_str : "$(path).$(key_str)"
result[Symbol(dk)] = _deserialize_value(dv, new_path, symbol_keys)
end
return result
else
return v
end
end
# ============================================================================
# Export function
# ============================================================================
"""
$(TYPEDSIGNATURES)
Export an optimal control solution to a `.json` file using the JSON3 format.
This function serializes a `CTModels.Solution` into a structured JSON dictionary,
including all primal and dual information, which can be read by external tools.
# Arguments
- `::CTModels.JSON3Tag`: A tag used to dispatch the export method for JSON3.
- `sol::CTModels.Solution`: The solution to be saved.
# Keyword Arguments
- `filename::String = "solution"`: Base filename. The `.json` extension is automatically appended.
# Notes
The exported JSON includes the time grid, state, control, costate, objective, solver info, and all constraint duals (if available).
# Example
```julia-repl
julia> using JSON3
julia> export_ocp_solution(JSON3Tag(), sol; filename="mysolution")
# → creates "mysolution.json"
```
"""
function CTModels.export_ocp_solution(
::CTModels.JSON3Tag, sol::CTModels.Solution; filename::String
)
T = CTModels.time_grid(sol)
blob = Dict(
"time_grid" => CTModels.time_grid(sol),
"state" => _apply_over_grid(CTModels.state(sol), T),
"control" => _apply_over_grid(CTModels.control(sol), T),
"variable" => CTModels.variable(sol),
"costate" => _apply_over_grid(CTModels.costate(sol), T),
"objective" => CTModels.objective(sol),
"iterations" => CTModels.iterations(sol),
"constraints_violation" => CTModels.constraints_violation(sol),
"message" => CTModels.message(sol),
"status" => CTModels.status(sol),
"successful" => CTModels.successful(sol),
"path_constraints_dual" =>
_apply_over_grid(CTModels.path_constraints_dual(sol), T),
"state_constraints_lb_dual" =>
_apply_over_grid(CTModels.state_constraints_lb_dual(sol), T),
"state_constraints_ub_dual" =>
_apply_over_grid(CTModels.state_constraints_ub_dual(sol), T),
"control_constraints_lb_dual" =>
_apply_over_grid(CTModels.control_constraints_lb_dual(sol), T),
"control_constraints_ub_dual" =>
_apply_over_grid(CTModels.control_constraints_ub_dual(sol), T),
"boundary_constraints_dual" => CTModels.boundary_constraints_dual(sol), # ctVector or Nothing
"variable_constraints_lb_dual" => CTModels.variable_constraints_lb_dual(sol), # ctVector or Nothing
"variable_constraints_ub_dual" => CTModels.variable_constraints_ub_dual(sol), # ctVector or Nothing
)
# Serialize infos and get Symbol type metadata
infos_serialized, symbol_keys = _serialize_infos(CTModels.infos(sol))
blob["infos"] = infos_serialized
blob["infos_symbol_keys"] = symbol_keys
open(filename * ".json", "w") do io
JSON3.pretty(io, blob)
end
return nothing
end
"""
$(TYPEDSIGNATURES)
Convert JSON3 array data to `Matrix{Float64}` for trajectory import.
# Context
When importing JSON data, `stack(blob[field]; dims=1)` returns different types
depending on the dimensionality of the original trajectory:
- **1D trajectories** (e.g., scalar control): `stack()` → `Vector{Float64}`
- **Multi-D trajectories** (e.g., 2D state): `stack()` → `Matrix{Float64}`
This function normalizes both cases to `Matrix{Float64}` as required by `build_solution`.
# Arguments
- `data`: Output from `stack(blob[field]; dims=1)`, either `Vector` or `Matrix`
# Returns
- `Matrix{Float64}`: Properly shaped matrix `(n_time_points, n_dim)` for `build_solution`
# Implementation Details
- **Vector case**: Converts `Vector{Float64}` of length `n` to `Matrix{Float64}(n, 1)`
using `reduce(hcat, data)'` to preserve time-series ordering
- **Matrix case**: Direct conversion to `Matrix{Float64}`
# Examples
```julia
# 1D control trajectory (101 time points)
control_data = [5.99, 5.93, ..., -5.99] # Vector{Float64}
control_matrix = _json_array_to_matrix(control_data)
# → Matrix{Float64}(101, 1)
# 2D state trajectory (101 time points, 2 dimensions)
state_data = [1.0 2.0; 1.1 2.1; ...] # Matrix{Float64}(101, 2)
state_matrix = _json_array_to_matrix(state_data)
# → Matrix{Float64}(101, 2)
```
# See Also
- Test coverage: `test/suite/serialization/test_export_import.jl`
(testset "JSON stack() behavior investigation")
"""
function _json_array_to_matrix(data)::Matrix{Float64}
if data isa Vector
return Matrix{Float64}(reduce(hcat, data)')
else
return Matrix{Float64}(data)
end
end
"""
$(TYPEDSIGNATURES)
Import an optimal control solution from a `.json` file exported with `export_ocp_solution`.
This function reads the JSON contents and reconstructs a `CTModels.Solution` object,
including the discretized primal and dual trajectories.
# Arguments
- `::CTModels.JSON3Tag`: A tag used to dispatch the import method for JSON3.
- `ocp::CTModels.Model`: The model associated with the optimal control problem. Used to rebuild the full solution.
# Keyword Arguments
- `filename::String = "solution"`: Base filename. The `.json` extension is automatically appended.
# Returns
- `CTModels.Solution`: A reconstructed solution instance.
# Notes
Handles both vector and matrix encodings of signals. If dual fields are missing or `null`, the corresponding attributes are set to `nothing`.
# Example
```julia-repl
julia> using JSON3
julia> sol = import_ocp_solution(JSON3Tag(), model; filename="mysolution")
```
"""
function CTModels.import_ocp_solution(
::CTModels.JSON3Tag, ocp::CTModels.Model; filename::String
)
json_string = read(filename * ".json", String)
blob = JSON3.read(json_string)
# get state
X = _json_array_to_matrix(stack(blob["state"]; dims=1))
# get control
U = _json_array_to_matrix(stack(blob["control"]; dims=1))
# get costate
P = _json_array_to_matrix(stack(blob["costate"]; dims=1))
# get dual path constraints: convert to matrix
path_constraints_dual = if isnothing(blob["path_constraints_dual"])
nothing
else
_json_array_to_matrix(stack(blob["path_constraints_dual"]; dims=1))
end
# get state constraints (and dual): convert to matrix
state_constraints_lb_dual = if isnothing(blob["state_constraints_lb_dual"])
nothing
else
_json_array_to_matrix(stack(blob["state_constraints_lb_dual"]; dims=1))
end
state_constraints_ub_dual = if isnothing(blob["state_constraints_ub_dual"])
nothing
else
_json_array_to_matrix(stack(blob["state_constraints_ub_dual"]; dims=1))
end
# get control constraints (and dual): convert to matrix
control_constraints_lb_dual = if isnothing(blob["control_constraints_lb_dual"])
nothing
else
_json_array_to_matrix(stack(blob["control_constraints_lb_dual"]; dims=1))
end
control_constraints_ub_dual = if isnothing(blob["control_constraints_ub_dual"])
nothing
else
_json_array_to_matrix(stack(blob["control_constraints_ub_dual"]; dims=1))
end
# get dual of boundary constraints: no conversion needed
boundary_constraints_dual = blob["boundary_constraints_dual"]
if !isnothing(boundary_constraints_dual)
boundary_constraints_dual = Vector{Float64}(boundary_constraints_dual)
end
# get variable constraints dual: no conversion needed
variable_constraints_lb_dual = blob["variable_constraints_lb_dual"]
if !isnothing(variable_constraints_lb_dual)
variable_constraints_lb_dual = Vector{Float64}(blob["variable_constraints_lb_dual"])
end
variable_constraints_ub_dual = blob["variable_constraints_ub_dual"]
if !isnothing(variable_constraints_ub_dual)
variable_constraints_ub_dual = Vector{Float64}(blob["variable_constraints_ub_dual"])
end
# get additional solver infos with Symbol type restoration
symbol_keys_raw = get(blob, "infos_symbol_keys", String[])
symbol_keys = collect(String, symbol_keys_raw) # Convert JSON3.Array/empty array to Vector{String}
infos = if haskey(blob, "infos")
_deserialize_infos(blob["infos"], symbol_keys)
else
Dict{Symbol,Any}()
end
# NB. convert vect{vect} to matrix
return CTModels.build_solution(
ocp,
Vector{Float64}(blob.time_grid),
X,
U,
Vector{Float64}(blob.variable),
P;
objective=Float64(blob.objective),
iterations=blob.iterations,
constraints_violation=Float64(blob.constraints_violation),
message=blob.message,
status=Symbol(blob.status),
successful=blob.successful,
path_constraints_dual=path_constraints_dual,
state_constraints_lb_dual=state_constraints_lb_dual,
state_constraints_ub_dual=state_constraints_ub_dual,
control_constraints_lb_dual=control_constraints_lb_dual,
control_constraints_ub_dual=control_constraints_ub_dual,
boundary_constraints_dual=boundary_constraints_dual,
variable_constraints_lb_dual=variable_constraints_lb_dual,
variable_constraints_ub_dual=variable_constraints_ub_dual,
infos=infos,
)
end
end