Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions ext/CTModelsJLD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ $(TYPEDSIGNATURES)
Export an optimal control solution to a `.jld2` file using the JLD2 format.

This function serializes and saves a `CTModels.Solution` object to disk,
allowing it to be reloaded later.
allowing it to be reloaded later. The solution is discretized to avoid
serialization warnings for function objects.

# Arguments
- `::CTModels.JLD2Tag`: A tag used to dispatch the export method for JLD2.
Expand All @@ -25,11 +26,24 @@ julia> using JLD2
julia> export_ocp_solution(JLD2Tag(), sol; filename="mysolution")
# → creates "mysolution.jld2"
```

# Notes
- Functions are discretized on the time grid to avoid JLD2 serialization warnings
- The solution can be perfectly reconstructed via `import_ocp_solution`
- Uses the same discretization logic as JSON export for consistency
"""
function CTModels.export_ocp_solution(
::CTModels.JLD2Tag, sol::CTModels.Solution; filename::String
)
save_object(filename * ".jld2", sol)
# Get the associated OCP model from the solution
ocp = CTModels.model(sol)

# Serialize solution to discrete data
data = CTModels.OCP._serialize_solution(sol, ocp)

# Save both the serialized data and the OCP model
jldsave(filename * ".jld2"; solution_data=data, ocp=ocp)

return nothing
end

Expand All @@ -38,28 +52,70 @@ $(TYPEDSIGNATURES)

Import an optimal control solution from a `.jld2` file.

This function loads a previously saved `CTModels.Solution` from disk.
This function loads a previously saved `CTModels.Solution` from disk and
reconstructs it using `build_solution` from the discretized data.

# Arguments
- `::CTModels.JLD2Tag`: A tag used to dispatch the import method for JLD2.
- `ocp::CTModels.Model`: The associated model (used for dispatch consistency; not used internally).
- `ocp::CTModels.Model`: The associated optimal control problem model.

# Keyword Arguments
- `filename::String = "solution"`: Base name of the file. The `.jld2` extension is automatically appended.

# Returns
- `CTModels.Solution`: The loaded solution object.
- `CTModels.Solution`: The reconstructed solution object.

# Example
```julia-repl
julia> using JLD2
julia> sol = import_ocp_solution(JLD2Tag(), model; filename="mysolution")
```

# Notes
- The solution is reconstructed from discretized data via `build_solution`
- This ensures perfect round-trip consistency with the export
- The OCP model from the file is used if the provided one is not compatible
"""
function CTModels.import_ocp_solution(
::CTModels.JLD2Tag, ocp::CTModels.Model; filename::String
)
return load_object(filename * ".jld2")
# Load the saved data
file_data = load(filename * ".jld2")
data = file_data["solution_data"]
saved_ocp = file_data["ocp"]

# Extract time grid - handle both TimeGridModel and raw Vector
T = if data["time_grid"] isa CTModels.TimeGridModel
data["time_grid"].value
else
data["time_grid"]
end

# Reconstruct solution using build_solution
sol = CTModels.build_solution(
saved_ocp,
T,
data["state"],
data["control"],
data["variable"],
data["costate"];
objective = data["objective"],
iterations = data["iterations"],
constraints_violation = data["constraints_violation"],
message = data["message"],
status = data["status"],
successful = data["successful"],
path_constraints_dual = data["path_constraints_dual"],
boundary_constraints_dual = data["boundary_constraints_dual"],
state_constraints_lb_dual = data["state_constraints_lb_dual"],
state_constraints_ub_dual = data["state_constraints_ub_dual"],
control_constraints_lb_dual = data["control_constraints_lb_dual"],
control_constraints_ub_dual = data["control_constraints_ub_dual"],
variable_constraints_lb_dual = data["variable_constraints_lb_dual"],
variable_constraints_ub_dual = data["variable_constraints_ub_dual"]
)

return sol
end

end
116 changes: 60 additions & 56 deletions ext/CTModelsJSON.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,58 @@ end
"""
$(TYPEDSIGNATURES)

Convert JSON3 array data to `Matrix{Float64}` for trajectory import.

# Context

When importing JSON data, `stack(blob[field]; dims=1)` returns different types
depending on the dimensionality of the original trajectory:
- **1D trajectories** (e.g., scalar control): `stack()` → `Vector{Float64}`
- **Multi-D trajectories** (e.g., 2D state): `stack()` → `Matrix{Float64}`

This function normalizes both cases to `Matrix{Float64}` as required by `build_solution`.

# Arguments
- `data`: Output from `stack(blob[field]; dims=1)`, either `Vector` or `Matrix`

# Returns
- `Matrix{Float64}`: Properly shaped matrix `(n_time_points, n_dim)` for `build_solution`

# Implementation Details

- **Vector case**: Converts `Vector{Float64}` of length `n` to `Matrix{Float64}(n, 1)`
using `reduce(hcat, data)'` to preserve time-series ordering
- **Matrix case**: Direct conversion to `Matrix{Float64}`

# Examples

```julia
# 1D control trajectory (101 time points)
control_data = [5.99, 5.93, ..., -5.99] # Vector{Float64}
control_matrix = _json_array_to_matrix(control_data)
# → Matrix{Float64}(101, 1)

# 2D state trajectory (101 time points, 2 dimensions)
state_data = [1.0 2.0; 1.1 2.1; ...] # Matrix{Float64}(101, 2)
state_matrix = _json_array_to_matrix(state_data)
# → Matrix{Float64}(101, 2)
```

# See Also
- Test coverage: `test/suite/serialization/test_export_import.jl`
(testset "JSON stack() behavior investigation")
"""
function _json_array_to_matrix(data)::Matrix{Float64}
if data isa Vector
return Matrix{Float64}(reduce(hcat, data)')
else
return Matrix{Float64}(data)
end
end

"""
$(TYPEDSIGNATURES)

Import an optimal control solution from a `.json` file exported with `export_ocp_solution`.

This function reads the JSON contents and reconstructs a `CTModels.Solution` object,
Expand Down Expand Up @@ -228,91 +280,43 @@ function CTModels.import_ocp_solution(
blob = JSON3.read(json_string)

# get state
X = stack(blob["state"]; dims=1)
if X isa Vector # if X is a Vector, convert it to a Matrix
X = Matrix{Float64}(reduce(hcat, X)')
else
X = Matrix{Float64}(X)
end
X = _json_array_to_matrix(stack(blob["state"]; dims=1))

# get control
U = stack(blob["control"]; dims=1)
if U isa Vector # if U is a Vector, convert it to a Matrix
U = Matrix{Float64}(reduce(hcat, U)')
else
U = Matrix{Float64}(U)
end
U = _json_array_to_matrix(stack(blob["control"]; dims=1))

# get costate
P = stack(blob["costate"]; dims=1)
if P isa Vector # if P is a Vector, convert it to a Matrix
P = Matrix{Float64}(reduce(hcat, P)')
else
P = Matrix{Float64}(P)
end
P = _json_array_to_matrix(stack(blob["costate"]; dims=1))

# get dual path constraints: convert to matrix
path_constraints_dual = if isnothing(blob["path_constraints_dual"])
nothing
else
stack(blob["path_constraints_dual"]; dims=1)
end
if path_constraints_dual isa Vector # if path_constraints_dual is a Vector, convert it to a Matrix
path_constraints_dual = Matrix{Float64}(reduce(hcat, path_constraints_dual)')
elseif !isnothing(path_constraints_dual)
path_constraints_dual = Matrix{Float64}(path_constraints_dual)
_json_array_to_matrix(stack(blob["path_constraints_dual"]; dims=1))
end

# get state constraints (and dual): convert to matrix
state_constraints_lb_dual = if isnothing(blob["state_constraints_lb_dual"])
nothing
else
stack(blob["state_constraints_lb_dual"]; dims=1)
end
if state_constraints_lb_dual isa Vector # if state_constraints_lb_dual is a Vector, convert it to a Matrix
state_constraints_lb_dual = Matrix{Float64}(
reduce(hcat, state_constraints_lb_dual)'
)
elseif !isnothing(state_constraints_lb_dual)
state_constraints_lb_dual = Matrix{Float64}(state_constraints_lb_dual)
_json_array_to_matrix(stack(blob["state_constraints_lb_dual"]; dims=1))
end
state_constraints_ub_dual = if isnothing(blob["state_constraints_ub_dual"])
nothing
else
stack(blob["state_constraints_ub_dual"]; dims=1)
end
if state_constraints_ub_dual isa Vector # if state_constraints_ub_dual is a Vector, convert it to a Matrix
state_constraints_ub_dual = Matrix{Float64}(
reduce(hcat, state_constraints_ub_dual)'
)
elseif !isnothing(state_constraints_ub_dual)
state_constraints_ub_dual = Matrix{Float64}(state_constraints_ub_dual)
_json_array_to_matrix(stack(blob["state_constraints_ub_dual"]; dims=1))
end

# get control constraints (and dual): convert to matrix
control_constraints_lb_dual = if isnothing(blob["control_constraints_lb_dual"])
nothing
else
stack(blob["control_constraints_lb_dual"]; dims=1)
end
if control_constraints_lb_dual isa Vector # if control_constraints_lb_dual is a Vector, convert it to a Matrix
control_constraints_lb_dual = Matrix{Float64}(
reduce(hcat, control_constraints_lb_dual)'
)
elseif !isnothing(control_constraints_lb_dual)
control_constraints_lb_dual = Matrix{Float64}(control_constraints_lb_dual)
_json_array_to_matrix(stack(blob["control_constraints_lb_dual"]; dims=1))
end
control_constraints_ub_dual = if isnothing(blob["control_constraints_ub_dual"])
nothing
else
stack(blob["control_constraints_ub_dual"]; dims=1)
end
if control_constraints_ub_dual isa Vector # if control_constraints_ub_dual is a Vector, convert it to a Matrix
control_constraints_ub_dual = Matrix{Float64}(
reduce(hcat, control_constraints_ub_dual)'
)
elseif !isnothing(control_constraints_ub_dual)
control_constraints_ub_dual = Matrix{Float64}(control_constraints_ub_dual)
_json_array_to_matrix(stack(blob["control_constraints_ub_dual"]; dims=1))
end

# get dual of boundary constraints: no conversion needed
Expand Down
1 change: 1 addition & 0 deletions reports/2026-01-29_Idempotence/PR_DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ CTModels tests | 1721 1721 14.4s
The analysis identified areas for future investigation:
- Bidirectional `ctinterpolate`/`ctdeinterpolate` for lossless function serialization
- Review of `deepcopy` usage in `build_solution` (rationale unclear)
- Investigation of `isa Vector` checks in JSON deserialization (see [`reports/2026-01-29_Idempotence/analysis/02_vector_conversion_investigation.md`](file:///Users/ocots/Research/logiciels/dev/control-toolbox/CTModels.jl/reports/2026-01-29_Idempotence/analysis/02_vector_conversion_investigation.md))
- Improved JLD2 handling of anonymous functions

See analysis document for details.
Expand Down
Loading
Loading