diff --git a/.github/workflows/Breakage.yml b/.github/workflows/Breakage.yml index 9a4020cf..8e1316bb 100644 --- a/.github/workflows/Breakage.yml +++ b/.github/workflows/Breakage.yml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - pkgname: [CTDirect, CTFlows, CTParser, OptimalControl] + pkgname: [CTDirect, CTFlows, OptimalControl] pkgversion: [latest, stable] include: - pkgpath: control-toolbox diff --git a/.gitignore b/.gitignore index 242c9c4e..df92c37e 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,7 @@ docs/site/ # committed for packages, but should be committed for applications that require a static # environment. Manifest.toml + +# +reports/ +profiling/ \ No newline at end of file diff --git a/Project.toml b/Project.toml index 76aefc81..403ff0ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,18 +1,23 @@ name = "CTModels" uuid = "34c4fa32-2049-4079-8329-de33c2a22e2d" +version = "0.7.0" authors = ["Olivier Cots "] -version = "0.6.9" [deps] +ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a" CTBase = "54762871-cc72-4466-b8e8-f6c8b58076cd" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +SolverCore = "ff4d7338-4cf1-434d-91df-b86cb86fb843" [weakdeps] JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" @@ -25,16 +30,21 @@ CTModelsJSON = "JSON3" CTModelsPlots = "Plots" [compat] -CTBase = "0.16" +ADNLPModels = "0.8" +CTBase = "0.17" DocStringExtensions = "0.9" +ExaModels = "0.9" Interpolations = "0.16" JLD2 = "0.6" JSON3 = "1" +KernelAbstractions = "0.9" LinearAlgebra = "1" MLStyle = "0.4" MacroTools = "0.5" +NLPModels = "0.21" OrderedCollections = "1" Parameters = "0.12" Plots = "1" RecipesBase = "1" +SolverCore = "0.3" julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 8116dabb..79320223 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,12 +1,13 @@ [deps] -CTParser = "32681960-a1b1-40db-9bff-a1ca817385d1" +CTBase = "54762871-cc72-4466-b8e8-f6c8b58076cd" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +MarkdownAST = "d0879d2d-cac2-40c8-9cee-1863dc0c7391" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" [compat] -CTParser = "0.7" +CTBase = "0.17" Documenter = "1" JLD2 = "0.6" JSON3 = "1" diff --git a/docs/docutils/DocumenterReference.jl b/docs/docutils/DocumenterReference.jl new file mode 100644 index 00000000..2597ba20 --- /dev/null +++ b/docs/docutils/DocumenterReference.jl @@ -0,0 +1,976 @@ +# Copyright 2023, Oscar Dowson and contributors +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +# +# Modified November 2025 for CTBenchmarks.jl: +# - Separated public and private API documentation into distinct pages +# - Added robust handling for missing docstrings (warnings instead of errors) +# - Included non-exported symbols in API reference +# - Filtered internal compiler-generated symbols (starting with '#') +# +# Refactored December 2025: +# - Extracted helper functions to reduce code duplication +# - Improved documentation and code organization +# - Added Dict-based DocType to string conversion + +module DocumenterReference + +using CTBase: CTBase +using Documenter: Documenter +using Markdown: Markdown +using MarkdownAST: MarkdownAST + +# ═══════════════════════════════════════════════════════════════════════════════ +# Types and Constants +# ═══════════════════════════════════════════════════════════════════════════════ + +""" + DocType + +Enumeration of documentation element types recognized by the API reference generator. + +# Values + +- `DOCTYPE_ABSTRACT_TYPE`: An abstract type declaration +- `DOCTYPE_CONSTANT`: A constant binding (including non-function, non-type values) +- `DOCTYPE_FUNCTION`: A function or callable +- `DOCTYPE_MACRO`: A macro (name starts with `@`) +- `DOCTYPE_MODULE`: A submodule +- `DOCTYPE_STRUCT`: A concrete struct type +""" +@enum( + DocType, + DOCTYPE_ABSTRACT_TYPE, + DOCTYPE_CONSTANT, + DOCTYPE_FUNCTION, + DOCTYPE_MACRO, + DOCTYPE_MODULE, + DOCTYPE_STRUCT, +) + +""" + DOCTYPE_NAMES::Dict{DocType, String} + +Mapping from DocType enum values to their human-readable string representations. +""" +const DOCTYPE_NAMES = Dict{DocType, String}( + DOCTYPE_ABSTRACT_TYPE => "abstract type", + DOCTYPE_CONSTANT => "constant", + DOCTYPE_FUNCTION => "function", + DOCTYPE_MACRO => "macro", + DOCTYPE_MODULE => "module", + DOCTYPE_STRUCT => "struct", +) + +""" + DOCTYPE_ORDER::Dict{DocType, Int} + +Ordering for DocType values used when sorting symbols for display. +Lower values appear first. +""" +const DOCTYPE_ORDER = Dict{DocType, Int}( + DOCTYPE_MODULE => 0, + DOCTYPE_MACRO => 1, + DOCTYPE_FUNCTION => 2, + DOCTYPE_ABSTRACT_TYPE => 3, + DOCTYPE_STRUCT => 4, + DOCTYPE_CONSTANT => 5, +) + +""" + _Config + +Internal configuration for API reference generation. + +# Fields + +- `current_module::Module`: The module being documented. +- `subdirectory::String`: Output directory for generated API pages. +- `modules::Dict{Module,Vector{String}}`: Mapping of modules to their source files. + When a module is specified as `Module => files`, the files are stored here. +- `sort_by::Function`: Custom sort function for symbols. +- `exclude::Set{Symbol}`: Symbol names to exclude from documentation. +- `public::Bool`: Flag to generate public API page. +- `private::Bool`: Flag to generate private API page. +- `title::String`: Title displayed at the top of the generated page. +- `title_in_menu::String`: Title displayed in the navigation menu. +- `source_files::Vector{String}`: Global source file paths (fallback if no module-specific files). +- `filename::String`: Base filename (without extension) for the markdown file. +- `include_without_source::Bool`: If `true`, include symbols whose source file cannot be determined. +- `external_modules_to_document::Vector{Module}`: Additional modules to search for docstrings. +""" +struct _Config + current_module::Module + subdirectory::String + modules::Dict{Module,Vector{String}} + sort_by::Function + exclude::Set{Symbol} + public::Bool + private::Bool + title::String + title_in_menu::String + source_files::Vector{String} + filename::String + include_without_source::Bool + external_modules_to_document::Vector{Module} +end + +""" + CONFIG::Vector{_Config} + +Global configuration storage for API reference generation. + +Each call to [`automatic_reference_documentation`](@ref) appends a new `_Config` +entry to this vector. Use [`reset_config!`](@ref) to clear it between builds. +""" +const CONFIG = _Config[] + +""" + PAGE_CONTENT_ACCUMULATOR::Dict{String, Vector{Tuple{Module, Vector{String}, Vector{String}}}} + +Global accumulator for multi-module combined pages. +Maps output filename to a list of (module, public_docstrings, private_docstrings) tuples. +""" +const PAGE_CONTENT_ACCUMULATOR = Dict{String, Vector{Tuple{Module, Vector{String}, Vector{String}}}}() + +# ═══════════════════════════════════════════════════════════════════════════════ +# Public API +# ═══════════════════════════════════════════════════════════════════════════════ + +""" + reset_config!() + +Clear the global `CONFIG` vector and `PAGE_CONTENT_ACCUMULATOR`. +Useful between documentation builds or for testing. +""" +function reset_config!() + empty!(CONFIG) + empty!(PAGE_CONTENT_ACCUMULATOR) + return nothing +end + +""" + automatic_reference_documentation(; + subdirectory::String, + primary_modules, + sort_by::Function = identity, + exclude::Vector{Symbol} = Symbol[], + public::Bool = true, + private::Bool = true, + title::String = "API Reference", + title_in_menu::String = "", + filename::String = "", + source_files::Vector{String} = String[], + include_without_source::Bool = false, + external_modules_to_document::Vector{Module} = Module[], + ) + +Automatically creates the API reference documentation for one or more modules and +returns a structure which can be used in the `pages` argument of `Documenter.makedocs`. + +## Arguments + + * `subdirectory`: the directory relative to the documentation root in which to + write the API files. + * `primary_modules`: a vector of modules or `Module => source_files` pairs to document. + When source files are provided, only symbols defined in those files are documented. + * `sort_by`: a custom sort function applied to symbol lists. + * `exclude`: vector of symbol names to skip from the generated API. + * `public`: flag to generate public API page (default: `true`). + * `private`: flag to generate private API page (default: `true`). + * `title`: title displayed at the top of the generated page. + * `title_in_menu`: title displayed in the navigation menu (default: same as `title`). + * `filename`: base filename (without extension) for the markdown file. + * `source_files`: global source file paths (fallback if no module-specific files). + **Deprecated**: prefer using `primary_modules=[Module => files]` instead. + * `include_without_source`: if `true`, include symbols whose source file cannot + be determined. Default: `false`. + * `external_modules_to_document`: additional modules to search for docstrings + (e.g., `[Plots]` to include `Plots.plot` methods defined in your source files). + +## Multiple instances + +Each time you call this function, a new object is added to the global variable +`DocumenterReference.CONFIG`. Use `reset_config!()` to clear it between builds. +""" +function automatic_reference_documentation(; + subdirectory::String, + primary_modules::Vector, + sort_by::Function=identity, + exclude::Vector{Symbol}=Symbol[], + public::Bool=true, + private::Bool=true, + title::String="API Reference", + title_in_menu::String="", + filename::String="", + source_files::Vector{String}=String[], + include_without_source::Bool=false, + external_modules_to_document::Vector{Module}=Module[], +) + # Validate arguments + if !public && !private + error("automatic_reference_documentation: both `public` and `private` cannot be false.") + end + + # Parse primary_modules into a Dict{Module, Vector{String}} + modules_dict = _parse_primary_modules(primary_modules) + exclude_set = Set(exclude) + normalized_source_files = _normalize_paths(source_files) + effective_title_in_menu = isempty(title_in_menu) ? title : title_in_menu + effective_filename = _default_basename(filename, public, private) + + # Single-module case + if length(primary_modules) == 1 + current_module = first(keys(modules_dict)) + _register_config( + current_module, subdirectory, modules_dict, sort_by, exclude_set, + public, private, title, effective_title_in_menu, normalized_source_files, + effective_filename, include_without_source, external_modules_to_document + ) + return _build_page_return_structure(effective_title_in_menu, subdirectory, effective_filename, public, private) + end + + # Multi-module case with combined page (filename provided) + if !isempty(filename) + for mod in keys(modules_dict) + _register_config( + mod, subdirectory, modules_dict, sort_by, exclude_set, + public, private, title, effective_title_in_menu, normalized_source_files, + effective_filename, include_without_source, external_modules_to_document + ) + end + return _build_page_return_structure(effective_title_in_menu, subdirectory, effective_filename, public, private) + end + + # Multi-module case with per-module subdirectories + list_of_pages = Any[] + for mod in keys(modules_dict) + module_subdir = joinpath(subdirectory, string(mod)) + module_filename = _default_basename("", public, private) + default_title = _default_title(public, private) + + _register_config( + mod, module_subdir, modules_dict, sort_by, exclude_set, + public, private, default_title, default_title, normalized_source_files, + module_filename, include_without_source, external_modules_to_document + ) + + pages = _build_page_return_structure(default_title, module_subdir, module_filename, public, private) + push!(list_of_pages, string(mod) => last(pages)) + end + return effective_title_in_menu => list_of_pages +end + +# ═══════════════════════════════════════════════════════════════════════════════ +# Documenter Pipeline Integration +# ═══════════════════════════════════════════════════════════════════════════════ + +""" + APIBuilder <: Documenter.Builder.DocumentPipeline + +Custom Documenter pipeline stage for automatic API reference generation. + +This builder is inserted into the Documenter pipeline at order `0.0` (before +most other stages) to generate API reference pages from the configurations +stored in [`CONFIG`](@ref). +""" +abstract type APIBuilder <: Documenter.Builder.DocumentPipeline end + +""" + Documenter.Selectors.order(::Type{APIBuilder}) -> Float64 + +Return the pipeline order for [`APIBuilder`](@ref). +Returns `0.0`, placing this stage early in the Documenter pipeline. +""" +Documenter.Selectors.order(::Type{APIBuilder}) = 0.0 + +""" + Documenter.Selectors.runner(::Type{APIBuilder}, document) + +Documenter pipeline runner for API reference generation. +Processes all registered module configurations and generates their API reference pages. +""" +function Documenter.Selectors.runner(::Type{APIBuilder}, document::Documenter.Document) + @info "APIBuilder: creating API reference" + for config in CONFIG + _build_api_page(document, config) + end + _finalize_api_pages(document) + return nothing +end + +# ═══════════════════════════════════════════════════════════════════════════════ +# Helper Functions: Configuration +# ═══════════════════════════════════════════════════════════════════════════════ + +""" + _parse_primary_modules(primary_modules::Vector) -> Dict{Module, Vector{String}} + +Parse the `primary_modules` argument into a dictionary mapping modules to their source files. +Handles both plain modules and `Module => files` pairs. +""" +function _parse_primary_modules(primary_modules::Vector) + result = Dict{Module, Vector{String}}() + for m in primary_modules + if m isa Module + result[m] = String[] + elseif m isa Pair + mod = first(m) + files = last(m) + result[mod] = _normalize_paths(files isa Vector ? files : [files]) + else + error("Invalid element in primary_modules: expected Module or Module => files pair") + end + end + return result +end + +""" + _normalize_paths(paths) -> Vector{String} + +Normalize a collection of paths to absolute paths. +""" +function _normalize_paths(paths) + isempty(paths) ? String[] : [abspath(p) for p in paths] +end + +""" + _register_config(current_module, subdirectory, modules, sort_by, exclude, public, private, + title, title_in_menu, source_files, filename, include_without_source, + external_modules_to_document) + +Create and register a `_Config` in the global `CONFIG` vector. +""" +function _register_config( + current_module::Module, + subdirectory::String, + modules::Dict{Module, Vector{String}}, + sort_by::Function, + exclude::Set{Symbol}, + public::Bool, + private::Bool, + title::String, + title_in_menu::String, + source_files::Vector{String}, + filename::String, + include_without_source::Bool, + external_modules_to_document::Vector{Module}, +) + push!(CONFIG, _Config( + current_module, subdirectory, modules, sort_by, exclude, + public, private, title, title_in_menu, source_files, filename, + include_without_source, external_modules_to_document + )) + return nothing +end + +""" + _default_basename(filename::String, public::Bool, private::Bool) -> String + +Compute the default base filename for the generated markdown file. +""" +function _default_basename(filename::String, public::Bool, private::Bool) + !isempty(filename) && return filename + public && private && return "api" + public && return "public" + return "private" +end + +""" + _default_title(public::Bool, private::Bool) -> String + +Compute the default title based on public/private flags. +""" +function _default_title(public::Bool, private::Bool) + public && !private && return "Public API" + !public && private && return "Private API" + return "API Reference" +end + +""" + _build_page_path(subdirectory::String, filename::String) -> String + +Build the page path by joining subdirectory and filename. +Handles special cases where `subdirectory` is `"."` or empty. +""" +function _build_page_path(subdirectory::String, filename::String) + (subdirectory == "." || isempty(subdirectory)) && return filename + return "$subdirectory/$filename" +end + +""" + _build_page_return_structure(title_in_menu, subdirectory, filename, public, private) -> Pair + +Build the return structure for `automatic_reference_documentation`. +""" +function _build_page_return_structure(title_in_menu::String, subdirectory::String, filename::String, public::Bool, private::Bool) + if public && private + return title_in_menu => [ + "Public" => _build_page_path(subdirectory, "public.md"), + "Private" => _build_page_path(subdirectory, "private.md"), + ] + else + return title_in_menu => _build_page_path(subdirectory, "$filename.md") + end +end + +""" + _get_effective_source_files(config::_Config) -> Vector{String} + +Determine the effective source files for filtering symbols. +Priority: module-specific files > global source_files > empty (no filtering). +""" +function _get_effective_source_files(config::_Config) + module_files = get(config.modules, config.current_module, String[]) + !isempty(module_files) && return module_files + !isempty(config.source_files) && return config.source_files + return String[] +end + +# ═══════════════════════════════════════════════════════════════════════════════ +# Helper Functions: Symbol Classification +# ═══════════════════════════════════════════════════════════════════════════════ + +""" + _to_string(x::DocType) -> String + +Convert a DocType enumeration value to its string representation. +""" +_to_string(x::DocType) = DOCTYPE_NAMES[x] + +""" + _classify_symbol(obj, name_str::String) -> DocType + +Classify a symbol by its type (function, macro, struct, constant, module, abstract type). +""" +function _classify_symbol(obj, name_str::String) + startswith(name_str, "@") && return DOCTYPE_MACRO + obj isa Module && return DOCTYPE_MODULE + obj isa Type && isabstracttype(obj) && return DOCTYPE_ABSTRACT_TYPE + obj isa Type && return DOCTYPE_STRUCT + obj isa Function && return DOCTYPE_FUNCTION + return DOCTYPE_CONSTANT +end + +""" + _exported_symbols(mod::Module) -> NamedTuple + +Classify all symbols in a module into exported and private categories. +Returns a NamedTuple with `exported` and `private` fields, each containing +sorted lists of `(Symbol, DocType)` pairs. +""" +function _exported_symbols(mod::Module) + exported = Pair{Symbol,DocType}[] + private = Pair{Symbol,DocType}[] + exported_names = Set(names(mod; all=false)) + + for n in names(mod; all=true, imported=false) + name_str = String(n) + # Skip compiler-generated symbols and the module itself + startswith(name_str, "#") && continue + n == nameof(mod) && continue + + obj = try + getfield(mod, n) + catch + continue + end + + doc_type = _classify_symbol(obj, name_str) + target = n in exported_names ? exported : private + push!(target, n => doc_type) + end + + sort_fn = x -> (DOCTYPE_ORDER[x[2]], string(x[1])) + return (exported=sort(exported; by=sort_fn), private=sort(private; by=sort_fn)) +end + +# ═══════════════════════════════════════════════════════════════════════════════ +# Helper Functions: Source File Detection +# ═══════════════════════════════════════════════════════════════════════════════ + +""" + _get_source_file(mod::Module, key::Symbol, type::DocType) -> Union{String, Nothing} + +Determine the source file path where a symbol is defined. +Returns `nothing` if the source file cannot be determined. +""" +function _get_source_file(mod::Module, key::Symbol, type::DocType) + try + # Strategy 1: Try docstring metadata + path = _get_source_from_docstring(mod, key) + path !== nothing && return path + + obj = getfield(mod, key) + + # Strategy 2: For functions/macros, use methods() + if obj isa Function + path = _get_source_from_methods(obj) + path !== nothing && return path + end + + # Strategy 3: For concrete types, try constructor methods + if obj isa Type && !isabstracttype(obj) + path = _get_source_from_methods(obj) + path !== nothing && return path + end + + return nothing + catch e + @debug "Could not determine source file for $key in $mod" exception=e + return nothing + end +end + +""" + _get_source_from_docstring(mod::Module, key::Symbol) -> Union{String, Nothing} + +Try to get source file path from docstring metadata. +""" +function _get_source_from_docstring(mod::Module, key::Symbol) + binding = Base.Docs.Binding(mod, key) + meta = Base.Docs.meta(mod) + haskey(meta, binding) || return nothing + + docs = meta[binding] + if isa(docs, Base.Docs.MultiDoc) && !isempty(docs.docs) + for (_, docstr) in docs.docs + if isa(docstr, Base.Docs.DocStr) && haskey(docstr.data, :path) + path = docstr.data[:path] + if path !== nothing && !isempty(path) + return abspath(String(path)) + end + end + end + end + return nothing +end + +""" + _get_source_from_methods(obj) -> Union{String, Nothing} + +Try to get source file path from method definitions. +""" +function _get_source_from_methods(obj) + for m in methods(obj) + file = String(m.file) + if file != "" && file != "none" && !startswith(file, ".") + return abspath(file) + end + end + return nothing +end + +# ═══════════════════════════════════════════════════════════════════════════════ +# Helper Functions: Symbol Iteration +# ═══════════════════════════════════════════════════════════════════════════════ + +""" + _iterate_over_symbols(f, config, symbol_list) + +Iterate over symbols, applying a function to each documented symbol. +Filters symbols based on exclusion list, documentation presence, and source files. +""" +function _iterate_over_symbols(f, config::_Config, symbol_list) + current_module = config.current_module + effective_source_files = _get_effective_source_files(config) + + for (key, type) in sort!(copy(symbol_list); by=config.sort_by) + key isa Symbol || continue + + # Check exclusion + key in config.exclude && continue + + # Check documentation + if !_has_documentation(current_module, key, type, config.modules) + continue + end + + # Check source file filtering + if !_passes_source_filter(current_module, key, type, effective_source_files, config.include_without_source) + continue + end + + f(key, type) + end + return nothing +end + +""" + _has_documentation(mod::Module, key::Symbol, type::DocType, modules::Dict) -> Bool + +Check if a symbol has documentation. Logs a warning if not. +""" +function _has_documentation(mod::Module, key::Symbol, type::DocType, modules::Dict) + binding = Base.Docs.Binding(mod, key) + + has_doc = if isdefined(Base.Docs, :hasdoc) + Base.Docs.hasdoc(binding) + else + doc = Base.Docs.doc(binding) + doc !== nothing && !occursin("No documentation found.", string(doc)) + end + + if !has_doc + if type == DOCTYPE_MODULE + submod = getfield(mod, key) + if submod != mod && haskey(modules, submod) + return true # Module is documented elsewhere + end + end + @warn "No documentation found for $key in $mod. Skipping from API reference." + return false + end + return true +end + +""" + _passes_source_filter(mod, key, type, source_files, include_without_source) -> Bool + +Check if a symbol passes the source file filter. +""" +function _passes_source_filter(mod::Module, key::Symbol, type::DocType, source_files::Vector{String}, include_without_source::Bool) + isempty(source_files) && return true + + source_path = _get_source_file(mod, key, type) + if source_path === nothing + if !include_without_source + @debug "Cannot determine source file for $key ($type), skipping." + return false + end + return true + end + + return source_path in source_files +end + +# ═══════════════════════════════════════════════════════════════════════════════ +# Helper Functions: Type Formatting for @docs Blocks +# ═══════════════════════════════════════════════════════════════════════════════ + +""" + _method_signature_string(m::Method, mod::Module, key::Symbol) -> String + +Generate a Documenter-compatible signature string for a method. +Returns a string like `Module.func(::Type1, ::Type2)` for use in `@docs` blocks. +""" +function _method_signature_string(m::Method, mod::Module, key::Symbol) + sig = m.sig + while sig isa UnionAll + sig = sig.body + end + + if !(sig <: Tuple) + return "$(mod).$(key)" + end + + params = sig.parameters + arg_types = length(params) > 1 ? params[2:end] : Any[] + + if isempty(arg_types) + return "$(mod).$(key)()" + end + + type_strs = [_format_type_for_docs(T) for T in arg_types] + return "$(mod).$(key)($(join(type_strs, ", ")))" +end + +""" + _format_type_for_docs(T) -> String + +Format a type for use in Documenter's `@docs` block. +Always fully qualifies types to avoid UndefVarError when Documenter evaluates in Main. +""" +function _format_type_for_docs(T) + # Vararg + if T isa Core.TypeofVararg + inner = _format_type_for_docs(T.T) + inner_clean = startswith(inner, "::") ? inner[3:end] : inner + return "::Vararg{$(inner_clean)}" + end + + # TypeVar + T isa TypeVar && return "::$(T.name)" + + # UnionAll - unwrap and format + T isa UnionAll && return _format_type_for_docs(Base.unwrap_unionall(T)) + + # DataType + if T isa DataType + return _format_datatype_for_docs(T) + end + + # Union + if T isa Union + union_types = Base.uniontypes(T) + formatted = [_format_type_for_docs(ut) for ut in union_types] + cleaned = [startswith(s, "::") ? s[3:end] : s for s in formatted] + return "::Union{$(join(cleaned, ", "))}" + end + + return "::$(T)" +end + +""" + _format_datatype_for_docs(T::DataType) -> String + +Format a DataType for use in @docs blocks. +""" +function _format_datatype_for_docs(T::DataType) + type_mod = parentmodule(T) + type_name = T.name.name + is_core_or_base = type_mod === Core || type_mod === Base + + # Handle parametric types + if !isempty(T.parameters) + has_typevar_params = any(p -> p isa TypeVar, T.parameters) + + if has_typevar_params + # Strip type parameters to avoid UndefVarError + return is_core_or_base ? "::$(type_name)" : "::$(type_mod).$(type_name)" + else + # Keep concrete type parameters + params = [_format_type_param(p) for p in T.parameters] + params_str = join(params, ", ") + return is_core_or_base ? "::$(type_name){$(params_str)}" : "::$(type_mod).$(type_name){$(params_str)}" + end + end + + # Simple type + return is_core_or_base ? "::$(type_name)" : "::$(type_mod).$(type_name)" +end + +""" + _format_type_param(p) -> String + +Format a type parameter (can be a type or a value like an integer). +""" +function _format_type_param(p) + if p isa Type + s = _format_type_for_docs(p) + return startswith(s, "::") ? s[3:end] : s + elseif p isa TypeVar + return string(p.name) + else + return string(p) + end +end + +# ═══════════════════════════════════════════════════════════════════════════════ +# Page Building Functions +# ═══════════════════════════════════════════════════════════════════════════════ + +""" + _build_api_page(document::Documenter.Document, config::_Config) + +Generate public and/or private API reference pages for a module. +Accumulates content in `PAGE_CONTENT_ACCUMULATOR` for later finalization. +""" +function _build_api_page(document::Documenter.Document, config::_Config) + current_module = config.current_module + symbols = _exported_symbols(current_module) + + # Determine output filenames + public_basename = config.public && config.private ? "public" : config.filename + private_basename = config.public && config.private ? "private" : config.filename + private_filename = _build_page_path(config.subdirectory, "$private_basename.md") + + # Collect docstrings + public_docstrings = config.public ? _collect_module_docstrings(config, symbols.exported) : String[] + private_docstrings = config.private ? _collect_private_docstrings(config, symbols.private) : String[] + + # Accumulate content + if !haskey(PAGE_CONTENT_ACCUMULATOR, private_filename) + PAGE_CONTENT_ACCUMULATOR[private_filename] = Tuple{Module, Vector{String}, Vector{String}}[] + end + push!(PAGE_CONTENT_ACCUMULATOR[private_filename], (current_module, public_docstrings, private_docstrings)) + + return nothing +end + +""" + _collect_module_docstrings(config::_Config, symbol_list) -> Vector{String} + +Collect docstring blocks for symbols from the current module. +""" +function _collect_module_docstrings(config::_Config, symbol_list) + docstrings = String[] + current_module = config.current_module + + _iterate_over_symbols(config, symbol_list) do key, type + type == DOCTYPE_MODULE && return nothing + push!(docstrings, "## `$key`\n\n```@docs\n$(current_module).$key\n```\n\n") + return nothing + end + + return docstrings +end + +""" + _collect_private_docstrings(config::_Config, symbol_list) -> Vector{String} + +Collect docstring blocks for private symbols, including external module methods. +""" +function _collect_private_docstrings(config::_Config, symbol_list) + docstrings = _collect_module_docstrings(config, symbol_list) + + # Add docstrings from external modules + if !isempty(config.external_modules_to_document) + external_docs = _collect_external_module_docstrings(config) + append!(docstrings, external_docs) + end + + return docstrings +end + +""" + _collect_external_module_docstrings(config::_Config) -> Vector{String} + +Collect docstrings for methods from external modules defined in source files. +""" +function _collect_external_module_docstrings(config::_Config) + docstrings = String[] + added_signatures = Set{String}() + filtered_source_files = _get_effective_source_files(config) + + for extra_mod in config.external_modules_to_document + methods_by_func = _collect_methods_from_source_files(extra_mod, filtered_source_files) + + for (key, method_list) in sort(collect(methods_by_func); by=first) + for m in method_list + sig_str = _method_signature_string(m, extra_mod, key) + sig_str in added_signatures && continue + + push!(added_signatures, sig_str) + push!(docstrings, "## `$(extra_mod).$key`\n\n```@docs\n$(sig_str)\n```\n\n") + end + end + end + + return docstrings +end + +""" + _collect_methods_from_source_files(mod::Module, source_files::Vector{String}) -> Dict{Symbol, Vector{Method}} + +Collect all methods from a module that are defined in the given source files. +""" +function _collect_methods_from_source_files(mod::Module, source_files::Vector{String}) + methods_by_func = Dict{Symbol, Vector{Method}}() + + for key in names(mod; all=true) + obj = try + getfield(mod, key) + catch + continue + end + + obj isa Function || continue + + for m in methods(obj) + file = String(m.file) + (file == "" || file == "none") && continue + + abs_file = abspath(file) + should_include = isempty(source_files) || (abs_file in source_files) + + if should_include + if !haskey(methods_by_func, key) + methods_by_func[key] = Method[] + end + push!(methods_by_func[key], m) + end + end + end + + return methods_by_func +end + +""" + _finalize_api_pages(document::Documenter.Document) + +Finalize all accumulated API pages by combining content from multiple modules. +""" +function _finalize_api_pages(document::Documenter.Document) + for (filename, module_contents) in PAGE_CONTENT_ACCUMULATOR + is_private = occursin("private", filename) || !occursin("public", filename) + + all_modules = [mc[1] for mc in module_contents] + modules_str = join([string(m) for m in all_modules], "`, `") + + overview, all_docstrings = if is_private + _build_private_page_content(modules_str, module_contents) + else + _build_public_page_content(modules_str, module_contents) + end + + combined_md = Markdown.parse(overview * join(all_docstrings, "\n")) + + document.blueprint.pages[filename] = Documenter.Page( + joinpath(document.user.source, filename), + joinpath(document.user.build, filename), + document.user.build, + combined_md.content, + Documenter.Globals(), + convert(MarkdownAST.Node, combined_md), + ) + end + + empty!(PAGE_CONTENT_ACCUMULATOR) + return nothing +end + +""" + _build_private_page_content(modules_str, module_contents) -> Tuple{String, Vector{String}} + +Build the overview and docstrings for a private API page. +""" +function _build_private_page_content(modules_str::String, module_contents) + overview = """ + ```@meta + EditURL = nothing + ``` + + # Private API + + This page lists **non-exported** (internal) symbols of `$(modules_str)`. + + """ + + all_docstrings = String[] + for (mod, _, private_docs) in module_contents + if !isempty(private_docs) + push!(all_docstrings, "\n---\n\n### From `$(mod)`\n\n") + append!(all_docstrings, private_docs) + end + end + + return overview, all_docstrings +end + +""" + _build_public_page_content(modules_str, module_contents) -> Tuple{String, Vector{String}} + +Build the overview and docstrings for a public API page. +""" +function _build_public_page_content(modules_str::String, module_contents) + overview = """ + # Public API + + This page lists **exported** symbols of `$(modules_str)`. + + """ + + all_docstrings = String[] + for (mod, public_docs, _) in module_contents + if !isempty(public_docs) + push!(all_docstrings, "\n---\n\n### From `$(mod)`\n\n") + append!(all_docstrings, public_docs) + end + end + + return overview, all_docstrings +end + +end # module diff --git a/docs/make.jl b/docs/make.jl index 5c70cffe..efc934b6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,9 +1,20 @@ using Documenter using CTModels +using CTBase # For automatic_reference_documentation using Plots using JSON3 using JLD2 +using Markdown +using MarkdownAST: MarkdownAST +# ═══════════════════════════════════════════════════════════════════════════════ +# Configuration +# ═══════════════════════════════════════════════════════════════════════════════ +draft = false # Draft mode: if true, @example blocks in markdown are not executed + +# ═══════════════════════════════════════════════════════════════════════════════ +# Load extensions +# ═══════════════════════════════════════════════════════════════════════════════ const CTModelsPlots = Base.get_extension(CTModels, :CTModelsPlots) const CTModelsJSON = Base.get_extension(CTModels, :CTModelsJSON) const CTModelsJLD = Base.get_extension(CTModels, :CTModelsJLD) @@ -15,47 +26,267 @@ for Module in Modules DocMeta.setdocmeta!(Module, :DocTestSetup, :(using $Module); recursive=true) end +# ═══════════════════════════════════════════════════════════════════════════════ +# Paths +# ═══════════════════════════════════════════════════════════════════════════════ repo_url = "github.com/control-toolbox/CTModels.jl" +src_dir = abspath(joinpath(@__DIR__, "..", "src")) +ext_dir = abspath(joinpath(@__DIR__, "..", "ext")) + +# Helper to build absolute paths +src(files...) = [abspath(joinpath(src_dir, f)) for f in files] +ext(files...) = [abspath(joinpath(ext_dir, f)) for f in files] -API_PAGES = [ - "constraints.md", - "control.md", - "ctmodels.md", - "default.md", - "definition.md", - "dual_model.md", - "dynamics.md", - "init.md", - "jld.md", - "json.md", - "model.md", - "objective.md", - "plot.md", - "print.md", - "solution.md", - "state.md", - "time_dependence.md", - "times.md", - "types.md", - "utils.md", - "variable.md", +# Symbols to exclude from documentation (auto-generated by @with_kw, etc.) +const EXCLUDE_SYMBOLS = Symbol[ + :include, + :eval, + Symbol("@pack_PreModel"), + Symbol("@pack_PreModel!"), + Symbol("@unpack_PreModel"), + :is_empty, ] +# ═══════════════════════════════════════════════════════════════════════════════ +# Build documentation +# ═══════════════════════════════════════════════════════════════════════════════ makedocs(; - remotes=nothing, - warnonly=[:cross_references, :autodocs_block], + draft=draft, + remotes=nothing, # Disable remote links. Needed for DocumenterReference + warnonly=true, sitename="CTModels.jl", format=Documenter.HTML(; repolink="https://" * repo_url, prettyurls=false, - size_threshold_ignore=["api.md", "dev.md"], + #size_threshold_ignore=["api.md", "dev.md"], + #size_threshold=300_000, # 300 KiB threshold assets=[ asset("https://control-toolbox.org/assets/css/documentation.css"), asset("https://control-toolbox.org/assets/js/documentation.js"), ], ), - pages=["Introduction" => "index.md", "API" => API_PAGES], checkdocs=:none, + pages=[ + "Introduction" => "index.md", + "Interfaces" => [ + "OCP Tools" => "interfaces/ocp_tools.md", + "Optimization Problems" => "interfaces/optimization_problems.md", + "Optimization Modelers" => "interfaces/optimization_modelers.md", + "Solution Builders" => "interfaces/ocp_solution_builders.md", + ], + "API Reference" => [ + # ─────────────────────────────────────────────────────────────────── + # Main module + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src("CTModels.jl")], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="CTModels", + title_in_menu="CTModels", + filename="ctmodels", + ), + # ─────────────────────────────────────────────────────────────────── + # Core: Types + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src( + "core/types.jl", + "core/types/ocp_model.jl", + "core/types/ocp_components.jl", + "core/types/ocp_solution.jl", + "core/types/initial_guess.jl", + "core/types/nlp.jl", + )], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Types", + title_in_menu="Types", + filename="types", + ), + # ─────────────────────────────────────────────────────────────────── + # Core: Default & Utils + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src("core/default.jl", "core/utils.jl")], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Default & Utils", + title_in_menu="Default & Utils", + filename="default_utils", + ), + # ─────────────────────────────────────────────────────────────────── + # OCP: Model (model, definition, time_dependence) + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src( + "ocp/model.jl", + "ocp/definition.jl", + "ocp/time_dependence.jl", + )], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Model", + title_in_menu="Model", + filename="model", + ), + # ─────────────────────────────────────────────────────────────────── + # OCP: Times + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src("ocp/times.jl")], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Times", + title_in_menu="Times", + filename="times", + ), + # ─────────────────────────────────────────────────────────────────── + # OCP: State, Control, Variable + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src("ocp/state.jl", "ocp/control.jl", "ocp/variable.jl")], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="State, Control & Variable", + title_in_menu="State, Control & Variable", + filename="state_control_variable", + ), + # ─────────────────────────────────────────────────────────────────── + # OCP: Dynamics & Objective + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src("ocp/dynamics.jl", "ocp/objective.jl")], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Dynamics & Objective", + title_in_menu="Dynamics & Objective", + filename="dynamics_objective", + ), + # ─────────────────────────────────────────────────────────────────── + # OCP: Constraints + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src("ocp/constraints.jl")], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Constraints", + title_in_menu="Constraints", + filename="constraints", + ), + # ─────────────────────────────────────────────────────────────────── + # OCP: Solution & Dual + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src("ocp/solution.jl", "ocp/dual_model.jl")], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Solution & Dual", + title_in_menu="Solution & Dual", + filename="solution_dual", + ), + # ─────────────────────────────────────────────────────────────────── + # OCP: Print + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src("ocp/print.jl")], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Print", + title_in_menu="Print", + filename="print", + ), + # ─────────────────────────────────────────────────────────────────── + # Initial Guess + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src("init/initial_guess.jl")], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Initial Guess", + title_in_menu="Initial Guess", + filename="initial_guess", + ), + # ─────────────────────────────────────────────────────────────────── + # NLP Backends + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModels => src( + "nlp/nlp_backends.jl", + "nlp/options_schema.jl", + "nlp/problem_core.jl", + "nlp/discretized_ocp.jl", + "nlp/model_api.jl", + )], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="NLP Backends", + title_in_menu="NLP Backends", + filename="nlp", + ), + # ─────────────────────────────────────────────────────────────────── + # Extension: Plot + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[CTModelsPlots => ext( + "CTModelsPlots.jl", + "plot.jl", + "plot_default.jl", + "plot_utils.jl", + )], + external_modules_to_document=[Plots], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="Plot Extension", + title_in_menu="Plot", + filename="plot", + ), + # ─────────────────────────────────────────────────────────────────── + # Extension: JLD & JSON (combined) + # ─────────────────────────────────────────────────────────────────── + CTBase.automatic_reference_documentation(; + subdirectory=".", + primary_modules=[ + CTModelsJSON => ext("CTModelsJSON.jl"), + CTModelsJLD => ext("CTModelsJLD.jl"), + ], + external_modules_to_document=[CTModels], + exclude=EXCLUDE_SYMBOLS, + public=false, + private=true, + title="JLD & JSON Extension", + title_in_menu="JLD & JSON", + filename="import_export", + ), + ], + ], ) +# ═══════════════════════════════════════════════════════════════════════════════ deploydocs(; repo=repo_url * ".git", devbranch="main") diff --git a/docs/src/constraints.md b/docs/save/constraints.md similarity index 100% rename from docs/src/constraints.md rename to docs/save/constraints.md diff --git a/docs/src/control.md b/docs/save/control.md similarity index 100% rename from docs/src/control.md rename to docs/save/control.md diff --git a/docs/src/ctmodels.md b/docs/save/ctmodels.md similarity index 100% rename from docs/src/ctmodels.md rename to docs/save/ctmodels.md diff --git a/docs/src/default.md b/docs/save/default.md similarity index 100% rename from docs/src/default.md rename to docs/save/default.md diff --git a/docs/src/definition.md b/docs/save/definition.md similarity index 100% rename from docs/src/definition.md rename to docs/save/definition.md diff --git a/docs/src/dual_model.md b/docs/save/dual_model.md similarity index 100% rename from docs/src/dual_model.md rename to docs/save/dual_model.md diff --git a/docs/src/dynamics.md b/docs/save/dynamics.md similarity index 100% rename from docs/src/dynamics.md rename to docs/save/dynamics.md diff --git a/docs/src/init.md b/docs/save/init.md similarity index 100% rename from docs/src/init.md rename to docs/save/init.md diff --git a/docs/src/jld.md b/docs/save/jld.md similarity index 100% rename from docs/src/jld.md rename to docs/save/jld.md diff --git a/docs/src/json.md b/docs/save/json.md similarity index 100% rename from docs/src/json.md rename to docs/save/json.md diff --git a/docs/src/model.md b/docs/save/model.md similarity index 100% rename from docs/src/model.md rename to docs/save/model.md diff --git a/docs/src/objective.md b/docs/save/objective.md similarity index 100% rename from docs/src/objective.md rename to docs/save/objective.md diff --git a/docs/src/plot.md b/docs/save/plot.md similarity index 100% rename from docs/src/plot.md rename to docs/save/plot.md diff --git a/docs/src/print.md b/docs/save/print.md similarity index 100% rename from docs/src/print.md rename to docs/save/print.md diff --git a/docs/src/solution.md b/docs/save/solution.md similarity index 100% rename from docs/src/solution.md rename to docs/save/solution.md diff --git a/docs/src/state.md b/docs/save/state.md similarity index 100% rename from docs/src/state.md rename to docs/save/state.md diff --git a/docs/src/time_dependence.md b/docs/save/time_dependence.md similarity index 100% rename from docs/src/time_dependence.md rename to docs/save/time_dependence.md diff --git a/docs/src/times.md b/docs/save/times.md similarity index 100% rename from docs/src/times.md rename to docs/save/times.md diff --git a/docs/src/types.md b/docs/save/types.md similarity index 100% rename from docs/src/types.md rename to docs/save/types.md diff --git a/docs/src/utils.md b/docs/save/utils.md similarity index 100% rename from docs/src/utils.md rename to docs/save/utils.md diff --git a/docs/src/variable.md b/docs/save/variable.md similarity index 100% rename from docs/src/variable.md rename to docs/save/variable.md diff --git a/docs/src/index.md b/docs/src/index.md index 3a042fa3..e5f2eb4d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,12 +1,218 @@ # CTModels.jl +```@meta +CurrentModule = CTModels +``` + The `CTModels.jl` package is part of the [control-toolbox ecosystem](https://github.com/control-toolbox). +It provides the **mathematical model layer** for optimal control problems: -The root package is [OptimalControl.jl](https://github.com/control-toolbox/OptimalControl.jl) which aims to provide tools to model and solve optimal control problems with ordinary differential equations by direct and indirect methods, both on CPU and GPU. +- **types and building blocks** for states, controls, variables, time grids, and constraints; +- an `AbstractModel`/`Model` and `AbstractSolution`/`Solution` hierarchy for optimal control problems; +- tools to build **initial guesses**, connect to **NLP backends**, and interpret their solutions; +- optional extensions for **exporting solutions** (JSON/JLD) and **plotting**. -**API Documentation** +!!! note -```@contents -Pages = Main.API_PAGES -Depth = 1 -``` + The root package is [OptimalControl.jl](https://github.com/control-toolbox/OptimalControl.jl) which aims + to provide tools to model and solve optimal control problems with ordinary differential equations + by direct and indirect methods, both on CPU and GPU. + +!!! warning + + In some examples in the documentation, private methods are shown without the module prefix. + This is done for the sake of clarity and readability. + + ```julia-repl + julia> using CTModels + julia> x = 1 + julia> private_fun(x) # throws an error + ``` + + This should instead be written as: + + ```julia-repl + julia> using CTModels + julia> x = 1 + julia> CTModels.private_fun(x) + ``` + + If the method is re-exported by another package, + + ```julia + module OptimalControl + import CTModels: private_fun + export private_fun + end + ``` + + then there is no need to prefix it with the original module name: + + ```julia-repl + julia> using OptimalControl + julia> x = 1 + julia> private_fun(x) + ``` + +## What CTModels provides + +At a high level, CTModels is responsible for: + +- **Defining optimal control problems**: + `AbstractModel` / `Model` store dynamics, objective, constraints, time structure, and metadata. +- **Representing numerical solutions**: + `AbstractSolution` / `Solution` store state, control, dual variables, and solver information. +- **Managing time grids and dimensions** through convenient type aliases. +- **Structuring constraints** (path, boundary, box constraints on state, control, and variables). +- **Connecting to NLP backends** (ADNLPModels, ExaModels, etc.) via modelers and builders. +- **Providing utilities** for initial guesses, export/import, and plotting of solutions. + +Most of the public API is organized in a way that closely mirrors the mathematical +objects you manipulate when formulating an optimal control problem. + +## Time grids and basic aliases + +CTModels defines a few central type aliases that appear throughout the API: + +- `Dimension`: integer dimensions used for state, control, and variables. +- `ctNumber` and `ctVector`: real numbers and vectors of reals. +- `Time`, `Times`, `TimesDisc`: continuous time, time vectors, and discrete time grids. + +These aliases make type signatures more readable while remaining flexible enough +to accept a variety of numeric types. + +## Models, solutions, and constraints + +The core **optimal control model** is expressed via: + +- `AbstractModel` / `Model`: store the structure of the OCP + (dynamics, objective, constraints, time dependence, etc.). +- `ConstraintsModel`: a structured representation of all constraints + (path constraints, boundary constraints, and box constraints on state, control, and variables). + +In practice you typically: + +1. Specify **time dependence** and **time models** (fixed or free final time, etc.). +2. Describe **state, control, and variable spaces**. +3. Provide **dynamics** and **objective** functions. +4. Add **constraints**, either programmatically or via a `ConstraintsDictType` dictionary. + +The numerical **solution** of an OCP is represented by: + +- `AbstractSolution` / `Solution`: contain time grids, state and control trajectories, + path and boundary dual variables, solver status, and diagnostics. +- `DualModel` and related types: organize dual variables associated with constraints. + +These objects are the main bridge between the mathematical problem and the NLP backends. + +## Initial guesses + +Good initial guesses are crucial for challenging optimal control problems. +CTModels provides a small layer to organize them: + +- `pre_initial_guess` builds an `OptimalControlPreInit` object from raw user data + (functions, vectors, or constants for state, control, and variables). +- `initial_guess` turns this into an `OptimalControlInitialGuess`, checking consistency + with the chosen `AbstractOptimalControlProblem`. + +The corresponding API is implemented in `src/init/initial_guess.jl` and is documented +in the *Initial Guess* section of the API reference. + +## NLP backends and modelers + +CTModels does **not** solve the NLP itself. Instead, it connects to external NLP +backends via modelers and builders defined in `src/nlp/`: + +- `ADNLPModeler` (based on `ADNLPModels.jl`), +- `ExaModeler` (based on `ExaModels.jl`), +- additional builder types and helper functions. + +These modelers: + +- expose options through the generic `AbstractOCPTool` interface from CTBase + (see the *Interfaces → OCP Tools* page), +- build backend-specific NLP models from an `AbstractOptimizationProblem`, +- optionally map NLP solutions back to `CTModels.Solution` objects. + +The *Interfaces* section of the documentation contains detailed guides for: + +- implementing new **optimization problems**, +- implementing new **optimization modelers**, and +- implementing new **OCP solution builders**. + +## Extensions: JSON, JLD, and plotting + +Several optional extensions live in the `ext/` directory and are loaded on demand +by the corresponding packages: + +- **CTModelsJSON.jl** (requires `JSON3.jl`): + helpers to serialize/deserialize the `infos::Dict{Symbol,Any}` carried by solutions, + and methods for + `export_ocp_solution(CTModels.JSON3Tag(), ::Solution)` / + `import_ocp_solution(CTModels.JSON3Tag(), ::Model)`. + +- **CTModelsJLD.jl** (requires `JLD2.jl`): + methods to export and import a `Solution` as a `.jld2` file using + `export_ocp_solution(CTModels.JLD2Tag(), ::Solution)` and + `import_ocp_solution(CTModels.JLD2Tag(), ::Model)`. + +- **CTModelsPlots.jl** (requires `Plots.jl`): + plot recipes and helpers that make + `Plots.plot(sol::CTModels.Solution, ...)` + and + `Plots.plot!(sol::CTModels.Solution, ...)` + display the trajectories of state, control, costate, constraints, and dual + variables in a consistent, configurable way. + +If the corresponding extension package is not loaded, the public wrappers +`export_ocp_solution`, `import_ocp_solution`, and the generic `RecipesBase.plot` +throw a descriptive `CTBase.ExtensionError`. + +## How this documentation is organized + +The documentation is split into two main parts: + +- **Interfaces** + - *OCP Tools*: how to implement new configurable tools (backends, discretizers, solvers). + - *Optimization Problems*: how to define `AbstractOptimizationProblem` types. + - *Optimization Modelers*: how to map optimization problems to specific NLP backends. + - *Solution Builders*: how to turn NLP execution statistics into `CTModels.Solution` objects. + +- **API Reference** + - *Types*: core types for models, solutions, and internal structures. + - *Model / Times / Dynamics / Objective / Constraints*: detailed API for building OCP models. + - *Solution & Dual*: how solutions and dual variables are represented. + - *Initial Guess*: utilities to build and validate initial guesses. + - *NLP Backends*: ADNLPModels/ExaModels-based backends and related options. + - *Extensions*: Plot, JSON, and JLD extensions. + +You can start by reading the **Interfaces** pages to understand the high-level +design, then use the **API Reference** to look up the details of particular +functions and types. + +## I am X, I want to do Y → read… + +- **I use OptimalControl.jl and I just want to understand what CTModels does in the background** + Read this introduction page, then skim through the **Interfaces** section to see how + problems, modelers, and builders fit together. + +- **I want to formulate a new optimal control / optimization problem** + Read **Interfaces → Optimization Problems**, then **API Reference → Model / Times / Dynamics / Objective / Constraints** + for details about fields and conventions. + +- **I want to connect a new NLP backend or tweak an existing backend** + Read **Interfaces → Optimization Modelers** and the **API Reference → NLP Backends** section. + +- **I want to build good initial guesses for my problems** + Read **Interfaces → Solution Builders** for the overall philosophy, then **API Reference → Initial Guess** + for the `pre_initial_guess` and `initial_guess` functions. + +- **I want to save / reload solutions (for example for numerical experiments)** + Read **API Reference → Extensions (JSON & JLD)** and the pages associated with the `CTModelsJSON` and `CTModelsJLD` modules. + +- **I want to plot solution trajectories nicely** + Read **API Reference → Extensions (Plot Extension)**, and look at the examples using `Plots.plot(sol)` and `Plots.plot!(sol)`. + +- **I want to contribute to the core of CTModels (types, constraints, dual variables, etc.)** + Start with **API Reference → Types**, then **Solution & Dual** and **Constraints** to understand the internal structures + before modifying or adding new fields. diff --git a/docs/src/interfaces/ocp_solution_builders.md b/docs/src/interfaces/ocp_solution_builders.md new file mode 100644 index 00000000..9c918ae3 --- /dev/null +++ b/docs/src/interfaces/ocp_solution_builders.md @@ -0,0 +1,144 @@ +# Implementing OCP solution builders + +This page explains how to implement builders that turn NLP back-end +execution statistics into objects associated with discretized optimal +control problems. + +These builders implement the +[`AbstractOCPSolutionBuilder`](@ref CTModels.AbstractOCPSolutionBuilder) +interface, which refines the more general +[`AbstractSolutionBuilder`](@ref CTModels.AbstractSolutionBuilder). + +## Overview of the contract + +A concrete OCP solution builder type `B` is expected to: + +- subtype `AbstractOCPSolutionBuilder`: + + ```julia + struct MySolutionBuilder{F} <: CTModels.AbstractOCPSolutionBuilder + f::F # function or callable used internally + end + ``` + +- be callable on an NLP back-end solution, represented as + `SolverCore.AbstractExecutionStats`: + + ```julia + (builder::MySolutionBuilder)( + nlp_solution::SolverCore.AbstractExecutionStats; + kwargs..., + ) = ... + ``` + +A generic fallback for this call is defined on +`AbstractOCPSolutionBuilder` and throws `CTBase.NotImplemented` if it is not +specialized. + +## Relationship with optimization problems + +OCP solution builders are typically stored inside +[`OCPBackendBuilders`](@ref CTModels.OCPBackendBuilders), which itself is +used by [`DiscretizedOptimalControlProblem`](@ref +CTModels.DiscretizedOptimalControlProblem). Each back-end (e.g. ADNLPModels, +ExaModels) has a pair of builders: + +- a model builder `TM <: AbstractModelBuilder`; +- a solution builder `TS <: AbstractOCPSolutionBuilder`. + +The optimization problem exposes these builders via the `get_*_builder` +interface: + +- [`get_adnlp_solution_builder`](@ref CTModels.get_adnlp_solution_builder) +- [`get_exa_solution_builder`](@ref CTModels.get_exa_solution_builder) + +Modelers (see the `optimization_modelers.md` page) retrieve the appropriate +solution builder and apply it to the NLP back-end solution when they want to +produce an OCP-related representation. + +## Example: ADNLPSolutionBuilder and ExaSolutionBuilder + +CTModels defines two concrete OCP solution builders in `core/types/nlp.jl`: + +```julia +struct ADNLPSolutionBuilder{T<:Function} <: CTModels.AbstractOCPSolutionBuilder + f::T +end + +struct ExaSolutionBuilder{T<:Function} <: CTModels.AbstractOCPSolutionBuilder + f::T +end +``` + +The corresponding call methods are implemented in `nlp/discretized_ocp.jl`: + +```julia +function (builder::CTModels.ADNLPSolutionBuilder)( + nlp_solution::SolverCore.AbstractExecutionStats, +) + return builder.f(nlp_solution) +end + +function (builder::CTModels.ExaSolutionBuilder)( + nlp_solution::SolverCore.AbstractExecutionStats, +) + return builder.f(nlp_solution) +end +``` + +This pattern allows the internal implementation (carried by `f`) to vary +while the external interface remains stable. + +## Example: minimal builders in tests + +The test helper in `test/problems/problems_definition.jl` shows a minimal +implementation where the solution builders simply return the NLP solution +unchanged: + +```julia +abstract type AbstractNLPSolutionBuilder <: CTModels.AbstractSolutionBuilder end + +struct ADNLPSolutionBuilder <: AbstractNLPSolutionBuilder end +struct ExaSolutionBuilder <: AbstractNLPSolutionBuilder end + +function (builder::ADNLPSolutionBuilder)( + nlp_solution::SolverCore.AbstractExecutionStats, +) + return nlp_solution +end + +function (builder::ExaSolutionBuilder)( + nlp_solution::SolverCore.AbstractExecutionStats, +) + return nlp_solution +end +``` + +This illustrates that the only strict requirement at the interface level is +being callable on `AbstractExecutionStats`. The actual transformation (if +any) is left to the concrete implementation. + +## Designing your own OCP solution builder + +When designing a new solution builder, consider: + +- **Input**: a back-end solution object, typically + `SolverCore.AbstractExecutionStats` from the NLP solver. +- **Output**: an OCP-related representation (e.g. an + `AbstractSolution`, a struct containing trajectories, or an intermediate + diagnostic object). +- **Configuration**: solution builders do not usually follow the + `AbstractOCPTool` options interface, but they may still store internal + functions and parameters as fields. + +A typical pattern is to: + +1. define a struct that stores whatever is needed to interpret the NLP + solution; +2. implement the call method described above; +3. plug the builder into your + `AbstractOptimizationProblem` implementation via the + `get_*_solution_builder` interface. + +See also the documentation pages on optimization problems and modelers for +how these components fit together. diff --git a/docs/src/interfaces/ocp_tools.md b/docs/src/interfaces/ocp_tools.md new file mode 100644 index 00000000..c3307a86 --- /dev/null +++ b/docs/src/interfaces/ocp_tools.md @@ -0,0 +1,176 @@ +# Implementing new OCP tools + +This page explains how to implement new *tools* in CTModels that follow the +`AbstractOCPTool` interface. Tools are configurable components such as +backends, modelers, discretizers, or solvers that expose a common options +API. + +The interface is defined by the abstract type +[`AbstractOCPTool`](@ref CTModels.AbstractOCPTool) and the helper functions in +`nlp/options_schema.jl`. + +## Overview + +All concrete tools `T <: AbstractOCPTool` are expected to: + +- store their configuration in two fields + - `options_values::NamedTuple` — effective option values. + - `options_sources::NamedTuple` — provenance for each option (`:ct_default` + or `:user`). +- optionally describe their options via + [`_option_specs(::Type{T})`](@ref CTModels._option_specs), returning a + `NamedTuple` of [`OptionSpec`](@ref CTModels.OptionSpec) values. +- provide a keyword-only constructor `T(; kwargs...)` that uses + [`_build_ocp_tool_options`](@ref CTModels._build_ocp_tool_options) to + validate and merge user-supplied keyword arguments with tool defaults. + +High-level helpers such as +[`get_option_value`](@ref CTModels.get_option_value), +[`get_option_source`](@ref CTModels.get_option_source), +[`get_option_default`](@ref CTModels.get_option_default) and +[`show_options`](@ref CTModels.show_options) then work uniformly on any +`AbstractOCPTool` subtype. + +## Defining a new tool type + +1. **Choose an abstract specialization** + + Depending on the role of your tool, you will typically subtype one of the + following interfaces, all of which inherit from + [`AbstractOCPTool`](@ref CTModels.AbstractOCPTool): + + - [`AbstractOptimizationModeler`](@ref CTModels.AbstractOptimizationModeler) + for OCP→NLP modelers (e.g. `ADNLPModeler`, `ExaModeler`). + - `AbstractOptimizationSolver` (from CTSolvers) for NLP solvers + (e.g. `IpoptSolver`). + - `AbstractOptimalControlDiscretizer` (from CTSolvers) for OCP discretizers + (e.g. `Collocation`). + +2. **Define the concrete struct** + + A minimal tool definition looks like: + + ```julia + struct MyTool{Vals,Srcs} <: AbstractOptimizationModeler + options_values::Vals + options_sources::Srcs + end + ``` + + The field names `options_values` and `options_sources` are required by the + generic helpers [`_options_values`](@ref CTModels._options_values) and + [`_option_sources`](@ref CTModels._option_sources). + +## Describing options with `OptionSpec` + +To expose metadata for your tool's options, specialize +[`_option_specs(::Type{T})`](@ref CTModels._option_specs) on your concrete +type. The function should return a `NamedTuple` whose fields are option names +and whose values are [`OptionSpec`](@ref CTModels.OptionSpec) instances. + +```julia +function CTModels._option_specs(::Type{<:MyTool}) + return ( + tol = CTModels.OptionSpec(; + type = Real, + default = 1e-6, + description = "Optimality tolerance.", + ), + max_iter = CTModels.OptionSpec(; + type = Integer, + default = 1000, + description = "Maximum number of iterations.", + ), + ) +end +``` + +If `_option_specs` returns `missing` for a tool type, then functions like +[`options_keys`](@ref CTModels.options_keys) and +[`default_options`](@ref CTModels.default_options) will report that no +metadata is available. + +## Implementing the constructor with `_build_ocp_tool_options` + +The recommended pattern for constructing tools is to delegate keyword +processing to [`_build_ocp_tool_options`](@ref CTModels._build_ocp_tool_options): + +```julia +function MyTool(; kwargs...) + values, sources = CTModels._build_ocp_tool_options( + MyTool; kwargs..., strict_keys = true, + ) + return MyTool{typeof(values),typeof(sources)}(values, sources) +end +``` + +This helper: + +- normalizes `kwargs` to a `NamedTuple`; +- validates keys and types against `_option_specs` (when available); +- merges defaults from [`default_options`](@ref CTModels.default_options) with + user overrides (user wins); +- builds the parallel `options_sources` NamedTuple, marking each entry as + `:ct_default` or `:user`. + +Once defined, your tool automatically works with +[`get_option_value`](@ref CTModels.get_option_value), +[`get_option_source`](@ref CTModels.get_option_source), +[`get_option_default`](@ref CTModels.get_option_default) and +[`show_options`](@ref CTModels.show_options). + +## Registering tools and assigning symbols + +For some categories of tools, CTModels or CTSolvers maintain registries that +map symbolic identifiers to concrete types. For example, modelers are +registered in `REGISTERED_MODELERS` in `nlp_backends.jl`, and solvers and +\discretizers are registered similarly in CTSolvers. + +To integrate a new tool into such a registry, you typically: + +1. Specialize [`get_symbol`](@ref CTModels.get_symbol) on the tool type: + + ```julia + CTModels.get_symbol(::Type{<:MyTool}) = :mytool + ``` + +2. Optionally specialize [`tool_package_name`](@ref CTModels.tool_package_name) + to indicate which external package provides the implementation: + + ```julia + CTModels.tool_package_name(::Type{<:MyTool}) = "MyBackendPackage" + ``` + +3. Add the tool type to the appropriate `REGISTERED_*` constant and use the + helper that builds a tool from a symbol (e.g. + `build_modeler_from_symbol(:mytool; kwargs...)`). + +## Examples + +### ADNLPModeler (CTModels) + +`ADNLPModeler` is a concrete +[`AbstractOptimizationModeler`](@ref CTModels.AbstractOptimizationModeler) +that wraps `ADNLPModels.jl`: + +- it subtypes `AbstractOptimizationModeler <: AbstractOCPTool`; +- it defines `options_values` and `options_sources` fields; +- it specializes `_option_specs(::Type{<:ADNLPModeler})` to describe its + options (`show_time`, `backend`, etc.); +- it has a keyword-only constructor implemented via + `_build_ocp_tool_options(ADNLPModeler; kwargs...)`. + +### Collocation (CTSolvers) + +In CTSolvers, `Collocation` is a concrete discretizer implementing +`AbstractOptimalControlDiscretizer <: AbstractOCPTool`: + +- it stores `options_values` and `options_sources`; +- it defines `_option_specs(::Type{<:Collocation})` with options such as + `grid`, `lagrange_to_mayer` and `scheme`; +- its constructor + `Collocation(; kwargs...) = Collocation{typeof(values.scheme)}(values, sources)` + is built on `_build_ocp_tool_options(Collocation; ...)`. + +These examples can be used as templates when adding new tools that follow the +`AbstractOCPTool` interface. diff --git a/docs/src/interfaces/optimization_modelers.md b/docs/src/interfaces/optimization_modelers.md new file mode 100644 index 00000000..a70ada59 --- /dev/null +++ b/docs/src/interfaces/optimization_modelers.md @@ -0,0 +1,152 @@ +# Implementing optimization modelers + +This page explains how to implement new optimization modelers in CTModels, +that is, components that take an +[`AbstractOptimizationProblem`](@ref CTModels.AbstractOptimizationProblem) and +produce an NLP back-end model (and optionally map NLP solutions back to +OCP-related objects). + +Modelers implement the +[`AbstractOptimizationModeler`](@ref CTModels.AbstractOptimizationModeler) +interface and are also +[`AbstractOCPTool`](@ref CTModels.AbstractOCPTool)s. This means they follow +both the options interface (see [OCP Tools](ocp_tools.md)) and a calling +interface specific to optimization problems. + +## Overview of the contract + +A concrete modeler type `M` is expected to: + +- subtype `AbstractOptimizationModeler`: + + ```julia + struct MyModeler{Vals,Srcs} <: CTModels.AbstractOptimizationModeler + options_values::Vals + options_sources::Srcs + end + ``` + +- follow the `AbstractOCPTool` options contract (fields, `_option_specs`, + constructor via `_build_ocp_tool_options`); +- implement at least the model-building call: + + ```julia + (modeler::MyModeler)(prob::CTModels.AbstractOptimizationProblem, + initial_guess; kwargs...) = ... + ``` + + which produces the NLP model for the chosen back-end. + +Optionally, the modeler can also implement a second call that maps a back-end +solution back to an OCP-related representation: + +```julia +(modeler::MyModeler)(prob::CTModels.AbstractOptimizationProblem, + nlp_solution::SolverCore.AbstractExecutionStats; + kwargs...) = ... +``` + +Generic fallbacks for both calls are defined on +`AbstractOptimizationModeler` and throw `CTBase.NotImplemented` if they are +not specialized. + +## Implementing the options interface + +Because `AbstractOptimizationModeler <: AbstractOCPTool`, modelers follow the +same options pattern as other tools. See +[OCP Tools](ocp_tools.md) for a detailed discussion. + +In short, a typical modeler definition looks like: + +```julia +struct MyModeler{Vals,Srcs} <: CTModels.AbstractOptimizationModeler + options_values::Vals + options_sources::Srcs +end + +function CTModels._option_specs(::Type{<:MyModeler}) + return ( + show_time = CTModels.OptionSpec(; + type = Bool, + default = false, + description = "Whether to print timing information while building the model.", + ), + # additional options... + ) +end + +function MyModeler(; kwargs...) + values, sources = CTModels._build_ocp_tool_options( + MyModeler; kwargs..., strict_keys = true, + ) + return MyModeler{typeof(values),typeof(sources)}(values, sources) +end +``` + +## Implementing the model-building call + +The functional part of the interface is provided by the call overloads on the +modeler. A minimal pattern, inspired by +[`ADNLPModeler`](@ref CTModels.ADNLPModeler), is: + +```julia +function (modeler::MyModeler)( + prob::CTModels.AbstractOptimizationProblem, + initial_guess; + kwargs..., +) + # Use the generic interface on `AbstractOptimizationProblem` to obtain + # the appropriate builder for this back-end. + builder = CTModels.get_adnlp_model_builder(prob) # or a similar function + + # Merge modeler options with any additional keyword arguments + vals = CTModels._options_values(modeler) + return builder(initial_guess; vals..., kwargs...) +end +``` + +Concrete modelers in CTModels follow this pattern: + +- `ADNLPModeler` dispatches on `get_adnlp_model_builder(prob)` and returns an + `ADNLPModels.ADNLPModel`. +- `ExaModeler` dispatches on `get_exa_model_builder(prob)` and returns an + `ExaModels.ExaModel{BaseType}`. + +## Mapping NLP solutions back to OCP solutions + +Modelers may also provide a second call that converts a back-end NLP solution +into an OCP-related representation, using the solution builders provided by +`AbstractOptimizationProblem`: + +```julia +function (modeler::MyModeler)( + prob::CTModels.AbstractOptimizationProblem, + nlp_solution::SolverCore.AbstractExecutionStats; + kwargs..., +) + builder = CTModels.get_adnlp_solution_builder(prob) + return builder(nlp_solution) +end +``` + +The generic fallback on `AbstractOptimizationModeler` throws +`CTBase.NotImplemented`, so if your modeler does not implement this mapping, +any attempt to call it will result in a clear error. + +## Registration and symbols + +Modelers are often registered in a back-end registry so that they can be +constructed from a symbolic identifier. CTModels, for instance, defines: + +- `REGISTERED_MODELERS` in `nlp_backends.jl`; +- helpers such as `build_modeler_from_symbol(:adnlp; kwargs...)`. + +To integrate a new modeler into such a registry, you typically: + +1. Specialize [`get_symbol`](@ref CTModels.get_symbol) on the modeler type. +2. Optionally specialize + [`tool_package_name`](@ref CTModels.tool_package_name). +3. Add the modeler type to the appropriate `REGISTERED_*` constant. + +See also the [OCP Tools](ocp_tools.md) page for the generic `AbstractOCPTool` interface +and examples such as `ADNLPModeler` and `ExaModeler`. diff --git a/docs/src/interfaces/optimization_problems.md b/docs/src/interfaces/optimization_problems.md new file mode 100644 index 00000000..29e463b3 --- /dev/null +++ b/docs/src/interfaces/optimization_problems.md @@ -0,0 +1,125 @@ +# Implementing new optimization problems + +This page explains how to implement new optimization problem types in +CTModels that follow the +[`AbstractOptimizationProblem`](@ref CTModels.AbstractOptimizationProblem) +interface. + +Optimization problems form the bridge between high-level optimal control +models and low-level NLP back-ends. They expose back-end specific builders +for models and solutions. + +The core of the interface is provided by: + +- the abstract type + [`AbstractOptimizationProblem`](@ref CTModels.AbstractOptimizationProblem); +- a set of generic methods defined in `nlp/problem_core.jl` that dispatch on + `AbstractOptimizationProblem`: + - [`get_adnlp_model_builder`](@ref CTModels.get_adnlp_model_builder) + - [`get_exa_model_builder`](@ref CTModels.get_exa_model_builder) + - [`get_adnlp_solution_builder`](@ref CTModels.get_adnlp_solution_builder) + - [`get_exa_solution_builder`](@ref CTModels.get_exa_solution_builder) + +Each generic function has a default implementation that throws +`CTBase.NotImplemented`. Concrete problem types are expected to specialize +these functions for the back-ends they want to support. + +## Overview of the contract + +A concrete optimization problem type `P` is expected to: + +- subtype `AbstractOptimizationProblem`: + + ```julia + struct MyProblem <: CTModels.AbstractOptimizationProblem + # fields describing the OCP, discretization, etc. + end + ``` + +- store whatever information is needed to (re)build an NLP back-end model + and interpret its solution; +- implement one or more of the `get_*_builder` functions listed above. + +You only need to implement the methods for the back-ends that your problem +supports. For unsupported back-ends, the default `CTBase.NotImplemented` +methods will raise a clear error if they are called. + +## Example: providing builders explicitly + +A simple example (similar to the test helper in +`test/problems/problems_definition.jl`) is to store the builders as fields of +the problem type and just return them from the interface methods: + +```julia +struct OptimizationProblem <: CTModels.AbstractOptimizationProblem + build_adnlp_model::CTModels.ADNLPModelBuilder + build_exa_model::CTModels.ExaModelBuilder + adnlp_solution_builder::CTModels.ADNLPSolutionBuilder + exa_solution_builder::CTModels.ExaSolutionBuilder +end + +function CTModels.get_adnlp_model_builder(prob::OptimizationProblem) + return prob.build_adnlp_model +end + +function CTModels.get_exa_model_builder(prob::OptimizationProblem) + return prob.build_exa_model +end + +function CTModels.get_adnlp_solution_builder(prob::OptimizationProblem) + return prob.adnlp_solution_builder +end + +function CTModels.get_exa_solution_builder(prob::OptimizationProblem) + return prob.exa_solution_builder +end +``` + +In this pattern, the optimization problem is essentially a container for the +four builders. The modelers and other components only interact with the +problem via the `get_*_builder` interface. + +## Example: discretized optimal control problems + +The type +[`DiscretizedOptimalControlProblem`](@ref CTModels.DiscretizedOptimalControlProblem) +provides a more structured example. It stores a high-level OCP model and a +mapping from symbols (e.g. `:adnlp`, `:exa`) to +[`OCPBackendBuilders`](@ref CTModels.OCPBackendBuilders) records: + +```julia +struct DiscretizedOptimalControlProblem{TO<:CTModels.AbstractModel,TB<:NamedTuple} <: + CTModels.AbstractOptimizationProblem + optimal_control_problem::TO + backend_builders::TB +end +``` + +Each `OCPBackendBuilders` value stores a model builder +(`TM <: AbstractModelBuilder`) and a solution builder +(`TS <: AbstractOCPSolutionBuilder`). The `get_*_builder` methods then +retrieve the appropriate entry from the `backend_builders` NamedTuple. + +This design allows the same discretized problem to support multiple NLP +back-ends at once. + +## Relationship with modelers and tools + +Optimization problems do not directly know how to build an NLP model. That +logic lives in modelers, which are subtypes of +[`AbstractOptimizationModeler`](@ref CTModels.AbstractOptimizationModeler) +and also implement the [`AbstractOCPTool`](@ref CTModels.AbstractOCPTool) +interface. + +A typical workflow is: + +1. Construct a `MyProblem <: AbstractOptimizationProblem` that describes the + OCP and its discretization. +2. Construct a modeler tool (e.g. `ADNLPModeler`, `ExaModeler`). +3. The modeler calls `get_*_model_builder(prob)` to obtain the builder for its + back-end, then applies it to the initial guess to obtain an NLP model. +4. After solving the NLP, the modeler may call `get_*_solution_builder(prob)` + to turn the back-end solution into an OCP-related representation. + +For the implementation of modelers and tools, see also +[OCP Tools](ocp_tools.md) and the separate page on optimization modelers. diff --git a/ext/CTModelsJSON.jl b/ext/CTModelsJSON.jl index 11654bcd..72a71f6e 100644 --- a/ext/CTModelsJSON.jl +++ b/ext/CTModelsJSON.jl @@ -5,6 +5,81 @@ using DocStringExtensions using JSON3 +# ============================================================================ +# Helper functions for serializing/deserializing infos Dict{Symbol,Any} +# ============================================================================ + +""" +Convert Dict{Symbol,Any} to Dict{String,Any} for JSON serialization. +Only serializes JSON-compatible types (numbers, strings, bools, arrays, dicts). +""" +function _serialize_infos(infos::Dict{Symbol,Any})::Dict{String,Any} + result = Dict{String,Any}() + for (k, v) in infos + result[string(k)] = _serialize_value(v) + end + return result +end + +""" +Serialize a single value to JSON-compatible format. +""" +function _serialize_value(v) + if v isa Number || v isa String || v isa Bool || isnothing(v) + return v + elseif v isa Symbol + return string(v) + elseif v isa AbstractVector + return [_serialize_value(x) for x in v] + elseif v isa AbstractDict + result = Dict{String,Any}() + for (dk, dv) in v + result[string(dk)] = _serialize_value(dv) + end + return result + else + # For non-serializable types, convert to string representation + return string(v) + end +end + +""" +Convert Dict{String,Any} back to Dict{Symbol,Any} after JSON deserialization. +""" +function _deserialize_infos(blob)::Dict{Symbol,Any} + if isnothing(blob) || isempty(blob) + return Dict{Symbol,Any}() + end + result = Dict{Symbol,Any}() + for (k, v) in blob + result[Symbol(k)] = _deserialize_value(v) + end + return result +end + +""" +Deserialize a single value from JSON format. +""" +function _deserialize_value(v) + if v isa Number || v isa String || v isa Bool || isnothing(v) + return v + elseif v isa AbstractVector + return [_deserialize_value(x) for x in v] + elseif v isa AbstractDict + result = Dict{Symbol,Any}() + for (dk, dv) in v + result[Symbol(dk)] = _deserialize_value(dv) + end + return result + else + return v + end +end + +# ============================================================================ +# Export function +# ============================================================================ + """ $(TYPEDSIGNATURES) @@ -60,6 +135,8 @@ function CTModels.export_ocp_solution( "boundary_constraints_dual" => CTModels.boundary_constraints_dual(sol), # ctVector or Nothing "variable_constraints_lb_dual" => CTModels.variable_constraints_lb_dual(sol), # ctVector or Nothing "variable_constraints_ub_dual" => CTModels.variable_constraints_ub_dual(sol), # ctVector or Nothing + # Additional solver infos (Dict{Symbol,Any} → Dict{String,Any} for JSON) + "infos" => _serialize_infos(CTModels.infos(sol)), ) open(filename * ".json", "w") do io @@ -206,6 +283,13 @@ function CTModels.import_ocp_solution( variable_constraints_ub_dual = Vector{Float64}(blob["variable_constraints_ub_dual"]) end + # get additional solver infos + infos = if haskey(blob, "infos") + _deserialize_infos(blob["infos"]) + else + Dict{Symbol,Any}() + end + # NB. convert vect{vect} to matrix return CTModels.build_solution( ocp, @@ -228,6 +312,7 @@ function CTModels.import_ocp_solution( boundary_constraints_dual=boundary_constraints_dual, variable_constraints_lb_dual=variable_constraints_lb_dual, variable_constraints_ub_dual=variable_constraints_ub_dual, + infos=infos, ) end diff --git a/ext/plot_default.jl b/ext/plot_default.jl index dd55183c..be0f49ec 100644 --- a/ext/plot_default.jl +++ b/ext/plot_default.jl @@ -91,8 +91,8 @@ julia> size = __size_plot(sol, model, :components, :split; ...) ``` """ function __size_plot( - sol::CTModels.Solution, - model::Union{CTModels.Model,Nothing}, + sol::CTModels.AbstractSolution, + model::Union{CTModels.AbstractModel,Nothing}, control::Symbol, layout::Symbol, description::Symbol...; diff --git a/ext/plot_utils.jl b/ext/plot_utils.jl index 18b4474f..49ec789d 100644 --- a/ext/plot_utils.jl +++ b/ext/plot_utils.jl @@ -63,7 +63,7 @@ julia> do_plot(sol, :state, :control, :path; state_style=NamedTuple(), control_s ``` """ function do_plot( - sol::CTModels.Solution, + sol::CTModels.AbstractSolution, description::Symbol...; state_style::Union{NamedTuple,Symbol}, control_style::Union{NamedTuple,Symbol}, @@ -74,7 +74,11 @@ function do_plot( do_plot_state = :state ∈ description && state_style != :none do_plot_costate = :costate ∈ description && costate_style != :none do_plot_control = :control ∈ description && control_style != :none - do_plot_path = :path ∈ description && path_style != :none + ocp = CTModels.model(sol) + do_plot_path = + :path ∈ description && + path_style != :none && + CTModels.dim_path_constraints_nl(ocp) > 0 do_plot_dual = :dual ∈ description && dual_style != :none && diff --git a/profiling/Project.toml b/profiling/Project.toml deleted file mode 100644 index a4e28c1f..00000000 --- a/profiling/Project.toml +++ /dev/null @@ -1,7 +0,0 @@ -[deps] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -CTModels = "34c4fa32-2049-4079-8329-de33c2a22e2d" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -PProf = "e4faabce-9ead-11e9-39d9-4379958e3056" -Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/profiling/boundary.jl b/profiling/boundary.jl deleted file mode 100644 index 0525f771..00000000 --- a/profiling/boundary.jl +++ /dev/null @@ -1,104 +0,0 @@ -begin - using Revise - using CTModels - - using JET - using BenchmarkTools - using Profile - - # define problem with new model: simple integrator - function simple_integrator_model() - pre_ocp = CTModels.PreModel() - CTModels.state!(pre_ocp, 1) - CTModels.control!(pre_ocp, 2) - CTModels.time!(pre_ocp; t0=0.0, tf=1.0) - f!(r, t, x, u, v) = r .= .-x[1] .- u[1] .+ u[2] - CTModels.dynamics!(pre_ocp, f!) - l(t, x, u, v) = (u[1] .+ u[2]) .^ 2 - CTModels.objective!(pre_ocp, :min; lagrange=l) - function bc!(r, x0, xf, v) - r[1] = x0[1] - r[2] = xf[1] - return nothing - end - function bc2!(r, x0, xf, v) - r[1] = x0[1] - r[2] = xf[1] - return nothing - end - CTModels.constraint!( - pre_ocp, :boundary; f=(bc!), lb=[-1, 0], ub=[-1, 0], label=:boundary1 - ) - N = 2 - CTModels.constraint!( - pre_ocp, :boundary; f=(bc!), lb=[-1, 0], ub=[-1, 0], label=:boundary2 - ) - N += 2 - CTModels.constraint!( - pre_ocp, :boundary; f=(bc2!), lb=[-1, 0], ub=[-1, 0], label=:boundary3 - ) - N += 2 - CTModels.constraint!( - pre_ocp, :control; rg=1:2, lb=[0, 0], ub=[Inf, Inf], label=:control_rg - ) - CTModels.definition!(pre_ocp, Expr(:simple_integrator_min_energy)) - ocp = CTModels.build(pre_ocp) - return ocp, N - end - - ocp, N = simple_integrator_model() - - x0 = [1.0] - xf = [0.0] - v = Float64[] - r = zeros(Float64, N) - bc_constraint = CTModels.boundary_constraints_nl(ocp) - boundary! = bc_constraint[2] - boundary!(r, x0, xf, v) - r - - function bc!(r, x0, xf, v) - r[1] = x0[1] - r[2] = xf[1] - return nothing - end -end - -let - println("--------------------------------") - println("Boundary constraint") - @code_warntype bc!(r, x0, xf, v) - println("\n") - println("--------------------------------") - println("Boundary constraint from model") - @code_warntype boundary!(r, x0, xf, v) -end - -let - println("--------------------------------") - println("Boundary constraint") - println(@report_opt bc!(r, x0, xf, v)) - println("--------------------------------") - println("Boundary constraint from model") - println(@report_opt boundary!(r, x0, xf, v)) -end - -let - println("--------------------------------") - println("Boundary constraint") - display(@benchmark bc!(r, x0, xf, v)) - println("\n") - println("--------------------------------") - println("Boundary constraint from model") - display(@benchmark boundary!(r, x0, xf, v)) -end - -let - println("--------------------------------") - println("Boundary constraint") - @code_native debuginfo = :none dump_module = false bc!(r, x0, xf, v) - println("\n") - println("--------------------------------") - println("Boundary constraint from model") - @code_native debuginfo = :none dump_module = false boundary!(r, x0, xf, v) -end diff --git a/profiling/interpolate.jl b/profiling/interpolate.jl deleted file mode 100644 index 4423b3fa..00000000 --- a/profiling/interpolate.jl +++ /dev/null @@ -1,52 +0,0 @@ -begin - using Revise - using CTModels - - using JET - using BenchmarkTools - using Profile - - function make_interpolation() - T = [0, 1] - - A = [ - 0 1 - 2 3 - ] - - V = CTModels.matrix2vec(A, 1) - - f = CTModels.ctinterpolate(T, V) - - for x in LinRange(0, 1, 100) - f(x) - end - - return f(0.5) - end - - let - println("--------------------------------") - println("Make interpolation") - @code_warntype make_interpolation() - end - - # let - # println("--------------------------------") - # println("Make interpolation") - # println(@report_opt make_interpolation()) - # end - - let - println("--------------------------------") - println("Make interpolation") - display(@benchmark make_interpolation()) - end - - # let - # println("--------------------------------") - # println("Make interpolation") - # @code_native debuginfo = :none dump_module = false make_interpolation() - # end - -end diff --git a/src/CTModels.jl b/src/CTModels.jl index 52675531..ae95101c 100644 --- a/src/CTModels.jl +++ b/src/CTModels.jl @@ -21,6 +21,11 @@ using Parameters # @with_kw: to have default values in struct using MacroTools: striplines using RecipesBase: plot, plot!, RecipesBase using OrderedCollections: OrderedDict +using SolverCore +using ADNLPModels +using ExaModels +using KernelAbstractions +using NLPModels # aliases @@ -50,7 +55,7 @@ Type alias for a time. julia> const Time = ctNumber ``` -See also: [`ctNumber`](@ref), [`Times`](@ref), [`TimesDisc`](@ref). +See also: [`ctNumber`](@ref), [`Times`](@ref CTModels.Times), [`TimesDisc`](@ref). """ const Time = ctNumber @@ -83,7 +88,7 @@ Type alias for a grid of times. This is used to define a discretization of time julia> const TimesDisc = Union{Times, StepRangeLen} ``` -See also: [`Time`](@ref), [`Times`](@ref). +See also: [`Time`](@ref), [`Times`](@ref CTModels.Times). """ const TimesDisc = Union{Times,StepRangeLen} @@ -94,18 +99,18 @@ Type alias for a dictionary of constraints. This is used to store constraints be julia> const TimesDisc = Union{Times, StepRangeLen} ``` -See also: [`ConstraintsModel`](@ref), [`PreModel`](@ref) and [`Model`](@ref). +See also: [`ConstraintsModel`](@ref), [`PreModel`](@ref) and [`Model`](@ref CTModels.Model). """ const ConstraintsDictType = OrderedDict{ Symbol,Tuple{Symbol,Union{Function,OrdinalRange{<:Int}},ctVector,ctVector} } # -include("default.jl") +include(joinpath(@__DIR__, "core", "default.jl")) # -include("utils.jl") -include("types.jl") +include(joinpath(@__DIR__, "core", "utils.jl")) +include(joinpath(@__DIR__, "core", "types.jl")) # export / import """ @@ -130,7 +135,7 @@ JSON tag for export/import functions. struct JSON3Tag <: AbstractTag end # ----------------------------- -# to be extended: no docstrings +# to be extended function RecipesBase.plot(sol::AbstractSolution, description::Symbol...; kwargs...) throw(CTBase.ExtensionError(:Plots)) end @@ -152,21 +157,21 @@ function import_ocp_solution(::JSON3Tag, ::AbstractModel; filename::String) end """ -$(TYPEDSIGNATURES) + export_ocp_solution(sol; format=:JLD, filename="solution") -Export a solution in JLD or JSON formats. Redirect to one of the methods: +Export an optimal control solution to a file. -- [`export_ocp_solution(JLD2Tag(), sol, filename=filename)`](@ref export_ocp_solution(::CTModels.JLD2Tag, ::CTModels.Solution)) -- [`export_ocp_solution(JSON3Tag(), sol, filename=filename)`](@ref export_ocp_solution(::CTModels.JSON3Tag, ::CTModels.Solution)) +# Arguments +- `sol::AbstractSolution`: The solution to export. -# Examples +# Keyword Arguments +- `format::Symbol=:JLD`: Export format, either `:JLD` or `:JSON`. +- `filename::String="solution"`: Base filename (extension added automatically). -```julia-repl -julia> using JSON3 -julia> export_ocp_solution(sol; filename="solution", format=:JSON) -julia> using JLD2 -julia> export_ocp_solution(sol; filename="solution", format=:JLD) # JLD is the default -``` +# Notes +Requires loading the appropriate package (`JLD2` or `JSON3`) before use. + +See also: [`import_ocp_solution`](@ref) """ function export_ocp_solution( sol::AbstractSolution; @@ -187,21 +192,24 @@ function export_ocp_solution( end """ -$(TYPEDSIGNATURES) + import_ocp_solution(ocp; format=:JLD, filename="solution") -Import a solution from a JLD or JSON file. Redirect to one of the methods: +Import an optimal control solution from a file. -- [`import_ocp_solution(JLD2Tag(), ocp, filename=filename)`](@ref import_ocp_solution(::CTModels.JLD2Tag, ::CTModels.Model)) -- [`import_ocp_solution(JSON3Tag(), ocp, filename=filename)`](@ref import_ocp_solution(::CTModels.JSON3Tag, ::CTModels.Model)) +# Arguments +- `ocp::AbstractModel`: The model associated with the solution. -# Examples +# Keyword Arguments +- `format::Symbol=:JLD`: Import format, either `:JLD` or `:JSON`. +- `filename::String="solution"`: Base filename (extension added automatically). -```julia-repl -julia> using JSON3 -julia> sol = import_ocp_solution(ocp; filename="solution", format=:JSON) -julia> using JLD2 -julia> sol = import_ocp_solution(ocp; filename="solution", format=:JLD) # JLD is the default -``` +# Returns +- `Solution`: The imported solution. + +# Notes +Requires loading the appropriate package (`JLD2` or `JSON3`) before use. + +See also: [`export_ocp_solution`](@ref) """ function import_ocp_solution( ocp::AbstractModel; @@ -222,22 +230,40 @@ function import_ocp_solution( end # -include("init.jl") -include("dual_model.jl") -include("state.jl") -include("control.jl") -include("variable.jl") -include("times.jl") -include("dynamics.jl") -include("objective.jl") -include("constraints.jl") -include("time_dependence.jl") -include("definition.jl") -include("print.jl") -include("model.jl") -include("solution.jl") +#include("init.jl") +include(joinpath(@__DIR__, "ocp", "dual_model.jl")) +include(joinpath(@__DIR__, "ocp", "state.jl")) +include(joinpath(@__DIR__, "ocp", "control.jl")) +include(joinpath(@__DIR__, "ocp", "variable.jl")) +include(joinpath(@__DIR__, "ocp", "times.jl")) +include(joinpath(@__DIR__, "ocp", "dynamics.jl")) +include(joinpath(@__DIR__, "ocp", "objective.jl")) +include(joinpath(@__DIR__, "ocp", "constraints.jl")) +include(joinpath(@__DIR__, "ocp", "time_dependence.jl")) +include(joinpath(@__DIR__, "ocp", "definition.jl")) +include(joinpath(@__DIR__, "ocp", "print.jl")) +include(joinpath(@__DIR__, "ocp", "model.jl")) +include(joinpath(@__DIR__, "ocp", "solution.jl")) + +# new from CTSolvers +""" +Type alias for [`AbstractModel`](@ref). -# -export plot, plot! +Provides compatibility with CTSolvers naming conventions. +""" +const AbstractOptimalControlProblem = CTModels.AbstractModel -end +""" +Type alias for [`AbstractSolution`](@ref). + +Provides compatibility with CTSolvers naming conventions. +""" +const AbstractOptimalControlSolution = CTModels.AbstractSolution +include(joinpath(@__DIR__, "nlp", "options_schema.jl")) +include(joinpath(@__DIR__, "nlp", "problem_core.jl")) +include(joinpath(@__DIR__, "nlp", "nlp_backends.jl")) +include(joinpath(@__DIR__, "nlp", "discretized_ocp.jl")) +include(joinpath(@__DIR__, "nlp", "model_api.jl")) +include(joinpath(@__DIR__, "init", "initial_guess.jl")) + +end \ No newline at end of file diff --git a/src/default.jl b/src/core/default.jl similarity index 95% rename from src/default.jl rename to src/core/default.jl index a98510af..ffb4c7e3 100644 --- a/src/default.jl +++ b/src/core/default.jl @@ -106,5 +106,8 @@ __matrix_dimension_storage() = 1 """ $(TYPEDSIGNATURES) +Return the default filename (without extension) for exporting and importing solutions. + +The default value is `"solution"`. """ __filename_export_import() = "solution" diff --git a/src/core/types.jl b/src/core/types.jl new file mode 100644 index 00000000..b00cab4f --- /dev/null +++ b/src/core/types.jl @@ -0,0 +1,5 @@ +include(joinpath(@__DIR__, "types", "ocp_components.jl")) +include(joinpath(@__DIR__, "types", "ocp_model.jl")) +include(joinpath(@__DIR__, "types", "ocp_solution.jl")) +include(joinpath(@__DIR__, "types", "nlp.jl")) +include(joinpath(@__DIR__, "types", "initial_guess.jl")) diff --git a/src/core/types/initial_guess.jl b/src/core/types/initial_guess.jl new file mode 100644 index 00000000..ce4facf3 --- /dev/null +++ b/src/core/types/initial_guess.jl @@ -0,0 +1,83 @@ +# ------------------------------------------------------------------------------ # +# Initial guess types for continuous-time OCPs +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for initial guesses used in optimal control problem solvers. + +Subtypes provide initial trajectories for state, control, and optimisation variables +to warm-start numerical solvers. + +See also: [`OptimalControlInitialGuess`](@ref). +""" +abstract type AbstractOptimalControlInitialGuess end + +""" +$(TYPEDEF) + +Concrete initial guess for an optimal control problem, storing callable +trajectories for state and control, and a value for the optimisation variable. + +# Fields + +- `state::X`: A function `t -> x(t)` returning the state guess at time `t`. +- `control::U`: A function `t -> u(t)` returning the control guess at time `t`. +- `variable::V`: The initial guess for the optimisation variable (scalar or vector). + +# Example + +```julia-repl +julia> using CTModels + +julia> x_guess = t -> [cos(t), sin(t)] +julia> u_guess = t -> [0.5] +julia> v_guess = [1.0, 2.0] +julia> ig = CTModels.OptimalControlInitialGuess(x_guess, u_guess, v_guess) +``` +""" +struct OptimalControlInitialGuess{X<:Function,U<:Function,V} <: + AbstractOptimalControlInitialGuess + state::X + control::U + variable::V +end + +""" +$(TYPEDEF) + +Abstract base type for pre-initialisation data used before constructing a full +initial guess. + +Subtypes store raw or partial information that will be processed into an +[`OptimalControlInitialGuess`](@ref). + +See also: [`OptimalControlPreInit`](@ref). +""" +abstract type AbstractOptimalControlPreInit end + +""" +$(TYPEDEF) + +Pre-initialisation container for initial guess data before validation and +interpolation. + +# Fields + +- `state::SX`: Raw state data (e.g., matrix, vector of vectors, or function). +- `control::SU`: Raw control data (e.g., matrix, vector of vectors, or function). +- `variable::SV`: Raw optimisation variable data (scalar, vector, or `nothing`). + +# Example + +```julia-repl +julia> using CTModels + +julia> pre = CTModels.OptimalControlPreInit([1.0 2.0; 3.0 4.0], [0.5, 0.6], [1.0]) +``` +""" +struct OptimalControlPreInit{SX,SU,SV} <: AbstractOptimalControlPreInit + state::SX + control::SU + variable::SV +end diff --git a/src/core/types/nlp.jl b/src/core/types/nlp.jl new file mode 100644 index 00000000..0b3e81dc --- /dev/null +++ b/src/core/types/nlp.jl @@ -0,0 +1,396 @@ +# ------------------------------------------------------------------------------ # +# NLP backends and optimization problem types +# (tools, builders, modelers, discretized optimal control problem) +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for configurable tools in CTModels (backends, discretizers, +solvers, etc.). + +Subtypes of `AbstractOCPTool` are expected to follow a common options +interface so they can be configured and introspected in a uniform way. + +# Interface contract + +Concrete subtypes `T <: AbstractOCPTool` are expected to: + +- store two fields + - `options_values::NamedTuple` — current option values. + - `options_sources::NamedTuple` — provenance for each option + (`:ct_default` or `:user`). +- optionally provide option metadata by specializing + [`_option_specs`](@ref CTModels._option_specs), returning a `NamedTuple` of + [`OptionSpec`](@ref) values. +- typically define a keyword-only constructor + `T(; kwargs...)` implemented using [`_build_ocp_tool_options`](@ref), so + that user-supplied keywords are validated and merged with tool defaults. + +Most helper functions in the options schema (see `nlp/options_schema.jl`) +operate generically on any subtype that satisfies this contract. +""" +abstract type AbstractOCPTool end + +""" +$(TYPEDEF) + +Metadata for a single named option of an [`AbstractOCPTool`](@ref). + +Each field describes one aspect of the option: + +- `type` — expected Julia type for the option value, or `missing` if + no static type information is available. +- `default` — default value when the option is not supplied by the user, + or `missing` if there is no default. +- `description` — short human-readable description of the option, or + `missing` if it is not yet documented. + +Instances of `OptionSpec` are typically returned from `_option_specs(::Type)` +in a `NamedTuple`, one field per option name. +""" +struct OptionSpec + type::Any # Expected Julia type for the option value, or `missing` if unknown. + default::Any + description::Any # Short English description (String) or `missing` if not documented yet. +end + +""" +$(TYPEDEF) + +Common supertype for builder objects used in the NLP back-end +infrastructure. + +`AbstractBuilder` itself does not impose a concrete calling interface; +specialized subtypes such as [`AbstractModelBuilder`](@ref) and +[`AbstractOCPSolutionBuilder`](@ref) define looser contracts that are +documented on their own abstract types and concrete implementations. +""" +abstract type AbstractBuilder end + +""" +$(TYPEDEF) + +Abstract base type for builders that construct NLP back-end models from +an [`AbstractOptimizationProblem`](@ref). + +Concrete subtypes (for example [`ADNLPModelBuilder`](@ref) and +[`ExaModelBuilder`](@ref)) are expected to be callable objects that +encapsulate the logic for building a model for a specific NLP back-end. +The exact call signature is back-end dependent and therefore not fixed at +the level of `AbstractModelBuilder`. +""" +abstract type AbstractModelBuilder <: AbstractBuilder end + +""" +$(TYPEDEF) + +Builder for constructing ADNLPModels-based NLP models from an +[`AbstractOptimizationProblem`](@ref). + +# Fields + +- `f::T`: A callable that builds the ADNLPModel when invoked. + +Concrete implementations are typically returned by high-level +optimisation modelling interfaces and are not created directly by users. + +See also: [`ExaModelBuilder`](@ref), [`AbstractModelBuilder`](@ref). +""" +struct ADNLPModelBuilder{T<:Function} <: AbstractModelBuilder + f::T +end + +""" +$(TYPEDEF) + +Builder for constructing ExaModels-based NLP models from an +[`AbstractOptimizationProblem`](@ref). + +# Fields + +- `f::T`: A callable that builds the ExaModel when invoked. + +See also: [`ADNLPModelBuilder`](@ref), [`AbstractModelBuilder`](@ref). +""" +struct ExaModelBuilder{T<:Function} <: AbstractModelBuilder + f::T +end + +""" +$(TYPEDEF) + +Abstract base type for builders that transform NLP solutions into other +representations (for example, solutions of an optimal control problem). + +Subtypes are expected to be callable, but the abstract type does not fix +the argument types. More specific contracts are documented on +[`AbstractOCPSolutionBuilder`](@ref) and related concrete types. +""" +abstract type AbstractSolutionBuilder <: AbstractBuilder end + +""" +$(TYPEDEF) + +Abstract base type for optimization problems built from optimal control +problems. + +Subtypes of `AbstractOptimizationProblem` are typically paired with +[`AbstractModelBuilder`](@ref) and [`AbstractSolutionBuilder`](@ref) +implementations that know how to construct and interpret NLP back-end +models and solutions. +""" +abstract type AbstractOptimizationProblem end + +""" +$(TYPEDEF) + +Abstract base type for NLP modelers built on top of +[`AbstractOptimizationProblem`](@ref). + +Subtypes of `AbstractOptimizationModeler` are also `AbstractOCPTool`s +and therefore follow the generic options interface: they store +`options_values` and `options_sources` fields and are typically +constructed using [`_build_ocp_tool_options`](@ref). + +Concrete modelers such as [`ADNLPModeler`](@ref) and +[`ExaModeler`](@ref) dispatch on an `AbstractOptimizationProblem` to +build NLP models and map NLP solutions back to OCP solutions. +""" +abstract type AbstractOptimizationModeler <: AbstractOCPTool end + +""" +$(TYPEDSIGNATURES) + +Interface method for [`AbstractOptimizationModeler`](@ref). + +Concrete modelers are expected to specialize this call to build an NLP +model from an [`AbstractOptimizationProblem`](@ref) and an initial +guess. The default implementation throws a +`CTBase.NotImplemented` error. +""" +function (modeler::AbstractOptimizationModeler)( + prob::AbstractOptimizationProblem, + initial_guess; + kwargs..., +) + throw( + CTBase.NotImplemented( + "model-building call not implemented for $(typeof(modeler))", + ), + ) +end + +""" +$(TYPEDSIGNATURES) + +Interface method for [`AbstractOptimizationModeler`](@ref). + +Concrete modelers may specialize this call to map an NLP back-end +solution (for example `SolverCore.AbstractExecutionStats`) back to a +solution associated with the original +[`AbstractOptimizationProblem`](@ref). The default implementation throws +`CTBase.NotImplemented`. +""" +function (modeler::AbstractOptimizationModeler)( + prob::AbstractOptimizationProblem, + nlp_solution::SolverCore.AbstractExecutionStats; + kwargs..., +) + throw( + CTBase.NotImplemented( + "solution-building call not implemented for $(typeof(modeler))", + ), + ) +end + +""" +$(TYPEDEF) + +Concrete [`AbstractOptimizationModeler`](@ref) based on `ADNLPModels.jl`. + +`ADNLPModeler` implements the [`AbstractOCPTool`](@ref) options +interface: it stores `options_values` and `options_sources`, defines an +`_option_specs` specialisation describing its options, and is +constructed via [`_build_ocp_tool_options`](@ref). + +# Fields + +- `options_values::Vals`: Named tuple of current option values. +- `options_sources::Srcs`: Named tuple indicating source of each option (`:ct_default` or `:user`). + +See also: [`ExaModeler`](@ref), [`AbstractOptimizationModeler`](@ref). +""" +struct ADNLPModeler{Vals,Srcs} <: AbstractOptimizationModeler + options_values::Vals + options_sources::Srcs +end + +""" +$(TYPEDEF) + +Concrete [`AbstractOptimizationModeler`](@ref) based on `ExaModels.jl`. + +Like [`ADNLPModeler`](@ref), this type follows the +[`AbstractOCPTool`](@ref) options interface and is configured via +[`_build_ocp_tool_options`](@ref). It additionally fixes a +`BaseType<:AbstractFloat` parameter that controls the floating-point +type of the underlying ExaModel. + +# Fields + +- `options_values::Vals`: Named tuple of current option values. +- `options_sources::Srcs`: Named tuple indicating source of each option (`:ct_default` or `:user`). + +# Type Parameters + +- `BaseType<:AbstractFloat`: Floating-point type for the ExaModel (e.g., `Float64`). + +See also: [`ADNLPModeler`](@ref), [`AbstractOptimizationModeler`](@ref). +""" +struct ExaModeler{BaseType<:AbstractFloat,Vals,Srcs} <: AbstractOptimizationModeler + options_values::Vals + options_sources::Srcs +end + +""" +$(TYPEDEF) + +Abstract base type for builders that turn NLP back-end execution +statistics into objects associated with a discretized optimal control +problem (for example, an OCP solution or intermediate representation). + +Concrete subtypes are expected to be callable on a +`SolverCore.AbstractExecutionStats` value. A generic fallback method is +provided (see below) that throws `CTBase.NotImplemented` if a concrete +builder does not implement the call. + +See also: [`ADNLPSolutionBuilder`](@ref), [`ExaSolutionBuilder`](@ref). +""" +abstract type AbstractOCPSolutionBuilder <: AbstractSolutionBuilder end + +""" +$(TYPEDSIGNATURES) + +Interface method for [`AbstractOCPSolutionBuilder`](@ref). + +Concrete OCP solution builders are expected to specialize this method to +convert NLP execution statistics into an appropriate representation. The +default implementation throws a `CTBase.NotImplemented` error. +""" +function (builder::AbstractOCPSolutionBuilder)( + nlp_solution::SolverCore.AbstractExecutionStats; + kwargs..., +) + throw( + CTBase.NotImplemented( + "OCP solution builder not implemented for $(typeof(builder))", + ), + ) +end + +""" +$(TYPEDEF) + +Solution builder for ADNLPModels-based solvers. + +Converts NLP execution statistics into an optimal control solution. + +# Fields + +- `f::T`: A callable that builds the OCP solution from NLP stats. + +See also: [`ExaSolutionBuilder`](@ref), [`AbstractOCPSolutionBuilder`](@ref). +""" +struct ADNLPSolutionBuilder{T<:Function} <: AbstractOCPSolutionBuilder + f::T +end + +""" +$(TYPEDEF) + +Solution builder for ExaModels-based solvers. + +Converts NLP execution statistics into an optimal control solution. + +# Fields + +- `f::T`: A callable that builds the OCP solution from NLP stats. + +See also: [`ADNLPSolutionBuilder`](@ref), [`AbstractOCPSolutionBuilder`](@ref). +""" +struct ExaSolutionBuilder{T<:Function} <: AbstractOCPSolutionBuilder + f::T +end + +""" +$(TYPEDEF) + +Container pairing a model builder with its corresponding solution builder. + +# Fields + +- `model::TM`: The model builder (e.g., [`ADNLPModelBuilder`](@ref)). +- `solution::TS`: The solution builder (e.g., [`ADNLPSolutionBuilder`](@ref)). + +See also: [`DiscretizedOptimalControlProblem`](@ref). +""" +struct OCPBackendBuilders{TM<:AbstractModelBuilder,TS<:AbstractOCPSolutionBuilder} + model::TM + solution::TS +end + +""" +$(TYPEDEF) + +Discretised optimal control problem ready for NLP solving. + +Wraps an optimal control problem together with backend builders for +multiple NLP backends (e.g., ADNLPModels and ExaModels). + +# Fields + +- `optimal_control_problem::TO`: The original optimal control problem model. +- `backend_builders::TB`: Named tuple mapping backend symbols to [`OCPBackendBuilders`](@ref). + +# Example + +```julia-repl +julia> using CTModels + +julia> # Typically constructed internally by discretisation routines +julia> docp = CTModels.DiscretizedOptimalControlProblem(ocp, backend_builders) +``` +""" +struct DiscretizedOptimalControlProblem{TO<:AbstractModel,TB<:NamedTuple} <: + AbstractOptimizationProblem + optimal_control_problem::TO + backend_builders::TB + function DiscretizedOptimalControlProblem( + optimal_control_problem::TO, backend_builders::TB + ) where {TO<:AbstractModel,TB<:NamedTuple} + return new{TO,TB}(optimal_control_problem, backend_builders) + end + function DiscretizedOptimalControlProblem( + optimal_control_problem::AbstractModel, + backend_builders::Tuple{Vararg{Pair{Symbol,<:OCPBackendBuilders}}}, + ) + return DiscretizedOptimalControlProblem( + optimal_control_problem, (; backend_builders...) + ) + end + function DiscretizedOptimalControlProblem( + optimal_control_problem::AbstractModel, + adnlp_model_builder::ADNLPModelBuilder, + exa_model_builder::ExaModelBuilder, + adnlp_solution_builder::ADNLPSolutionBuilder, + exa_solution_builder::ExaSolutionBuilder, + ) + return DiscretizedOptimalControlProblem( + optimal_control_problem, + ( + :adnlp => OCPBackendBuilders(adnlp_model_builder, adnlp_solution_builder), + :exa => OCPBackendBuilders(exa_model_builder, exa_solution_builder), + ), + ) + end +end diff --git a/src/core/types/ocp_components.jl b/src/core/types/ocp_components.jl new file mode 100644 index 00000000..2492e97e --- /dev/null +++ b/src/core/types/ocp_components.jl @@ -0,0 +1,491 @@ +# ------------------------------------------------------------------------------ # +# Continuous-time OCP component types +# (time dependence, state/control/variable models, time models, objectives, constraints) +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type representing time dependence of an optimal control problem. + +Used as a type parameter to distinguish between autonomous and non-autonomous +systems at the type level, enabling dispatch and compile-time optimisations. + +See also: [`Autonomous`](@ref), [`NonAutonomous`](@ref). +""" +abstract type TimeDependence end + +""" +$(TYPEDEF) + +Type tag indicating that the dynamics and other functions of an optimal control +problem do not explicitly depend on time. + +For autonomous systems, the dynamics have the form `ẋ = f(x, u)` rather than +`ẋ = f(t, x, u)`. + +See also: [`TimeDependence`](@ref), [`NonAutonomous`](@ref). +""" +abstract type Autonomous<:TimeDependence end + +""" +$(TYPEDEF) + +Type tag indicating that the dynamics and other functions of an optimal control +problem explicitly depend on time. + +For non-autonomous systems, the dynamics have the form `ẋ = f(t, x, u)`. + +See also: [`TimeDependence`](@ref), [`Autonomous`](@ref). +""" +abstract type NonAutonomous<:TimeDependence end + +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for state variable models in optimal control problems. + +Subtypes describe the state space structure including dimension, naming, and +optionally the state trajectory itself. + +See also: [`StateModel`](@ref), [`StateModelSolution`](@ref). +""" +abstract type AbstractStateModel end + +""" +$(TYPEDEF) + +State model describing the structure of the state variable in an optimal control +problem definition. + +# Fields + +- `name::String`: Display name for the state variable (e.g., `"x"`). +- `components::Vector{String}`: Names of individual state components (e.g., `["x₁", "x₂"]`). + +# Example + +```julia-repl +julia> using CTModels + +julia> sm = CTModels.StateModel("x", ["position", "velocity"]) +``` +""" +struct StateModel <: AbstractStateModel + name::String + components::Vector{String} +end + +""" +$(TYPEDEF) + +State model for a solved optimal control problem, including the state trajectory. + +# Fields + +- `name::String`: Display name for the state variable. +- `components::Vector{String}`: Names of individual state components. +- `value::TS`: A function `t -> x(t)` returning the state vector at time `t`. + +# Example + +```julia-repl +julia> using CTModels + +julia> x_traj = t -> [cos(t), sin(t)] +julia> sms = CTModels.StateModelSolution("x", ["x₁", "x₂"], x_traj) +julia> sms.value(0.0) +2-element Vector{Float64}: + 1.0 + 0.0 +``` +""" +struct StateModelSolution{TS<:Function} <: AbstractStateModel + name::String + components::Vector{String} + value::TS +end + +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for control variable models in optimal control problems. + +Subtypes describe the control space structure including dimension, naming, and +optionally the control trajectory itself. + +See also: [`ControlModel`](@ref), [`ControlModelSolution`](@ref). +""" +abstract type AbstractControlModel end + +""" +$(TYPEDEF) + +Control model describing the structure of the control variable in an optimal +control problem definition. + +# Fields + +- `name::String`: Display name for the control variable (e.g., `"u"`). +- `components::Vector{String}`: Names of individual control components (e.g., `["u₁", "u₂"]`). + +# Example + +```julia-repl +julia> using CTModels + +julia> cm = CTModels.ControlModel("u", ["thrust", "steering"]) +``` +""" +struct ControlModel <: AbstractControlModel + name::String + components::Vector{String} +end + +""" +$(TYPEDEF) + +Control model for a solved optimal control problem, including the control trajectory. + +# Fields + +- `name::String`: Display name for the control variable. +- `components::Vector{String}`: Names of individual control components. +- `value::TS`: A function `t -> u(t)` returning the control vector at time `t`. + +# Example + +```julia-repl +julia> using CTModels + +julia> u_traj = t -> [sin(t)] +julia> cms = CTModels.ControlModelSolution("u", ["u₁"], u_traj) +julia> cms.value(π/2) +1-element Vector{Float64}: + 1.0 +``` +""" +struct ControlModelSolution{TS<:Function} <: AbstractControlModel + name::String + components::Vector{String} + value::TS +end + +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for optimisation variable models in optimal control problems. + +Optimisation variables are decision variables that do not depend on time, such as +free final time or unknown parameters. + +See also: [`VariableModel`](@ref), [`EmptyVariableModel`](@ref), [`VariableModelSolution`](@ref). +""" +abstract type AbstractVariableModel end + +""" +$(TYPEDEF) + +Variable model describing the structure of the optimisation variable in an optimal +control problem definition. + +# Fields + +- `name::String`: Display name for the variable (e.g., `"v"`). +- `components::Vector{String}`: Names of individual variable components (e.g., `["tf", "λ"]`). + +# Example + +```julia-repl +julia> using CTModels + +julia> vm = CTModels.VariableModel("v", ["final_time", "parameter"]) +``` +""" +struct VariableModel <: AbstractVariableModel + name::String + components::Vector{String} +end + +""" +$(TYPEDEF) + +Sentinel type representing the absence of optimisation variables in an optimal +control problem. + +Used when the problem has no free parameters or free final time. + +# Example + +```julia-repl +julia> using CTModels + +julia> evm = CTModels.EmptyVariableModel() +``` +""" +struct EmptyVariableModel <: AbstractVariableModel end + +""" +$(TYPEDEF) + +Variable model for a solved optimal control problem, including the variable value. + +# Fields + +- `name::String`: Display name for the variable. +- `components::Vector{String}`: Names of individual variable components. +- `value::TS`: The optimisation variable value (scalar or vector). + +# Example + +```julia-repl +julia> using CTModels + +julia> vms = CTModels.VariableModelSolution("v", ["tf"], 2.5) +julia> vms.value +2.5 +``` +""" +struct VariableModelSolution{TS<:Union{ctNumber,ctVector}} <: AbstractVariableModel + name::String + components::Vector{String} + value::TS +end + +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for time boundary models (initial or final time). + +Subtypes represent either fixed or free time boundaries in an optimal control +problem. + +See also: [`FixedTimeModel`](@ref), [`FreeTimeModel`](@ref). +""" +abstract type AbstractTimeModel end + +""" +$(TYPEDEF) + +Time model representing a fixed (known) time boundary. + +# Fields + +- `time::T`: The fixed time value. +- `name::String`: Display name for this time (e.g., `"t₀"` or `"tf"`). + +# Example + +```julia-repl +julia> using CTModels + +julia> t0 = CTModels.FixedTimeModel(0.0, "t₀") +julia> t0.time +0.0 +``` +""" +struct FixedTimeModel{T<:Time} <: AbstractTimeModel + time::T + name::String +end + +""" +$(TYPEDEF) + +Time model representing a free (optimised) time boundary. + +The actual time value is stored in the optimisation variable at the given index. + +# Fields + +- `index::Int`: Index into the optimisation variable where this time is stored. +- `name::String`: Display name for this time (e.g., `"tf"`). + +# Example + +```julia-repl +julia> using CTModels + +julia> tf = CTModels.FreeTimeModel(1, "tf") +julia> tf.index +1 +``` +""" +struct FreeTimeModel <: AbstractTimeModel + index::Int + name::String +end + +""" +$(TYPEDEF) + +Abstract base type for combined initial and final time models. + +See also: [`TimesModel`](@ref). +""" +abstract type AbstractTimesModel end + +""" +$(TYPEDEF) + +Combined model for initial and final times in an optimal control problem. + +# Fields + +- `initial::TI`: The initial time model (fixed or free). +- `final::TF`: The final time model (fixed or free). +- `time_name::String`: Display name for the time variable (e.g., `"t"`). + +# Example + +```julia-repl +julia> using CTModels + +julia> t0 = CTModels.FixedTimeModel(0.0, "t₀") +julia> tf = CTModels.FixedTimeModel(1.0, "tf") +julia> times = CTModels.TimesModel(t0, tf, "t") +``` +""" +struct TimesModel{TI<:AbstractTimeModel,TF<:AbstractTimeModel} <: AbstractTimesModel + initial::TI + final::TF + time_name::String +end + +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for objective function models in optimal control problems. + +Subtypes represent different forms of the cost functional: Mayer (terminal cost), +Lagrange (integral cost), or Bolza (both). + +See also: [`MayerObjectiveModel`](@ref), [`LagrangeObjectiveModel`](@ref), [`BolzaObjectiveModel`](@ref). +""" +abstract type AbstractObjectiveModel end + +""" +$(TYPEDEF) + +Objective model with only a Mayer (terminal) cost: `g(x(t₀), x(tf), v)`. + +# Fields + +- `mayer::TM`: The Mayer cost function `(x0, xf, v) -> g(x0, xf, v)`. +- `criterion::Symbol`: Optimisation direction, either `:min` or `:max`. + +# Example + +```julia-repl +julia> using CTModels + +julia> g = (x0, xf, v) -> xf[1]^2 +julia> obj = CTModels.MayerObjectiveModel(g, :min) +``` +""" +struct MayerObjectiveModel{TM<:Function} <: AbstractObjectiveModel + mayer::TM + criterion::Symbol +end + +""" +$(TYPEDEF) + +Objective model with only a Lagrange (integral) cost: `∫ f⁰(t, x, u, v) dt`. + +# Fields + +- `lagrange::TL`: The Lagrange integrand `(t, x, u, v) -> f⁰(t, x, u, v)`. +- `criterion::Symbol`: Optimisation direction, either `:min` or `:max`. + +# Example + +```julia-repl +julia> using CTModels + +julia> f0 = (t, x, u, v) -> u[1]^2 +julia> obj = CTModels.LagrangeObjectiveModel(f0, :min) +``` +""" +struct LagrangeObjectiveModel{TL<:Function} <: AbstractObjectiveModel + lagrange::TL + criterion::Symbol +end + +""" +$(TYPEDEF) + +Objective model with both Mayer and Lagrange costs (Bolza form): +`g(x(t₀), x(tf), v) + ∫ f⁰(t, x, u, v) dt`. + +# Fields + +- `mayer::TM`: The Mayer cost function `(x0, xf, v) -> g(x0, xf, v)`. +- `lagrange::TL`: The Lagrange integrand `(t, x, u, v) -> f⁰(t, x, u, v)`. +- `criterion::Symbol`: Optimisation direction, either `:min` or `:max`. + +# Example + +```julia-repl +julia> using CTModels + +julia> g = (x0, xf, v) -> xf[1]^2 +julia> f0 = (t, x, u, v) -> u[1]^2 +julia> obj = CTModels.BolzaObjectiveModel(g, f0, :min) +``` +""" +struct BolzaObjectiveModel{TM<:Function,TL<:Function} <: AbstractObjectiveModel + mayer::TM + lagrange::TL + criterion::Symbol +end + +# ------------------------------------------------------------------------------ # +# Constraints +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for constraint models in optimal control problems. + +Subtypes store all constraint information including path constraints, boundary +constraints, and box constraints on state, control, and variables. + +See also: [`ConstraintsModel`](@ref). +""" +abstract type AbstractConstraintsModel end + +""" +$(TYPEDEF) + +Container for all constraints in an optimal control problem. + +# Fields + +- `path_nl::TP`: Tuple of nonlinear path constraints `(t, x, u, v) -> c(t, x, u, v)`. +- `boundary_nl::TB`: Tuple of nonlinear boundary constraints `(x0, xf, v) -> b(x0, xf, v)`. +- `state_box::TS`: Tuple of box constraints on state variables (lower/upper bounds). +- `control_box::TC`: Tuple of box constraints on control variables (lower/upper bounds). +- `variable_box::TV`: Tuple of box constraints on optimisation variables (lower/upper bounds). + +# Example + +```julia-repl +julia> using CTModels + +julia> # Typically constructed internally by the model builder +julia> cm = CTModels.ConstraintsModel((), (), (), (), ()) +``` +""" +struct ConstraintsModel{TP<:Tuple,TB<:Tuple,TS<:Tuple,TC<:Tuple,TV<:Tuple} <: + AbstractConstraintsModel + path_nl::TP + boundary_nl::TB + state_box::TS + control_box::TC + variable_box::TV +end diff --git a/src/core/types/ocp_model.jl b/src/core/types/ocp_model.jl new file mode 100644 index 00000000..2af26fb2 --- /dev/null +++ b/src/core/types/ocp_model.jl @@ -0,0 +1,353 @@ +# ------------------------------------------------------------------------------ # +# Continuous-time OCP model types (Model, PreModel and consistency helpers) +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for optimal control problem models. + +Subtypes represent either a fully built immutable model ([`Model`](@ref CTModels.Model)) or a +mutable model under construction ([`PreModel`](@ref)). + +See also: [`Model`](@ref CTModels.Model), [`PreModel`](@ref). +""" +abstract type AbstractModel end + +""" +$(TYPEDEF) + +Immutable optimal control problem model containing all problem components. + +A `Model` is created from a [`PreModel`](@ref) once all required fields have been +set. It is parameterised by the time dependence type (`Autonomous` or `NonAutonomous`) +and the types of all its components. + +# Fields + +- `times::TimesModelType`: Initial and final time specification. +- `state::StateModelType`: State variable structure (name, components). +- `control::ControlModelType`: Control variable structure (name, components). +- `variable::VariableModelType`: Optimisation variable structure (may be empty). +- `dynamics::DynamicsModelType`: System dynamics function `(t, x, u, v) -> ẋ`. +- `objective::ObjectiveModelType`: Cost functional (Mayer, Lagrange, or Bolza). +- `constraints::ConstraintsModelType`: All problem constraints. +- `definition::Expr`: Original symbolic definition of the problem. +- `build_examodel::BuildExaModelType`: Optional ExaModels builder function. + +# Example + +```julia-repl +julia> using CTModels + +julia> # Models are typically created via the @def macro or PreModel +julia> ocp = CTModels.Model # Type reference +``` +""" +struct Model{ + TD<:TimeDependence, + TimesModelType<:AbstractTimesModel, + StateModelType<:AbstractStateModel, + ControlModelType<:AbstractControlModel, + VariableModelType<:AbstractVariableModel, + DynamicsModelType<:Function, + ObjectiveModelType<:AbstractObjectiveModel, + ConstraintsModelType<:AbstractConstraintsModel, + BuildExaModelType<:Union{Function,Nothing}, +} <: AbstractModel + times::TimesModelType + state::StateModelType + control::ControlModelType + variable::VariableModelType + dynamics::DynamicsModelType + objective::ObjectiveModelType + constraints::ConstraintsModelType + definition::Expr + build_examodel::BuildExaModelType + + function Model{TD}( # TD must be specified explicitly + times::AbstractTimesModel, + state::AbstractStateModel, + control::AbstractControlModel, + variable::AbstractVariableModel, + dynamics::Function, + objective::AbstractObjectiveModel, + constraints::AbstractConstraintsModel, + definition::Expr, + build_examodel::Union{Function,Nothing}, + ) where {TD<:TimeDependence} + return new{ + TD, + typeof(times), + typeof(state), + typeof(control), + typeof(variable), + typeof(dynamics), + typeof(objective), + typeof(constraints), + typeof(build_examodel), + }( + times, + state, + control, + variable, + dynamics, + objective, + constraints, + definition, + build_examodel, + ) + end +end + +""" +$(TYPEDSIGNATURES) + +Return `true` since times are always set in a built [`Model`](@ref CTModels.Model). +""" +__is_times_set(ocp::Model)::Bool = true + +""" +$(TYPEDSIGNATURES) + +Return `true` since state is always set in a built [`Model`](@ref CTModels.Model). +""" +__is_state_set(ocp::Model)::Bool = true + +""" +$(TYPEDSIGNATURES) + +Return `true` since control is always set in a built [`Model`](@ref CTModels.Model). +""" +__is_control_set(ocp::Model)::Bool = true + +""" +$(TYPEDSIGNATURES) + +Return `true` since variable is always set in a built [`Model`](@ref CTModels.Model). +""" +__is_variable_set(ocp::Model)::Bool = true + +""" +$(TYPEDSIGNATURES) + +Return `true` since dynamics is always set in a built [`Model`](@ref CTModels.Model). +""" +__is_dynamics_set(ocp::Model)::Bool = true + +""" +$(TYPEDSIGNATURES) + +Return `true` since objective is always set in a built [`Model`](@ref CTModels.Model). +""" +__is_objective_set(ocp::Model)::Bool = true + +""" +$(TYPEDSIGNATURES) + +Return `true` since definition is always set in a built [`Model`](@ref CTModels.Model). +""" +__is_definition_set(ocp::Model)::Bool = true + +""" +$(TYPEDEF) + +Mutable optimal control problem model under construction. + +A `PreModel` is used to incrementally define an optimal control problem before +building it into an immutable [`Model`](@ref CTModels.Model). Fields can be set in any order +and the model is validated before building. + +# Fields + +- `times::Union{AbstractTimesModel,Nothing}`: Initial and final time specification. +- `state::Union{AbstractStateModel,Nothing}`: State variable structure. +- `control::Union{AbstractControlModel,Nothing}`: Control variable structure. +- `variable::AbstractVariableModel`: Optimisation variable (defaults to empty). +- `dynamics::Union{Function,Vector,Nothing}`: System dynamics (function or component-wise). +- `objective::Union{AbstractObjectiveModel,Nothing}`: Cost functional. +- `constraints::ConstraintsDictType`: Dictionary of constraints being built. +- `definition::Union{Expr,Nothing}`: Symbolic definition expression. +- `autonomous::Union{Bool,Nothing}`: Whether the system is autonomous. + +# Example + +```julia-repl +julia> using CTModels + +julia> pre = CTModels.PreModel() +julia> # Set fields incrementally... +``` +""" +@with_kw mutable struct PreModel <: AbstractModel + times::Union{AbstractTimesModel,Nothing} = nothing + state::Union{AbstractStateModel,Nothing} = nothing + control::Union{AbstractControlModel,Nothing} = nothing + variable::AbstractVariableModel = EmptyVariableModel() + dynamics::Union{Function,Vector{<:Tuple{<:AbstractRange{<:Int},<:Function}},Nothing} = + nothing + objective::Union{AbstractObjectiveModel,Nothing} = nothing + constraints::ConstraintsDictType = ConstraintsDictType() + definition::Union{Expr,Nothing} = nothing + autonomous::Union{Bool,Nothing} = nothing +end + +""" +$(TYPEDSIGNATURES) + +Return `true` if `x` is not `nothing`. +""" +__is_set(x) = !isnothing(x) + +""" +$(TYPEDSIGNATURES) + +Return `true` if the autonomous flag has been set in the [`PreModel`](@ref). +""" +__is_autonomous_set(ocp::PreModel)::Bool = __is_set(ocp.autonomous) + +""" +$(TYPEDSIGNATURES) + +Return `true` if times have been set in the [`PreModel`](@ref). +""" +__is_times_set(ocp::PreModel)::Bool = __is_set(ocp.times) + +""" +$(TYPEDSIGNATURES) + +Return `true` if state has been set in the [`PreModel`](@ref). +""" +__is_state_set(ocp::PreModel)::Bool = __is_set(ocp.state) + +""" +$(TYPEDSIGNATURES) + +Return `true` if control has been set in the [`PreModel`](@ref). +""" +__is_control_set(ocp::PreModel)::Bool = __is_set(ocp.control) + +""" +$(TYPEDSIGNATURES) + +Return `true` if `v` is an [`EmptyVariableModel`](@ref). +""" +__is_variable_empty(v) = v isa EmptyVariableModel + +""" +$(TYPEDSIGNATURES) + +Return `true` if a non-empty variable has been set in the [`PreModel`](@ref). +""" +__is_variable_set(ocp::PreModel)::Bool = !__is_variable_empty(ocp.variable) + +""" +$(TYPEDSIGNATURES) + +Return `true` if dynamics have been set in the [`PreModel`](@ref). +""" +__is_dynamics_set(ocp::PreModel)::Bool = __is_set(ocp.dynamics) + +""" +$(TYPEDSIGNATURES) + +Return `true` if objective has been set in the [`PreModel`](@ref). +""" +__is_objective_set(ocp::PreModel)::Bool = __is_set(ocp.objective) + +""" +$(TYPEDSIGNATURES) + +Return `true` if definition has been set in the [`PreModel`](@ref). +""" +__is_definition_set(ocp::PreModel)::Bool = __is_set(ocp.definition) + +""" +$(TYPEDSIGNATURES) + +Return the state dimension of the [`PreModel`](@ref). + +Throws `CTBase.UnauthorizedCall` if state has not been set. +""" +function state_dimension(ocp::PreModel)::Dimension + @ensure(__is_state_set(ocp), CTBase.UnauthorizedCall("the state must be set.")) + return length(ocp.state.components) +end + +""" +$(TYPEDSIGNATURES) + +Return `true` if dynamics cover all state components in the [`PreModel`](@ref). + +For component-wise dynamics, checks that all state indices are covered. +""" +function __is_dynamics_complete(ocp::PreModel)::Bool + if isnothing(ocp.dynamics) + return false + elseif ocp.dynamics isa Function + return true + else # ocp.dynamics isa Vector{<:Tuple{<:AbstractRange{<:Int},<:Function}} + @ensure(__is_state_set(ocp), CTBase.UnauthorizedCall("the state must be set.")) + n = state_dimension(ocp) + covered = falses(n) + for (range, _) in ocp.dynamics + for i in range + if 1 <= i <= n + covered[i] = true + else + throw( + CTBase.UnauthorizedCall( + "Dynamics index $i out of bounds for state of size $n." + ), + ) + end + end + end + return all(covered) + end +end + +""" +$(TYPEDSIGNATURES) + +Return true if all the required fields are set in the PreModel. +""" +function __is_consistent(ocp::PreModel)::Bool + return __is_times_set(ocp) && + __is_state_set(ocp) && + __is_control_set(ocp) && + __is_dynamics_complete(ocp) && + __is_objective_set(ocp) && + __is_autonomous_set(ocp) +end + +""" +$(TYPEDSIGNATURES) + +Return true if the PreModel can be built into a Model. +""" +function __is_complete(ocp::PreModel)::Bool + return __is_times_set(ocp) && + __is_state_set(ocp) && + __is_control_set(ocp) && + __is_dynamics_complete(ocp) && + __is_objective_set(ocp) && + __is_definition_set(ocp) && + __is_autonomous_set(ocp) +end + +""" +$(TYPEDSIGNATURES) + +Return true if nothing has been set. +""" +function __is_empty(ocp::PreModel)::Bool + return !__is_times_set(ocp) && + !__is_state_set(ocp) && + !__is_control_set(ocp) && + !__is_dynamics_set(ocp) && + !__is_objective_set(ocp) && + !__is_definition_set(ocp) && + !__is_variable_set(ocp) && + !__is_autonomous_set(ocp) && + Base.isempty(ocp.constraints) +end diff --git a/src/core/types/ocp_solution.jl b/src/core/types/ocp_solution.jl new file mode 100644 index 00000000..68d381bf --- /dev/null +++ b/src/core/types/ocp_solution.jl @@ -0,0 +1,239 @@ +# ------------------------------------------------------------------------------ # +# Continuous-time OCP solution-related types +# (time grids, solver infos, dual variables, Solution) +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for time grid models used in optimal control solutions. + +Subtypes store the discretised time points at which the solution is evaluated. + +See also: [`TimeGridModel`](@ref), [`EmptyTimeGridModel`](@ref). +""" +abstract type AbstractTimeGridModel end + +""" +$(TYPEDEF) + +Time grid model storing the discretised time points of a solution. + +# Fields + +- `value::T`: Vector or range of time points (e.g., `LinRange(0, 1, 100)`). + +# Example + +```julia-repl +julia> using CTModels + +julia> tg = CTModels.TimeGridModel(LinRange(0, 1, 101)) +julia> length(tg.value) +101 +``` +""" +struct TimeGridModel{T<:TimesDisc} <: AbstractTimeGridModel + value::T +end + +""" +$(TYPEDEF) + +Sentinel type representing an empty or uninitialised time grid. + +Used when a solution does not yet have an associated time discretisation. + +# Example + +```julia-repl +julia> using CTModels + +julia> etg = CTModels.EmptyTimeGridModel() +``` +""" +struct EmptyTimeGridModel <: AbstractTimeGridModel end + +is_empty(model::EmptyTimeGridModel)::Bool = true +is_empty(model::TimeGridModel)::Bool = false + +# ------------------------------------------------------------------------------ # +# Solver infos +""" +$(TYPEDEF) + +Abstract base type for solver information associated with an optimal control solution. + +Subtypes store metadata about the numerical solution process. + +See also: [`SolverInfos`](@ref). +""" +abstract type AbstractSolverInfos end + +""" +$(TYPEDEF) + +Solver information and statistics from the numerical solution process. + +# Fields + +- `iterations::Int`: Number of iterations performed by the solver. +- `status::Symbol`: Termination status (e.g., `:first_order`, `:max_iter`). +- `message::String`: Human-readable message describing the termination status. +- `successful::Bool`: Whether the solver converged successfully. +- `constraints_violation::Float64`: Maximum constraint violation at the solution. +- `infos::TI`: Dictionary of additional solver-specific information. + +# Example + +```julia-repl +julia> using CTModels + +julia> si = CTModels.SolverInfos(100, :first_order, "Converged", true, 1e-8, Dict{Symbol,Any}()) +julia> si.successful +true +``` +""" +struct SolverInfos{V,TI<:Dict{Symbol,V}} <: AbstractSolverInfos + iterations::Int + status::Symbol + message::String + successful::Bool + constraints_violation::Float64 + infos::TI +end + +# ------------------------------------------------------------------------------ # +# Constraints and dual variables for the solutions +""" +$(TYPEDEF) + +Abstract base type for dual variable models in optimal control solutions. + +Subtypes store Lagrange multipliers (dual variables) associated with constraints. + +See also: [`DualModel`](@ref). +""" +abstract type AbstractDualModel end + +""" +$(TYPEDEF) + +Dual variables (Lagrange multipliers) for all constraints in an optimal control solution. + +# Fields + +- `path_constraints_dual::PC_Dual`: Multipliers for path constraints `t -> μ(t)`, or `nothing`. +- `boundary_constraints_dual::BC_Dual`: Multipliers for boundary constraints (vector), or `nothing`. +- `state_constraints_lb_dual::SC_LB_Dual`: Multipliers for state lower bounds `t -> ν⁻(t)`, or `nothing`. +- `state_constraints_ub_dual::SC_UB_Dual`: Multipliers for state upper bounds `t -> ν⁺(t)`, or `nothing`. +- `control_constraints_lb_dual::CC_LB_Dual`: Multipliers for control lower bounds `t -> ω⁻(t)`, or `nothing`. +- `control_constraints_ub_dual::CC_UB_Dual`: Multipliers for control upper bounds `t -> ω⁺(t)`, or `nothing`. +- `variable_constraints_lb_dual::VC_LB_Dual`: Multipliers for variable lower bounds (vector), or `nothing`. +- `variable_constraints_ub_dual::VC_UB_Dual`: Multipliers for variable upper bounds (vector), or `nothing`. + +# Example + +```julia-repl +julia> using CTModels + +julia> # Typically constructed internally by the solver +julia> dm = CTModels.DualModel(nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing) +``` +""" +struct DualModel{ + PC_Dual<:Union{Function,Nothing}, + BC_Dual<:Union{ctVector,Nothing}, + SC_LB_Dual<:Union{Function,Nothing}, + SC_UB_Dual<:Union{Function,Nothing}, + CC_LB_Dual<:Union{Function,Nothing}, + CC_UB_Dual<:Union{Function,Nothing}, + VC_LB_Dual<:Union{ctVector,Nothing}, + VC_UB_Dual<:Union{ctVector,Nothing}, +} <: AbstractDualModel + path_constraints_dual::PC_Dual + boundary_constraints_dual::BC_Dual + state_constraints_lb_dual::SC_LB_Dual + state_constraints_ub_dual::SC_UB_Dual + control_constraints_lb_dual::CC_LB_Dual + control_constraints_ub_dual::CC_UB_Dual + variable_constraints_lb_dual::VC_LB_Dual + variable_constraints_ub_dual::VC_UB_Dual +end + +# ------------------------------------------------------------------------------ # +# Solution +# ------------------------------------------------------------------------------ # +""" +$(TYPEDEF) + +Abstract base type for optimal control problem solutions. + +Subtypes store the complete solution including primal trajectories, dual variables, +and solver information. + +See also: [`Solution`](@ref). +""" +abstract type AbstractSolution end + +""" +$(TYPEDEF) + +Complete solution of an optimal control problem. + +Stores the optimal state, control, and costate trajectories, the optimisation +variable value, objective value, dual variables, solver information, and a +reference to the original model. + +# Fields + +- `time_grid::TimeGridModelType`: Discretised time points. +- `times::TimesModelType`: Initial and final time specification. +- `state::StateModelType`: State trajectory `t -> x(t)` with metadata. +- `control::ControlModelType`: Control trajectory `t -> u(t)` with metadata. +- `variable::VariableModelType`: Optimisation variable value with metadata. +- `costate::CostateModelType`: Costate (adjoint) trajectory `t -> p(t)`. +- `objective::ObjectiveValueType`: Optimal objective value. +- `dual::DualModelType`: Dual variables for all constraints. +- `solver_infos::SolverInfosType`: Solver statistics and status. +- `model::ModelType`: Reference to the original optimal control problem. + +# Example + +```julia-repl +julia> using CTModels + +julia> # Solutions are typically returned by solvers +julia> sol = solve(ocp, ...) # Returns a Solution +julia> CTModels.objective(sol) +``` +""" +struct Solution{ + TimeGridModelType<:AbstractTimeGridModel, + TimesModelType<:AbstractTimesModel, + StateModelType<:AbstractStateModel, + ControlModelType<:AbstractControlModel, + VariableModelType<:AbstractVariableModel, + CostateModelType<:Function, + ObjectiveValueType<:ctNumber, + DualModelType<:AbstractDualModel, + SolverInfosType<:AbstractSolverInfos, + ModelType<:AbstractModel, +} <: AbstractSolution + time_grid::TimeGridModelType + times::TimesModelType + state::StateModelType + control::ControlModelType + variable::VariableModelType + costate::CostateModelType + objective::ObjectiveValueType + dual::DualModelType + solver_infos::SolverInfosType + model::ModelType +end + +""" +$(TYPEDSIGNATURES) + +Check if the time grid is empty from the solution. +""" +is_empty_time_grid(sol::Solution)::Bool = is_empty(sol.time_grid) diff --git a/src/utils.jl b/src/core/utils.jl similarity index 100% rename from src/utils.jl rename to src/core/utils.jl diff --git a/src/init.jl b/src/init.jl deleted file mode 100644 index b7e911b0..00000000 --- a/src/init.jl +++ /dev/null @@ -1,243 +0,0 @@ -""" -$(TYPEDSIGNATURES) - -Check if actual dimension is equal to target dimension, error otherwise -""" -function checkDim(actual_dim, target_dim) - if !isnothing(target_dim) && actual_dim != target_dim - error("Init dimension mismatch: got ", actual_dim, " instead of ", target_dim) - end - return nothing -end - -""" -$(TYPEDSIGNATURES) - -Return true if argument is a vector of vectors -""" -function isaVectVect(data) - return (data isa AbstractVector) && (data[1] isa AbstractVector) -end - -""" -$(TYPEDSIGNATURES) - -Convert matrix to vector of vectors (could be expanded) -""" -function formatData(data) - if data isa Matrix - return matrix2vec(data, 1) - else - return data - end -end - -""" -$(TYPEDSIGNATURES) - -Convert matrix time-grid to vector -""" -function formatTimeGrid(time) - if isnothing(time) - return nothing - elseif time isa AbstractVector - return time - else - return vec(time) - end -end - -""" -$(TYPEDSIGNATURES) - -Build functional initialization: default case -""" -function buildFunctionalInit(data::Nothing, time, dim) - # fallback to method-dependent default initialization - return t -> nothing -end - -""" -$(TYPEDSIGNATURES) - -Build functional initialization: function case -""" -function buildFunctionalInit(data::Function, time, dim) - # functional initialization - checkDim(length(data(0)), dim) - return t -> data(t) -end - -""" -$(TYPEDSIGNATURES) - -Build functional initialization: constant / 1D interpolation -""" -function buildFunctionalInit(data::Union{ctNumber,ctVector}, time, dim) - if !isnothing(time) && (length(data) == length(time)) - # interpolation vs time, dim 1 case - itp = ctinterpolate(time, data) - return t -> itp(t) - else - # constant initialization - checkDim(length(data), dim) - return t -> data - end -end - -""" -$(TYPEDSIGNATURES) - -Build functional initialization: general interpolation case -""" -function buildFunctionalInit(data, time, dim) - if isaVectVect(data) - # interpolation vs time, general case - itp = ctinterpolate(time, data) - checkDim(length(itp(0)), dim) - return t -> itp(t) - else - error("Unrecognized initialization argument: ", typeof(data)) - end -end - -""" -$(TYPEDSIGNATURES) - -Build vector initialization: default / vector case -""" -function buildVectorInit(data, dim) - if isnothing(data) - return data - else - checkDim(length(data), dim) - return data - end -end - -""" -$(TYPEDSIGNATURES) - -Initial guess for OCP, contains -- functions of time for the state and control variables -- vector for optimization variables -Initialization data for each field can be left to default or: -- vector for optimization variables -- constant / vector / function for state and control -- existing solution ('warm start') for all fields - -# Constructors: - -- `Init()`: default initialization -- `Init(state, control, variable, time)`: constant vector, function handles and / or matrices / vectors interpolated along given time grid -- `Init(sol)`: from existing solution - -# Examples - -```julia-repl -julia> init = Init() -julia> init = Init(state=[0.1, 0.2], control=0.3) -julia> init = Init(state=[0.1, 0.2], control=0.3, variable=0.5) -julia> init = Init(state=[0.1, 0.2], controlt=t->sin(t), variable=0.5) -julia> init = Init(state=[[0, 0], [1, 2], [5, -1]], time=[0, .3, 1.], controlt=t->sin(t)) -julia> init = Init(sol) -``` - -""" -mutable struct Init - state_init::Function - control_init::Function - variable_init::Union{Nothing,ctNumber,ctVector} - #costate_init::Function - #multipliers_init::Union{Nothing, ctVector} - - """ - $(TYPEDSIGNATURES) - - Init base constructor with separate explicit arguments - """ - function Init(; - state=nothing, - control=nothing, - variable=nothing, - time=nothing, - state_dim=nothing, - control_dim=nothing, - variable_dim=nothing, - ) - init = new() - - # some matrix / vector conversions - time = formatTimeGrid(time) - state = formatData(state) - control = formatData(control) - - # set initialization for x, u, v - init.state_init = buildFunctionalInit(state, time, state_dim) - init.control_init = buildFunctionalInit(control, time, control_dim) - init.variable_init = buildVectorInit(variable, variable_dim) - - return init - end - - """ - $(TYPEDSIGNATURES) - - Init constructor with arguments grouped as named tuple or dict - """ - function Init(init_data; state_dim=nothing, control_dim=nothing, variable_dim=nothing) - - # trivial case: default init - x_init = nothing - u_init = nothing - v_init = nothing - t_init = nothing - - # parse arguments - if !isnothing(init_data) - for key in keys(init_data) - if key == :state - x_init = init_data[:state] - elseif key == :control - u_init = init_data[:control] - elseif key == :variable - v_init = init_data[:variable] - elseif key == :time - t_init = init_data[:time] - else - error( - "Unknown key in initialization data (allowed: state, control, variable, time, state_dim, control_dim, variable_dim): ", - key, - ) - end - end - end - - # call base constructor - return Init(; - state=x_init, - control=u_init, - variable=v_init, - time=t_init, - state_dim=state_dim, - control_dim=control_dim, - variable_dim=variable_dim, - ) - end - - """ - $(TYPEDSIGNATURES) - - Init constructor with solution as argument (warm start) - """ - function Init(sol::Solution; unused_kwargs...) - return Init(; - state=state(sol), - control=control(sol), - variable=variable(sol), - state_dim=state_dimension(sol), - control_dim=control_dimension(sol), - variable_dim=variable_dimension(sol), - ) - end -end diff --git a/src/init/initial_guess.jl b/src/init/initial_guess.jl new file mode 100644 index 00000000..795f8eaa --- /dev/null +++ b/src/init/initial_guess.jl @@ -0,0 +1,1018 @@ +# ------------------------------------------------------------------------------ +# Initial guess +# ------------------------------------------------------------------------------ +""" +$(TYPEDSIGNATURES) + +Create a pre-initialisation object for an initial guess. + +This function creates an [`OptimalControlPreInit`](@ref) that can later be +processed into a full [`OptimalControlInitialGuess`](@ref). + +# Arguments + +- `state`: Raw state initialisation data (function, vector, matrix, or `nothing`). +- `control`: Raw control initialisation data (function, vector, matrix, or `nothing`). +- `variable`: Raw variable initialisation data (scalar, vector, or `nothing`). + +# Returns + +- `OptimalControlPreInit`: A pre-initialisation container. + +# Example + +```julia-repl +julia> using CTModels + +julia> pre = CTModels.pre_initial_guess(state=t -> [0.0, 0.0], control=t -> [1.0]) +``` +""" +function pre_initial_guess(; state=nothing, control=nothing, variable=nothing) + return OptimalControlPreInit(state, control, variable) +end + +""" +$(TYPEDSIGNATURES) + +Construct a validated initial guess for an optimal control problem. + +Builds an [`OptimalControlInitialGuess`](@ref) from the provided state, control, +and variable data, validating dimensions against the problem definition. + +# Arguments + +- `ocp::AbstractOptimalControlProblem`: The optimal control problem. +- `state`: State initialisation (function `t -> x(t)`, constant, vector, or `nothing`). +- `control`: Control initialisation (function `t -> u(t)`, constant, vector, or `nothing`). +- `variable`: Variable initialisation (scalar, vector, or `nothing`). + +# Returns + +- `OptimalControlInitialGuess`: A validated initial guess. + +# Example + +```julia-repl +julia> using CTModels + +julia> init = CTModels.initial_guess(ocp; state=t -> [0.0, 0.0], control=t -> [1.0]) +``` +""" +function initial_guess( + ocp::AbstractOptimalControlProblem; + state::Union{Nothing,Function,Real,Vector{<:Real}}=nothing, + control::Union{Nothing,Function,Real,Vector{<:Real}}=nothing, + variable::Union{Nothing,Real,Vector{<:Real}}=nothing, +) + x = initial_state(ocp, state) + u = initial_control(ocp, control) + v = initial_variable(ocp, variable) + init = OptimalControlInitialGuess(x, u, v) + return _validate_initial_guess(ocp, init) +end + +""" +$(TYPEDSIGNATURES) + +Return the state function directly when provided as a function. +""" +initial_state(::AbstractOptimalControlProblem, state::Function) = state + +""" +$(TYPEDSIGNATURES) + +Convert a scalar state value to a constant function for 1D state problems. + +Throws `CTBase.IncorrectArgument` if the state dimension is not 1. +""" +function initial_state(ocp::AbstractOptimalControlProblem, state::Real) + dim = state_dimension(ocp) + if dim == 1 + return t -> state + else + msg = "Initial state dimension mismatch: got scalar for state dimension $dim" + throw(CTBase.IncorrectArgument(msg)) + end +end + +""" +$(TYPEDSIGNATURES) + +Build an initialisation function combining block-level and component-level data. + +Merges a base initialisation with per-component overrides. +""" +function _build_block_with_components( + ocp::AbstractOptimalControlProblem, role::Symbol, block_data, comp_data::Dict{Int,Any} +) + dim = role === :state ? state_dimension(ocp) : control_dimension(ocp) + base_fun = begin + if block_data === nothing + if role === :state + initial_state(ocp, nothing) + else + initial_control(ocp, nothing) + end + elseif block_data isa Tuple && length(block_data) == 2 + # Per-block time grid: (time, data) + T, data = block_data + time = _format_time_grid(T) + _build_time_dependent_init(ocp, role, data, time) + else + if role === :state + initial_state(ocp, block_data) + else + initial_control(ocp, block_data) + end + end + end + + if isempty(comp_data) + return base_fun + end + + comp_funs = Dict{Int,Function}() + for (i, data) in comp_data + comp_funs[i] = _build_component_function(data) + end + + return t -> begin + base_val = base_fun(t) + vec = if dim == 1 + if base_val isa AbstractVector + copy(base_val) + else + [base_val] + end + else + if !(base_val isa AbstractVector) || length(base_val) != dim + msg = string( + "Block-level ", + role, + " initial guess produced value of incompatible dimension: got ", + (base_val isa AbstractVector ? length(base_val) : 1), + " instead of ", + dim, + ) + throw(CTBase.IncorrectArgument(msg)) + end + collect(base_val) + end + + for (i, fi) in comp_funs + val = fi(t) + val_scalar = if val isa AbstractVector + if length(val) != 1 + msg = string( + "Component-level ", + role, + " initial guess must be scalar or length-1 vector for index ", + i, + ".", + ) + throw(CTBase.IncorrectArgument(msg)) + end + val[1] + else + val + end + if !(1 <= i <= dim) + msg = string( + "Component index ", + i, + " out of bounds for ", + role, + " dimension ", + dim, + ".", + ) + throw(CTBase.IncorrectArgument(msg)) + end + vec[i] = val_scalar + end + return dim == 1 ? vec[1] : vec + end +end + +""" +$(TYPEDSIGNATURES) + +Build a component-level initialisation function from data. + +Handles both time-dependent `(time, data)` tuples and time-independent data. +""" +function _build_component_function(data) + # Support (time, data) tuples for per-component time grids + if data isa Tuple && length(data) == 2 + T, val = data + time = _format_time_grid(T) + return _build_component_function_with_time(val, time) + else + return _build_component_function_without_time(data) + end +end + +""" +$(TYPEDSIGNATURES) + +Build a component function from time-independent data (scalar, vector, or function). +""" +function _build_component_function_without_time(data) + if data isa Function + return data + elseif data isa Real + return t -> data + elseif data isa AbstractVector{<:Real} + if length(data) == 1 + c = data[1] + return t -> c + else + msg = "Component-level initialization without time must be scalar or length-1 vector." + throw(CTBase.IncorrectArgument(msg)) + end + else + msg = string( + "Unsupported component-level initialization type without time: ", typeof(data) + ) + throw(CTBase.IncorrectArgument(msg)) + end +end + +""" +$(TYPEDSIGNATURES) + +Build a component function from data with an associated time grid. + +Interpolates vector data over the time grid. +""" +function _build_component_function_with_time(data, time::AbstractVector) + if data isa Function + return data + elseif data isa Real + return t -> data + elseif data isa AbstractVector{<:Real} + if length(data) == length(time) + itp = ctinterpolate(time, data) + return t -> itp(t) + elseif length(data) == 1 + c = data[1] + return t -> c + else + msg = string( + "Component-level initialization time-grid mismatch: got ", + length(data), + " samples for ", + length(time), + "-point time grid.", + ) + throw(CTBase.IncorrectArgument(msg)) + end + else + msg = string( + "Unsupported component-level initialization type with time grid: ", typeof(data) + ) + throw(CTBase.IncorrectArgument(msg)) + end +end + +""" +$(TYPEDSIGNATURES) + +Convert a state vector to a constant function. + +Throws `CTBase.IncorrectArgument` if the vector length does not match the state dimension. +""" +function initial_state(ocp::AbstractOptimalControlProblem, state::Vector{<:Real}) + dim = state_dimension(ocp) + if length(state) != dim + msg = string( + "Initial state dimension mismatch: got ", length(state), " instead of ", dim + ) + throw(CTBase.IncorrectArgument(msg)) + end + return t -> state +end + +""" +$(TYPEDSIGNATURES) + +Return a default state initialisation function when no state is provided. + +Returns a constant function yielding `0.1` (scalar) or `fill(0.1, dim)` (vector). +""" +function initial_state(ocp::AbstractOptimalControlProblem, ::Nothing) + dim = state_dimension(ocp) + if dim == 1 + return t -> 0.1 + else + return t -> fill(0.1, dim) + end +end + +""" +$(TYPEDSIGNATURES) + +Return the control function directly when provided as a function. +""" +initial_control(::AbstractOptimalControlProblem, control::Function) = control + +""" +$(TYPEDSIGNATURES) + +Convert a scalar control value to a constant function for 1D control problems. + +Throws `CTBase.IncorrectArgument` if the control dimension is not 1. +""" +function initial_control(ocp::AbstractOptimalControlProblem, control::Real) + dim = control_dimension(ocp) + if dim == 1 + return t -> control + else + msg = "Initial control dimension mismatch: got scalar for control dimension $dim" + throw(CTBase.IncorrectArgument(msg)) + end +end + +""" +$(TYPEDSIGNATURES) + +Convert a control vector to a constant function. + +Throws `CTBase.IncorrectArgument` if the vector length does not match the control dimension. +""" +function initial_control(ocp::AbstractOptimalControlProblem, control::Vector{<:Real}) + dim = control_dimension(ocp) + if length(control) != dim + msg = string( + "Initial control dimension mismatch: got ", length(control), " instead of ", dim + ) + throw(CTBase.IncorrectArgument(msg)) + end + return t -> control +end + +""" +$(TYPEDSIGNATURES) + +Return a default control initialisation function when no control is provided. + +Returns a constant function yielding `0.1` (scalar) or `fill(0.1, dim)` (vector). +""" +function initial_control(ocp::AbstractOptimalControlProblem, ::Nothing) + dim = control_dimension(ocp) + if dim == 1 + return t -> 0.1 + else + return t -> fill(0.1, dim) + end +end + +""" +$(TYPEDSIGNATURES) + +Return a scalar variable value for 1D variable problems. + +Throws `CTBase.IncorrectArgument` if the variable dimension is not 1. +""" +function initial_variable(ocp::AbstractOptimalControlProblem, variable::Real) + dim = variable_dimension(ocp) + if dim == 0 + msg = "Initial variable dimension mismatch: got scalar for variable dimension 0" + throw(CTBase.IncorrectArgument(msg)) + elseif dim == 1 + return variable + else + msg = "Initial variable dimension mismatch: got scalar for variable dimension $dim" + throw(CTBase.IncorrectArgument(msg)) + end +end + +""" +$(TYPEDSIGNATURES) + +Return a variable vector. + +Throws `CTBase.IncorrectArgument` if the vector length does not match the variable dimension. +""" +function initial_variable(ocp::AbstractOptimalControlProblem, variable::Vector{<:Real}) + dim = variable_dimension(ocp) + if length(variable) != dim + msg = string( + "Initial variable dimension mismatch: got ", + length(variable), + " instead of ", + dim, + ) + throw(CTBase.IncorrectArgument(msg)) + end + return variable +end + +""" +$(TYPEDSIGNATURES) + +Return a default variable initialisation when no variable is provided. + +Returns an empty vector if `dim == 0`, `0.1` if `dim == 1`, or `fill(0.1, dim)` otherwise. +""" +function initial_variable(ocp::AbstractOptimalControlProblem, ::Nothing) + dim = variable_dimension(ocp) + if dim == 0 + return Float64[] + else + if dim == 1 + return 0.1 + else + return fill(0.1, dim) + end + end +end + +""" +$(TYPEDSIGNATURES) + +Extract the state trajectory function from an initial guess. +""" +function state(init::OptimalControlInitialGuess{X,<:Function})::X where {X<:Function} + return init.state +end + +""" +$(TYPEDSIGNATURES) + +Extract the control trajectory function from an initial guess. +""" +function control(init::OptimalControlInitialGuess{<:Function,U})::U where {U<:Function} + return init.control +end + +""" +$(TYPEDSIGNATURES) + +Extract the variable value from an initial guess. +""" +function variable( + init::OptimalControlInitialGuess{<: Function,<: Function,V} +)::V where {V<:Union{Real,Vector{<:Real}}} + return init.variable +end + +""" +$(TYPEDSIGNATURES) + +Validate an initial guess against an optimal control problem. + +Checks that the dimensions of state, control, and variable match the problem +definition. Returns the validated initial guess or throws an error. + +# Arguments + +- `ocp::AbstractOptimalControlProblem`: The optimal control problem. +- `init::AbstractOptimalControlInitialGuess`: The initial guess to validate. + +# Returns + +- The validated initial guess. + +# Throws + +- `CTBase.IncorrectArgument` if dimensions do not match. +""" +function validate_initial_guess( + ocp::AbstractOptimalControlProblem, init::AbstractOptimalControlInitialGuess +) + if init isa OptimalControlInitialGuess + return _validate_initial_guess(ocp, init) + else + # For now, only OptimalControlInitialGuess is supported. + return init + end +end + +""" +$(TYPEDSIGNATURES) + +Internal validation of an [`OptimalControlInitialGuess`](@ref). + +Samples the state and control functions at a test time and verifies dimensions. +""" +function _validate_initial_guess( + ocp::AbstractOptimalControlProblem, init::OptimalControlInitialGuess +) + # Dimensions from the OCP + xdim = state_dimension(ocp) + udim = control_dimension(ocp) + vdim = variable_dimension(ocp) + + # Sample evaluation time; for autonomous/non-autonomous problems + # the shape of x(t), u(t) is independent of t. + v0 = variable(init) + tsample = if has_fixed_initial_time(ocp) + initial_time(ocp) + else + initial_time(ocp, v0) + end + + # State + x0 = state(init)(tsample) + if xdim == 1 + if !(x0 isa Real) && !(x0 isa AbstractVector && length(x0) == 1) + msg = "Initial state function must return a scalar or length-1 vector for state dimension 1." + throw(CTBase.IncorrectArgument(msg)) + end + else + if !(x0 isa AbstractVector) || length(x0) != xdim + msg = string( + "Initial state function returns value of incompatible dimension: got ", + (x0 isa AbstractVector ? length(x0) : 1), + " instead of ", + xdim, + ) + throw(CTBase.IncorrectArgument(msg)) + end + end + + # Control + u0 = control(init)(tsample) + if udim == 1 + if !(u0 isa Real) && !(u0 isa AbstractVector && length(u0) == 1) + msg = "Initial control function must return a scalar or length-1 vector for control dimension 1." + throw(CTBase.IncorrectArgument(msg)) + end + else + if !(u0 isa AbstractVector) || length(u0) != udim + msg = string( + "Initial control function returns value of incompatible dimension: got ", + (u0 isa AbstractVector ? length(u0) : 1), + " instead of ", + udim, + ) + throw(CTBase.IncorrectArgument(msg)) + end + end + + # Variable + if vdim == 0 + if v0 isa AbstractVector + if length(v0) != 0 + msg = "Initial variable has non-zero length for problem with no variable." + throw(CTBase.IncorrectArgument(msg)) + end + elseif v0 isa Real + msg = "Initial variable is scalar for problem with no variable." + throw(CTBase.IncorrectArgument(msg)) + end + elseif vdim == 1 + if !(v0 isa Real) && !(v0 isa AbstractVector && length(v0) == 1) + msg = "Initial variable must be a scalar or length-1 vector for variable dimension 1." + throw(CTBase.IncorrectArgument(msg)) + end + else + if !(v0 isa AbstractVector) || length(v0) != vdim + msg = string( + "Initial variable has incompatible dimension: got ", + (v0 isa AbstractVector ? length(v0) : 1), + " instead of ", + vdim, + ) + throw(CTBase.IncorrectArgument(msg)) + end + end + + return init +end + +""" +$(TYPEDSIGNATURES) + +Build an initial guess from various input formats. + +Accepts multiple input types and converts them to an [`OptimalControlInitialGuess`](@ref): +- `nothing` or `()`: Returns default initial guess. +- `AbstractOptimalControlInitialGuess`: Returns as-is. +- `AbstractOptimalControlPreInit`: Converts from pre-initialisation. +- `AbstractSolution`: Warm-starts from a previous solution. +- `NamedTuple`: Parses named fields for state, control, and variable. + +# Arguments + +- `ocp::AbstractOptimalControlProblem`: The optimal control problem. +- `init_data`: The initial guess data in one of the supported formats. + +# Returns + +- `OptimalControlInitialGuess`: A validated initial guess. + +# Example + +```julia-repl +julia> using CTModels + +julia> init = CTModels.build_initial_guess(ocp, (state=t -> [0.0], control=t -> [1.0])) +``` +""" +function build_initial_guess(ocp::AbstractOptimalControlProblem, init_data) + if init_data === nothing || init_data === () + return initial_guess(ocp) + elseif init_data isa AbstractOptimalControlInitialGuess + return init_data + elseif init_data isa AbstractOptimalControlPreInit + return _initial_guess_from_preinit(ocp, init_data) + elseif init_data isa AbstractSolution + return _initial_guess_from_solution(ocp, init_data) + elseif init_data isa NamedTuple + return _initial_guess_from_namedtuple(ocp, init_data) + else + msg = "Unsupported initial guess type: $(typeof(init_data))" + throw(CTBase.IncorrectArgument(msg)) + end +end + +""" +$(TYPEDSIGNATURES) + +Build an initial guess from a previous solution (warm start). + +Extracts state, control, and variable trajectories from the solution and validates +dimensions against the current problem. +""" +function _initial_guess_from_solution( + ocp::AbstractOptimalControlProblem, sol::AbstractSolution +) + # Basic dimensional consistency checks + if state_dimension(ocp) != state_dimension(sol.model) + msg = "Warm start: state dimension mismatch between ocp and solution." + throw(CTBase.IncorrectArgument(msg)) + end + if control_dimension(ocp) != control_dimension(sol.model) + msg = "Warm start: control dimension mismatch between ocp and solution." + throw(CTBase.IncorrectArgument(msg)) + end + if variable_dimension(ocp) != variable_dimension(sol.model) + msg = "Warm start: variable dimension mismatch between ocp and solution." + throw(CTBase.IncorrectArgument(msg)) + end + + state_fun = state(sol) + control_fun = control(sol) + variable_val = variable(sol) + + init = OptimalControlInitialGuess(state_fun, control_fun, variable_val) + return _validate_initial_guess(ocp, init) +end + +""" +$(TYPEDSIGNATURES) + +Build an initial guess from a `NamedTuple`. + +Parses keys for state, control, variable (by name or component) and constructs +the appropriate initialisation functions. +""" +function _initial_guess_from_namedtuple( + ocp::AbstractOptimalControlProblem, init_data::NamedTuple +) + # Names and component maps from the OCP + s_name_sym = Symbol(state_name(ocp)) + u_name_sym = Symbol(control_name(ocp)) + v_name_sym = Symbol(variable_name(ocp)) + + s_comp_syms = Symbol.(state_components(ocp)) + u_comp_syms = Symbol.(control_components(ocp)) + v_comp_syms = Symbol.(variable_components(ocp)) + + s_comp_index = Dict(sym => i for (i, sym) in enumerate(s_comp_syms)) + u_comp_index = Dict(sym => i for (i, sym) in enumerate(u_comp_syms)) + v_comp_index = Dict(sym => i for (i, sym) in enumerate(v_comp_syms)) + + # Block-level and component-level specs + state_block = nothing + control_block = nothing + variable_block = nothing + state_block_set = false + control_block_set = false + variable_block_set = false + state_comp = Dict{Int,Any}() + control_comp = Dict{Int,Any}() + variable_comp = Dict{Int,Any}() + + # Parse keys and enforce uniqueness + for (k, v) in pairs(init_data) + if k == :time + msg = "Global :time in initial guess NamedTuple is not supported. Provide time grids per block or component as (time, data) tuples." + throw(CTBase.IncorrectArgument(msg)) + elseif k == :variable || k == v_name_sym + if variable_block_set || !isempty(variable_comp) + msg = "Variable initial guess specified both at block level and component level, or multiple block-level entries." + throw(CTBase.IncorrectArgument(msg)) + end + variable_block = v + variable_block_set = true + elseif k == :state || k == s_name_sym + if state_block_set || !isempty(state_comp) + msg = "State initial guess specified both at block level and component level, or multiple block-level entries." + throw(CTBase.IncorrectArgument(msg)) + end + state_block = v + state_block_set = true + elseif k == :control || k == u_name_sym + if control_block_set || !isempty(control_comp) + msg = "Control initial guess specified both at block level and component level, or multiple block-level entries." + throw(CTBase.IncorrectArgument(msg)) + end + control_block = v + control_block_set = true + elseif haskey(s_comp_index, k) + if state_block_set + msg = string( + "Cannot mix state block (:state or ", + s_name_sym, + ") and state component ", + k, + " in the same initial guess.", + ) + throw(CTBase.IncorrectArgument(msg)) + end + idx = s_comp_index[k] + if haskey(state_comp, idx) + msg = string( + "State component ", k, " specified more than once in initial guess." + ) + throw(CTBase.IncorrectArgument(msg)) + end + state_comp[idx] = v + elseif haskey(u_comp_index, k) + if control_block_set + msg = string( + "Cannot mix control block (:control or ", + u_name_sym, + ") and control component ", + k, + " in the same initial guess.", + ) + throw(CTBase.IncorrectArgument(msg)) + end + idx = u_comp_index[k] + if haskey(control_comp, idx) + msg = string( + "Control component ", k, " specified more than once in initial guess." + ) + throw(CTBase.IncorrectArgument(msg)) + end + control_comp[idx] = v + elseif haskey(v_comp_index, k) + if variable_block_set + msg = string( + "Cannot mix variable block (:variable or ", + v_name_sym, + ") and variable component ", + k, + " in the same initial guess.", + ) + throw(CTBase.IncorrectArgument(msg)) + end + idx = v_comp_index[k] + if haskey(variable_comp, idx) + msg = string( + "Variable component ", k, " specified more than once in initial guess." + ) + throw(CTBase.IncorrectArgument(msg)) + end + variable_comp[idx] = v + else + msg = string( + "Unknown key ", + k, + " in initial guess NamedTuple. Allowed keys are: time, state, control, variable, ", + s_name_sym, + ", ", + u_name_sym, + ", ", + v_name_sym, + ", and component names of state/control/variable.", + ) + throw(CTBase.IncorrectArgument(msg)) + end + end + + # Build state/control with possible per-component overrides + state_fun = _build_block_with_components(ocp, :state, state_block, state_comp) + control_fun = _build_block_with_components(ocp, :control, control_block, control_comp) + + # Build variable (block-level or per-component) + variable_val = begin + if isempty(variable_comp) + initial_variable(ocp, variable_block) + else + vdim = variable_dimension(ocp) + if vdim == 0 + msg = "Variable components specified for problem with no variable." + throw(CTBase.IncorrectArgument(msg)) + else + # Start from default variable initialization and override components + base = initial_variable(ocp, nothing) + if vdim == 1 + # Single-component variable: override index 1 if provided + if haskey(variable_comp, 1) + data = variable_comp[1] + val = if data isa AbstractVector{<:Real} + if length(data) != 1 + msg = "Variable component initial guess must be scalar or length-1 vector for variable dimension 1." + throw(CTBase.IncorrectArgument(msg)) + end + data[1] + elseif data isa Real + data + else + msg = string( + "Unsupported variable component initialization type without time: ", + typeof(data), + ) + throw(CTBase.IncorrectArgument(msg)) + end + val + else + # No specific component provided: keep default base + base + end + else + # vdim > 1: base should be a vector of length vdim + vec = if base isa AbstractVector + if length(base) != vdim + msg = string( + "Default variable initialization has incompatible dimension: got ", + length(base), + " instead of ", + vdim, + ".", + ) + throw(CTBase.IncorrectArgument(msg)) + end + collect(base) + elseif base isa Real + fill(base, vdim) + else + msg = string( + "Unsupported default variable initialization type: ", + typeof(base), + ) + throw(CTBase.IncorrectArgument(msg)) + end + # Override provided components; missing ones keep default + for (i, data) in variable_comp + if !(1 <= i <= vdim) + msg = string( + "Variable component index ", + i, + " out of bounds for variable dimension ", + vdim, + ".", + ) + throw(CTBase.IncorrectArgument(msg)) + end + val_scalar = if data isa AbstractVector{<:Real} + if length(data) != 1 + msg = string( + "Variable component index ", + i, + " initial guess must be scalar or length-1 vector.", + ) + throw(CTBase.IncorrectArgument(msg)) + end + data[1] + elseif data isa Real + data + else + msg = string( + "Unsupported variable component initialization type without time: ", + typeof(data), + ) + throw(CTBase.IncorrectArgument(msg)) + end + vec[i] = val_scalar + end + vec + end + end + end + end + + init = OptimalControlInitialGuess(state_fun, control_fun, variable_val) + return _validate_initial_guess(ocp, init) +end + +""" +$(TYPEDSIGNATURES) + +Convert a [`OptimalControlPreInit`](@ref) to an initial guess. +""" +function _initial_guess_from_preinit( + ocp::AbstractOptimalControlProblem, preinit::OptimalControlPreInit +) + nt = (state=preinit.state, control=preinit.control, variable=preinit.variable) + return _initial_guess_from_namedtuple(ocp, nt) +end + +""" +$(TYPEDSIGNATURES) + +Normalise time grid data to a vector format. +""" +function _format_time_grid(time_data) + if time_data === nothing + return nothing + elseif time_data isa AbstractVector + return time_data + elseif time_data isa AbstractArray + return vec(time_data) + else + msg = string( + "Invalid time grid type for initial guess: ", + typeof(time_data), + ". Expected a vector or array.", + ) + throw(CTBase.IncorrectArgument(msg)) + end +end + +""" +$(TYPEDSIGNATURES) + +Convert matrix data to vector-of-vectors format for time-grid interpolation. +""" +function _format_init_data_for_grid(data) + if data isa AbstractMatrix + return matrix2vec(data, 1) + else + return data + end +end + +""" +$(TYPEDSIGNATURES) + +Build a time-dependent initialisation function from data and a time grid. + +Interpolates the provided data over the time grid to create a callable function. +""" +function _build_time_dependent_init( + ocp::AbstractOptimalControlProblem, role::Symbol, data, time::AbstractVector +) + dim = role === :state ? state_dimension(ocp) : control_dimension(ocp) + if data === nothing + return role === :state ? initial_state(ocp, nothing) : initial_control(ocp, nothing) + end + if data isa Function + return data + end + data_fmt = _format_init_data_for_grid(data) + if data_fmt isa AbstractVector{<:Real} + if length(data_fmt) == length(time) + itp = ctinterpolate(time, data_fmt) + return t -> itp(t) + else + return if role === :state + initial_state(ocp, data_fmt) + else + initial_control(ocp, data_fmt) + end + end + elseif data_fmt isa AbstractVector && + !isempty(data_fmt) && + (data_fmt[1] isa AbstractVector) + if length(data_fmt) != length(time) + msg = string( + "Time-grid ", + role, + " initialization mismatch: got ", + length(data_fmt), + " samples for ", + length(time), + "-point time grid.", + ) + throw(CTBase.IncorrectArgument(msg)) + end + itp = ctinterpolate(time, data_fmt) + sample = itp(first(time)) + if !(sample isa AbstractVector) || length(sample) != dim + msg = string( + "Time-grid ", + role, + " initialization has incompatible dimension: got ", + (sample isa AbstractVector ? length(sample) : 1), + " instead of ", + dim, + ) + throw(CTBase.IncorrectArgument(msg)) + end + return t -> itp(t) + else + msg = string( + "Unsupported ", + role, + " initialization type for time-grid based initial guess: ", + typeof(data), + ) + throw(CTBase.IncorrectArgument(msg)) + end +end \ No newline at end of file diff --git a/src/nlp/discretized_ocp.jl b/src/nlp/discretized_ocp.jl new file mode 100644 index 00000000..47505ed1 --- /dev/null +++ b/src/nlp/discretized_ocp.jl @@ -0,0 +1,111 @@ +# ------------------------------------------------------------------------------ # +# Discretized optimal control problem +# +# This file implements helper methods that operate on +# [`DiscretizedOptimalControlProblem`](@ref) and its associated +# back-end builders (`ADNLPSolutionBuilder`, `ExaSolutionBuilder`, +# `OCPBackendBuilders`), which are part of the +# [`AbstractOCPTool`](@ref)-based optimization interface. +# ------------------------------------------------------------------------------ # +# Helpers +""" +$(TYPEDSIGNATURES) + +Invoke the ADNLPModels solution builder to convert NLP execution statistics +into an optimal control solution. +""" +function (builder::ADNLPSolutionBuilder)(nlp_solution::SolverCore.AbstractExecutionStats) + return builder.f(nlp_solution) +end + +""" +$(TYPEDSIGNATURES) + +Invoke the ExaModels solution builder to convert NLP execution statistics +into an optimal control solution. +""" +function (builder::ExaSolutionBuilder)(nlp_solution::SolverCore.AbstractExecutionStats) + return builder.f(nlp_solution) +end + +# Problem +""" +$(TYPEDSIGNATURES) + +Return the original optimal control problem from a discretised problem. + +# Arguments + +- `prob::DiscretizedOptimalControlProblem`: The discretised problem. + +# Returns + +- The underlying [`Model`](@ref CTModels.Model) (optimal control problem). +""" +function ocp_model(prob::DiscretizedOptimalControlProblem) + return prob.optimal_control_problem +end + +""" +$(TYPEDSIGNATURES) + +Retrieve the ADNLPModels model builder from a discretised problem. + +Throws `ArgumentError` if no `:adnlp` backend is registered. +""" +function get_adnlp_model_builder(prob::DiscretizedOptimalControlProblem) + for (name, builders) in pairs(prob.backend_builders) + if name === :adnlp + return builders.model + end + end + throw(ArgumentError("no :adnlp model builder registered")) +end + +""" +$(TYPEDSIGNATURES) + +Retrieve the ExaModels model builder from a discretised problem. + +Throws `ArgumentError` if no `:exa` backend is registered. +""" +function get_exa_model_builder(prob::DiscretizedOptimalControlProblem) + for (name, builders) in pairs(prob.backend_builders) + if name === :exa + return builders.model + end + end + throw(ArgumentError("no :exa model builder registered")) +end + +""" +$(TYPEDSIGNATURES) + +Retrieve the ADNLPModels solution builder from a discretised problem. + +Throws `ArgumentError` if no `:adnlp` backend is registered. +""" +function get_adnlp_solution_builder(prob::DiscretizedOptimalControlProblem) + for (name, builders) in pairs(prob.backend_builders) + if name === :adnlp + return builders.solution + end + end + throw(ArgumentError("no :adnlp solution builder registered")) +end + +""" +$(TYPEDSIGNATURES) + +Retrieve the ExaModels solution builder from a discretised problem. + +Throws `ArgumentError` if no `:exa` backend is registered. +""" +function get_exa_solution_builder(prob::DiscretizedOptimalControlProblem) + for (name, builders) in pairs(prob.backend_builders) + if name === :exa + return builders.solution + end + end + throw(ArgumentError("no :exa solution builder registered")) +end \ No newline at end of file diff --git a/src/nlp/model_api.jl b/src/nlp/model_api.jl new file mode 100644 index 00000000..044a17fc --- /dev/null +++ b/src/nlp/model_api.jl @@ -0,0 +1,90 @@ +# ------------------------------------------------------------------------------ +# NLP Model and Solution builders +# ------------------------------------------------------------------------------ +""" +$(TYPEDSIGNATURES) + +Build an NLP model from an optimisation problem using the specified modeler. + +# Arguments + +- `prob::AbstractOptimizationProblem`: The optimisation problem. +- `initial_guess`: Initial guess for the NLP solver. +- `modeler::AbstractOptimizationModeler`: The modeler (e.g., `ADNLPModeler`, `ExaModeler`). + +# Returns + +- An NLP model suitable for the chosen backend. +""" +function build_model( + prob::AbstractOptimizationProblem, initial_guess, modeler::AbstractOptimizationModeler +) + return modeler(prob, initial_guess) +end + +""" +$(TYPEDSIGNATURES) + +Build an NLP model from a discretised optimal control problem. + +# Arguments + +- `prob::DiscretizedOptimalControlProblem`: The discretised OCP. +- `initial_guess`: Initial guess for the NLP solver. +- `modeler::AbstractOptimizationModeler`: The modeler to use. + +# Returns + +- `NLPModels.AbstractNLPModel`: The NLP model. +""" +function nlp_model( + prob::DiscretizedOptimalControlProblem, + initial_guess, + modeler::AbstractOptimizationModeler, +)::NLPModels.AbstractNLPModel + return build_model(prob, initial_guess, modeler) +end + +""" +$(TYPEDSIGNATURES) + +Build a solution from NLP execution statistics using the specified modeler. + +# Arguments + +- `prob::AbstractOptimizationProblem`: The optimisation problem. +- `model_solution`: NLP solver output (execution statistics). +- `modeler::AbstractOptimizationModeler`: The modeler used for building. + +# Returns + +- A solution object appropriate for the problem type. +""" +function build_solution( + prob::AbstractOptimizationProblem, model_solution, modeler::AbstractOptimizationModeler +) + return modeler(prob, model_solution) +end + +""" +$(TYPEDSIGNATURES) + +Build an optimal control solution from NLP execution statistics. + +# Arguments + +- `docp::DiscretizedOptimalControlProblem`: The discretised OCP. +- `model_solution::SolverCore.AbstractExecutionStats`: NLP solver output. +- `modeler::AbstractOptimizationModeler`: The modeler used. + +# Returns + +- `AbstractOptimalControlSolution`: The OCP solution. +""" +function ocp_solution( + docp::DiscretizedOptimalControlProblem, + model_solution::SolverCore.AbstractExecutionStats, + modeler::AbstractOptimizationModeler, +)::AbstractOptimalControlSolution + return build_solution(docp, model_solution, modeler) +end \ No newline at end of file diff --git a/src/nlp/nlp_backends.jl b/src/nlp/nlp_backends.jl new file mode 100644 index 00000000..8f90eb34 --- /dev/null +++ b/src/nlp/nlp_backends.jl @@ -0,0 +1,300 @@ +# ------------------------------------------------------------------------------ +# Model backends +# ------------------------------------------------------------------------------ + +# ------------------------------------------------------------------------------ +# ADNLPModels +# ------------------------------------------------------------------------------ +""" +$(TYPEDSIGNATURES) + +Return the default value for the `show_time` option of [`ADNLPModeler`](@ref). + +Default is `false`. +""" +__adnlp_model_show_time() = false + +""" +$(TYPEDSIGNATURES) + +Return the default automatic differentiation backend for [`ADNLPModeler`](@ref). + +Default is `:optimized`. +""" +__adnlp_model_backend() = :optimized + +""" +$(TYPEDSIGNATURES) + +Return the option specifications for [`ADNLPModeler`](@ref). + +Defines options: `show_time` (Bool) and `backend` (Symbol). +""" +function _option_specs(::Type{<:ADNLPModeler}) + return ( + show_time=OptionSpec(; + type=Bool, + default=__adnlp_model_show_time(), + description="Whether to show timing information while building the ADNLP model.", + ), + backend=OptionSpec(; + type=Symbol, + default=__adnlp_model_backend(), + description="Automatic differentiation backend used by ADNLPModels.", + ), + ) +end + +""" +$(TYPEDSIGNATURES) + +Construct an [`ADNLPModeler`](@ref) with the given options. + +# Keyword Arguments + +- `show_time::Bool`: Whether to show timing information (default: `false`). +- `backend::Symbol`: AD backend to use (default: `:optimized`). + +# Returns + +- `ADNLPModeler`: A configured modeler instance. +""" +function ADNLPModeler(; kwargs...) + values, sources = _build_ocp_tool_options(ADNLPModeler; kwargs..., strict_keys=false) + return ADNLPModeler{typeof(values),typeof(sources)}(values, sources) +end + +""" +$(TYPEDSIGNATURES) + +Build an ADNLPModel from an optimisation problem and initial guess. +""" +function (modeler::ADNLPModeler)( + prob::AbstractOptimizationProblem, initial_guess +)::ADNLPModels.ADNLPModel + vals = _options_values(modeler) + builder = get_adnlp_model_builder(prob) + return builder(initial_guess; vals...) +end + +""" +$(TYPEDSIGNATURES) + +Build an OCP solution from NLP execution statistics using ADNLPModels. +""" +function (modeler::ADNLPModeler)( + prob::AbstractOptimizationProblem, nlp_solution::SolverCore.AbstractExecutionStats +) + builder = get_adnlp_solution_builder(prob) + return builder(nlp_solution) +end + +# ------------------------------------------------------------------------------ +# ExaModels +# ------------------------------------------------------------------------------ +""" +$(TYPEDSIGNATURES) + +Return the default floating-point type for [`ExaModeler`](@ref). + +Default is `Float64`. +""" +__exa_model_base_type() = Float64 + +""" +$(TYPEDSIGNATURES) + +Return the default execution backend for [`ExaModeler`](@ref). + +Default is `nothing` (CPU). +""" +__exa_model_backend() = nothing + +""" +$(TYPEDSIGNATURES) + +Return the option specifications for [`ExaModeler`](@ref). + +Defines options: `base_type`, `minimize`, and `backend`. +""" +function _option_specs(::Type{<:ExaModeler}) + return ( + base_type=OptionSpec(; + type=Type{<:AbstractFloat}, + default=__exa_model_base_type(), + description="Base floating-point type used by ExaModels.", + ), + minimize=OptionSpec(; + type=Bool, + default=missing, + description="Whether to minimize (true) or maximize (false) the objective.", + ), + backend=OptionSpec(; + type=Union{Nothing,KernelAbstractions.Backend}, + default=__exa_model_backend(), + description="Execution backend for ExaModels (CPU, GPU, etc.).", + ), + ) +end + +""" +$(TYPEDSIGNATURES) + +Construct an [`ExaModeler`](@ref) with the given options. + +# Keyword Arguments + +- `base_type::Type{<:AbstractFloat}`: Floating-point type (default: `Float64`). +- `minimize::Bool`: Whether to minimise (default from problem). +- `backend`: Execution backend (default: `nothing` for CPU). + +# Returns + +- `ExaModeler`: A configured modeler instance. +""" +function ExaModeler(; kwargs...) + values, sources = _build_ocp_tool_options(ExaModeler; kwargs..., strict_keys=true) + BaseType = values.base_type + + # base_type is only needed to fix the type parameter; it does not need to + # remain part of the exposed options NamedTuples. + filtered_vals = _filter_options(values, (:base_type,)) + filtered_srcs = _filter_options(sources, (:base_type,)) + + return ExaModeler{BaseType,typeof(filtered_vals),typeof(filtered_srcs)}( + filtered_vals, filtered_srcs + ) +end + +""" +$(TYPEDSIGNATURES) + +Build an ExaModel from an optimisation problem and initial guess. +""" +function (modeler::ExaModeler{BaseType})( + prob::AbstractOptimizationProblem, initial_guess +)::ExaModels.ExaModel{BaseType} where {BaseType<:AbstractFloat} + vals = _options_values(modeler) + backend = vals.backend + builder = get_exa_model_builder(prob) + return builder(BaseType, initial_guess; backend=backend, vals...) +end + +""" +$(TYPEDSIGNATURES) + +Build an OCP solution from NLP execution statistics using ExaModels. +""" +function (modeler::ExaModeler)( + prob::AbstractOptimizationProblem, nlp_solution::SolverCore.AbstractExecutionStats +) + builder = get_exa_solution_builder(prob) + return builder(nlp_solution) +end + +# ------------------------------------------------------------------------------ +# Registration +# ------------------------------------------------------------------------------ + +""" +$(TYPEDSIGNATURES) + +Return the symbol identifier for [`ADNLPModeler`](@ref). + +Returns `:adnlp`. +""" +get_symbol(::Type{<:ADNLPModeler}) = :adnlp + +""" +$(TYPEDSIGNATURES) + +Return the symbol identifier for [`ExaModeler`](@ref). + +Returns `:exa`. +""" +get_symbol(::Type{<:ExaModeler}) = :exa + +""" +$(TYPEDSIGNATURES) + +Return the package name for [`ADNLPModeler`](@ref). + +Returns `"ADNLPModels"`. +""" +tool_package_name(::Type{<:ADNLPModeler}) = "ADNLPModels" + +""" +$(TYPEDSIGNATURES) + +Return the package name for [`ExaModeler`](@ref). + +Returns `"ExaModels"`. +""" +tool_package_name(::Type{<:ExaModeler}) = "ExaModels" + +""" +Tuple of all registered modeler types. + +Currently contains `(ADNLPModeler, ExaModeler)`. +""" +const REGISTERED_MODELERS = (ADNLPModeler, ExaModeler) + +""" +$(TYPEDSIGNATURES) + +Return the tuple of all registered modeler types. +""" +registered_modeler_types() = REGISTERED_MODELERS + +""" +$(TYPEDSIGNATURES) + +Return a tuple of symbols for all registered modelers. + +Returns `(:adnlp, :exa)`. +""" +modeler_symbols() = Tuple(get_symbol(T) for T in REGISTERED_MODELERS) + +""" +$(TYPEDSIGNATURES) + +Look up a modeler type from its symbol identifier. + +Throws `CTBase.IncorrectArgument` if the symbol is unknown. +""" +function _modeler_type_from_symbol(sym::Symbol) + for T in REGISTERED_MODELERS + if get_symbol(T) === sym + return T + end + end + msg = "Unknown NLP model symbol $(sym). Supported symbols: $(modeler_symbols())." + throw(CTBase.IncorrectArgument(msg)) +end + +""" +$(TYPEDSIGNATURES) + +Construct a modeler from its symbol identifier. + +# Arguments + +- `sym::Symbol`: The modeler symbol (`:adnlp` or `:exa`). +- `kwargs...`: Options to pass to the modeler constructor. + +# Returns + +- An instance of the corresponding modeler type. + +# Example + +```julia-repl +julia> using CTModels + +julia> modeler = CTModels.build_modeler_from_symbol(:adnlp) +``` +""" +function build_modeler_from_symbol(sym::Symbol; kwargs...) + T = _modeler_type_from_symbol(sym) + return T(; kwargs...) +end \ No newline at end of file diff --git a/src/nlp/options_schema.jl b/src/nlp/options_schema.jl new file mode 100644 index 00000000..6884e25a --- /dev/null +++ b/src/nlp/options_schema.jl @@ -0,0 +1,580 @@ +# Internal metadata schema for backend and discretizer options. + +""" +$(TYPEDSIGNATURES) + +Return a short `Symbol` identifying the package or implementation used by a +given [`AbstractOCPTool`](@ref). + +Concrete tool types are expected to specialize this method on their own type, +for example `get_symbol(::Type{<:MyTool}) = :mytool`. +""" +function get_symbol(tool::AbstractOCPTool) + return get_symbol(typeof(tool)) +end + +""" +$(TYPEDSIGNATURES) + +Default implementation that throws `CTBase.NotImplemented`. + +Concrete tool types must specialize this method. +""" +function get_symbol(::Type{T}) where {T<:AbstractOCPTool} + throw(CTBase.NotImplemented("get_symbol not implemented for $(T)")) +end + +""" +$(TYPEDSIGNATURES) + +Return the package name associated with a tool instance. +""" +function tool_package_name(tool::AbstractOCPTool) + return tool_package_name(typeof(tool)) +end + +""" +$(TYPEDSIGNATURES) + +Return the package name for a tool type. + +Default implementation returns `missing`. +""" +function tool_package_name(::Type{T}) where {T<:AbstractOCPTool} + return missing +end + +# --------------------------------------------------------------------------- +# Internal options API overview +# +# For each tool T<:AbstractOCPTool: +# - _option_specs(T) :: NamedTuple of OptionSpec describing option keys. +# - default_options(T) :: NamedTuple of default values taken from specs +# (only options with non-missing defaults are included). +# - _build_ocp_tool_options(T; kwargs..., strict_keys=false) :: (values, sources) +# merges default options with user kwargs and tracks provenance +# (:ct_default or :user) in a parallel NamedTuple. +# - Concrete tools store `options_values` and `options_sources` fields and +# are accessed via _options_values(tool) and _option_sources(tool). +# +# OptionSpec fields: +# - type : expected Julia type for validation (or `missing`). +# - default : default value at the tool level (or `missing` if none). +# - description : short human-readable description (or `missing`). +# --------------------------------------------------------------------------- + +function OptionSpec(; type=missing, default=missing, description=missing) + OptionSpec(type, default, description) +end + +# Default: no metadata for a given tool type. +""" +$(TYPEDSIGNATURES) + +Return the option metadata specification for a concrete +[`AbstractOCPTool`](@ref) subtype. + +Concrete tools typically specialize this method on their own type and return a +`NamedTuple` whose fields correspond to option names and whose values are +[`OptionSpec`](@ref) instances. + +The default implementation returns `missing`, meaning that no option metadata +is available for the given tool type. +""" +function _option_specs(::Type{T}) where {T<:AbstractOCPTool} + return missing +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload to accept tool instances. +""" +_option_specs(x::AbstractOCPTool) = _option_specs(typeof(x)) + +""" +$(TYPEDSIGNATURES) + +Return the current option values for a tool instance. +""" +function _options_values(tool::AbstractOCPTool) + return tool.options_values +end + +""" +$(TYPEDSIGNATURES) + +Return the option sources (`:ct_default` or `:user`) for a tool instance. +""" +function _option_sources(tool::AbstractOCPTool) + return tool.options_sources +end + +""" +$(TYPEDSIGNATURES) + +Return the list of known option keys for a tool type. + +Returns `missing` if no option metadata is available. +""" +function options_keys(tool_type::Type{<:AbstractOCPTool}) + specs = _option_specs(tool_type) + specs === missing && return missing + return propertynames(specs) +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +options_keys(x::AbstractOCPTool) = options_keys(typeof(x)) + +""" +$(TYPEDSIGNATURES) + +Check if `key` is a valid option key for the given tool type. + +Returns `missing` if no option metadata is available. +""" +function is_an_option_key(key::Symbol, tool_type::Type{<:AbstractOCPTool}) + specs = _option_specs(tool_type) + specs === missing && return missing + return key in propertynames(specs) +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +is_an_option_key(key::Symbol, x::AbstractOCPTool) = is_an_option_key(key, typeof(x)) + +""" +$(TYPEDSIGNATURES) + +Return the expected type for an option key. + +Returns `missing` if the key is unknown or no type is specified. +""" +function option_type(key::Symbol, tool_type::Type{<:AbstractOCPTool}) + specs = _option_specs(tool_type) + specs === missing && return missing + if !(haskey(specs, key)) + return missing + end + spec = getfield(specs, key)::OptionSpec + return spec.type +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +option_type(key::Symbol, x::AbstractOCPTool) = option_type(key, typeof(x)) + +""" +$(TYPEDSIGNATURES) + +Return the description for an option key. + +Returns `missing` if the key is unknown or no description is available. +""" +function option_description(key::Symbol, tool_type::Type{<:AbstractOCPTool}) + specs = _option_specs(tool_type) + specs === missing && return missing + if !(haskey(specs, key)) + return missing + end + spec = getfield(specs, key)::OptionSpec + return spec.description +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +option_description(key::Symbol, x::AbstractOCPTool) = option_description(key, typeof(x)) + +""" +$(TYPEDSIGNATURES) + +Return the default value for an option key. + +Returns `missing` if the key is unknown or no default is specified. +""" +function option_default(key::Symbol, tool_type::Type{<:AbstractOCPTool}) + specs = _option_specs(tool_type) + specs === missing && return missing + if !(haskey(specs, key)) + return missing + end + spec = getfield(specs, key)::OptionSpec + return spec.default +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +option_default(key::Symbol, x::AbstractOCPTool) = option_default(key, typeof(x)) + +""" +$(TYPEDSIGNATURES) + +Return a `NamedTuple` of default option values for a tool type. + +Only options with non-missing defaults are included. +""" +function default_options(tool_type::Type{<:AbstractOCPTool}) + specs = _option_specs(tool_type) + specs === missing && return NamedTuple() + pairs = Pair{Symbol,Any}[] + for name in propertynames(specs) + spec = getfield(specs, name)::OptionSpec + if spec.default !== missing + push!(pairs, name => spec.default) + end + end + return (; pairs...) +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +default_options(x::AbstractOCPTool) = default_options(typeof(x)) + +""" +$(TYPEDSIGNATURES) + +Filter a `NamedTuple` by excluding specified keys. +""" +function _filter_options(nt::NamedTuple, exclude) + return (; (k => v for (k, v) in pairs(nt) if !(k in exclude))...) +end + +""" +$(TYPEDSIGNATURES) + +Compute the Levenshtein distance between two strings. + +Used for suggesting similar option names when a typo is detected. +""" +function _string_distance(a::AbstractString, b::AbstractString) + m = lastindex(a) + n = lastindex(b) + # Use 1-based indices over code units for simplicity; option keys are short. + da = collect(codeunits(a)) + db = collect(codeunits(b)) + # dp[i+1, j+1] = distance between first i chars of a and first j chars of b + dp = Array{Int}(undef, m + 1, n + 1) + for i in 0:m + dp[i + 1, 1] = i + end + for j in 0:n + dp[1, j + 1] = j + end + for i in 1:m + for j in 1:n + cost = da[i] == db[j] ? 0 : 1 + dp[i + 1, j + 1] = min( + dp[i, j + 1] + 1, # deletion + dp[i + 1, j] + 1, # insertion + dp[i, j] + cost, # substitution + ) + end + end + return dp[m + 1, n + 1] +end + +""" +$(TYPEDSIGNATURES) + +Suggest up to `max_suggestions` closest option keys for a tool type. + +Used to provide helpful error messages when an unknown option is specified. +""" +function _suggest_option_keys( + key::Symbol, tool_type::Type{<:AbstractOCPTool}; max_suggestions::Int=3 +) + specs = _option_specs(tool_type) + specs === missing && return Symbol[] + names = collect(propertynames(specs)) + distances = [(_string_distance(String(key), String(n)), n) for n in names] + sort!(distances; by=first) + take = min(max_suggestions, length(distances)) + return [distances[i][2] for i in 1:take] +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +function _suggest_option_keys(key::Symbol, x::AbstractOCPTool; max_suggestions::Int=3) + _suggest_option_keys(key, typeof(x); max_suggestions=max_suggestions) +end + +# --------------------------------------------------------------------------- +# High-level getters for option value/source/default on instantiated tools. +# These helpers validate the option key and reuse the suggestion machinery +# used when parsing user keyword arguments. +# --------------------------------------------------------------------------- + +""" +$(TYPEDSIGNATURES) + +Generate and throw an error for an unknown option key with suggestions. +""" +function _unknown_option_error( + key::Symbol, tool_type::Type{<:AbstractOCPTool}, context::AbstractString +) + suggestions = _suggest_option_keys(key, tool_type; max_suggestions=3) + tool_name = string(nameof(tool_type)) + msg = "Unknown option $(key) for $(tool_name) when querying the $(context)." + if !isempty(suggestions) + msg *= " Did you mean " * join(string.(suggestions), " or ") * "?" + end + msg *= " Use show_options($(tool_name)) to list all available options." + throw(CTBase.IncorrectArgument(msg)) +end + +""" +$(TYPEDSIGNATURES) + +Get the current value of an option for a tool instance. + +Throws an error if the option is unknown or has no value. +""" +function get_option_value(tool::AbstractOCPTool, key::Symbol) + vals = _options_values(tool) + if haskey(vals, key) + return vals[key] + end + + tool_type = typeof(tool) + specs = _option_specs(tool_type) + if specs === missing || !haskey(specs, key) + return _unknown_option_error(key, tool_type, "value") + end + + tool_name = string(nameof(tool_type)) + msg = + "Option $(key) is defined for $(tool_name) but has no value: " * + "no default was provided and the option was not set by the user." + throw(CTBase.IncorrectArgument(msg)) +end + +""" +$(TYPEDSIGNATURES) + +Get the source (`:ct_default` or `:user`) of an option value. + +Throws an error if the option is unknown. +""" +function get_option_source(tool::AbstractOCPTool, key::Symbol) + srcs = _option_sources(tool) + if haskey(srcs, key) + return srcs[key] + end + + tool_type = typeof(tool) + specs = _option_specs(tool_type) + if specs === missing || !haskey(specs, key) + return _unknown_option_error(key, tool_type, "source") + end + + tool_name = string(nameof(tool_type)) + msg = "Option $(key) is defined for $(tool_name) but has no recorded source." + throw(CTBase.IncorrectArgument(msg)) +end + +""" +$(TYPEDSIGNATURES) + +Get the default value of an option for a tool instance. + +Throws an error if the option is unknown. +""" +function get_option_default(tool::AbstractOCPTool, key::Symbol) + tool_type = typeof(tool) + specs = _option_specs(tool_type) + if specs === missing || !haskey(specs, key) + return _unknown_option_error(key, tool_type, "default") + end + return option_default(key, tool_type) +end + +""" +$(TYPEDSIGNATURES) + +Print a human-readable listing of options and their metadata for a tool type. +""" +function _show_options(tool_type::Type{<:AbstractOCPTool}) + specs = _option_specs(tool_type) + if specs === missing + println("No option metadata available for ", tool_type, ".") + return nothing + end + println("Options for ", tool_type, ":") + for name in propertynames(specs) + spec = getfield(specs, name)::OptionSpec + T = spec.type === missing ? "Any" : string(spec.type) + desc = spec.description === missing ? "" : " — " * String(spec.description) + println(" - ", name, " :: ", T, desc) + end +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +function _show_options(x::AbstractOCPTool) + return _show_options(typeof(x)) +end + +""" +$(TYPEDSIGNATURES) + +Display available options for a tool type. + +Prints option names, types, and descriptions to stdout. +""" +function show_options(tool_type::Type{<:AbstractOCPTool}) + return _show_options(tool_type) +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +function show_options(x::AbstractOCPTool) + return _show_options(typeof(x)) +end + +""" +$(TYPEDSIGNATURES) + +Validate user-supplied keyword options against tool metadata. + +If `strict_keys` is `true`, unknown keys trigger an error. If `false`, unknown +keys are accepted and only known keys are type-checked. +""" +function _validate_option_kwargs( + user_nt::NamedTuple, tool_type::Type{<:AbstractOCPTool}; strict_keys::Bool=false +) + specs = _option_specs(tool_type) + specs === missing && return nothing + + known_keys = propertynames(specs) + + # Unknown keys + if strict_keys + unknown = Symbol[] + for k in keys(user_nt) + if !(k in known_keys) + push!(unknown, k) + end + end + if !isempty(unknown) + # Only report the first unknown key with suggestions. + k = first(unknown) + suggestions = _suggest_option_keys(k, tool_type; max_suggestions=3) + tool_name = string(nameof(tool_type)) + msg = "Unknown option $(k) for $(tool_name)." + if !isempty(suggestions) + msg *= " Did you mean " * join(string.(suggestions), " or ") * "?" + end + msg *= " Use show_options($(tool_name)) to list all available options." + throw(CTBase.IncorrectArgument(msg)) + end + end + + # Type checks for known keys where a type is provided. + for k in keys(user_nt) + if !(k in known_keys) + continue + end + T = option_type(k, tool_type) + T === missing && continue + v = user_nt[k] + if !(v isa T) + tool_name = string(nameof(tool_type)) + msg = + "Invalid type for option $(k) of $(tool_name). " * + "Expected value of type $(T), got value of type $(typeof(v))." + throw(CTBase.IncorrectArgument(msg)) + end + end + + return nothing +end + +""" +$(TYPEDSIGNATURES) + +Convenience overload for tool instances. +""" +function _validate_option_kwargs( + user_nt::NamedTuple, x::AbstractOCPTool; strict_keys::Bool=false +) + _validate_option_kwargs(user_nt, typeof(x); strict_keys=strict_keys) +end + +""" +$(TYPEDSIGNATURES) + +Build a normalized pair of option `values` and `sources` for a concrete +[`AbstractOCPTool`](@ref) subtype. + +This helper is typically used in the keyword-only constructor of a tool type, +for example `MyTool(; kwargs...) = MyTool(_build_ocp_tool_options(MyTool; kwargs...)...)`. + +# Arguments + +- `::Type{T}`: concrete subtype of `AbstractOCPTool`. +- `strict_keys::Bool`: if `true`, unknown option keys are rejected with a + detailed error; if `false`, unknown keys are accepted. +- `kwargs...`: user-supplied option values. + +# Returns + +A pair `(values, sources)` where: + +- `values::NamedTuple`: effective option values after merging tool defaults + (from [`default_options`](@ref)) with the user keywords. +- `sources::NamedTuple`: for each option name, either `:ct_default` or + `:user` indicating whether the value comes from the tool defaults or from + user input. +""" +function _build_ocp_tool_options( + ::Type{T}; strict_keys::Bool=false, kwargs... +) where {T<:AbstractOCPTool} + # Normalize user-supplied keyword arguments to a NamedTuple. + user_nt = NamedTuple(kwargs) + + # Validate option keys and types against the tool metadata. + _validate_option_kwargs(user_nt, T; strict_keys=strict_keys) + + # Merge tool-level default options with user overrides (user wins). + defaults = default_options(T) + values = merge(defaults, user_nt) + + # Build a parallel NamedTuple recording the provenance of each option + # (:ct_default for defaults coming from the tool, :user for overrides). + src_pairs = Pair{Symbol,Symbol}[] + for name in keys(values) + src = haskey(user_nt, name) ? :user : :ct_default + push!(src_pairs, name => src) + end + sources = (; src_pairs...) + + return values, sources +end \ No newline at end of file diff --git a/src/nlp/problem_core.jl b/src/nlp/problem_core.jl new file mode 100644 index 00000000..32d9547a --- /dev/null +++ b/src/nlp/problem_core.jl @@ -0,0 +1,94 @@ +# builders of NLP models +""" +$(TYPEDSIGNATURES) + +Invoke the ADNLPModels model builder to construct an NLP model from an initial guess. +""" +function (builder::ADNLPModelBuilder)(initial_guess; kwargs...)::ADNLPModels.ADNLPModel + return builder.f(initial_guess; kwargs...) +end + +""" +$(TYPEDSIGNATURES) + +Invoke the ExaModels model builder to construct an NLP model from an initial guess. + +The `BaseType` parameter specifies the floating-point type for the model. +""" +function (builder::ExaModelBuilder)( + ::Type{BaseType}, initial_guess; kwargs... +)::ExaModels.ExaModel where {BaseType<:AbstractFloat} + return builder.f(BaseType, initial_guess; kwargs...) +end + +# helpers to build solutions + +# problem + +""" +$(TYPEDSIGNATURES) + +Interface method for [`AbstractOptimizationProblem`](@ref). + +Concrete problem types that support the ExaModels back-end must +specialize this function to return the [`ExaModelBuilder`](@ref) used to +construct the corresponding NLP model. The default implementation throws +`CTBase.NotImplemented`. +""" +function get_exa_model_builder(prob::AbstractOptimizationProblem) + throw( + CTBase.NotImplemented("get_exa_model_builder not implemented for $(typeof(prob))"), + ) +end + +""" +$(TYPEDSIGNATURES) + +Interface method for [`AbstractOptimizationProblem`](@ref). + +Concrete problem types that support the ADNLPModels back-end must +specialize this function to return the [`ADNLPModelBuilder`](@ref) used +to construct the corresponding NLP model. The default implementation +throws `CTBase.NotImplemented`. +""" +function get_adnlp_model_builder(prob::AbstractOptimizationProblem) + throw( + CTBase.NotImplemented("get_adnlp_model_builder not implemented for $(typeof(prob))"), + ) +end + +""" +$(TYPEDSIGNATURES) + +Interface method for [`AbstractOptimizationProblem`](@ref). + +Concrete problem types that support ADNLPModels must specialize this +function to return the [`ADNLPSolutionBuilder`](@ref) used to convert NLP +solutions into the desired representation. The default implementation +throws `CTBase.NotImplemented`. +""" +function get_adnlp_solution_builder(prob::AbstractOptimizationProblem) + throw( + CTBase.NotImplemented( + "get_adnlp_solution_builder not implemented for $(typeof(prob))", + ), + ) +end + +""" +$(TYPEDSIGNATURES) + +Interface method for [`AbstractOptimizationProblem`](@ref). + +Concrete problem types that support ExaModels must specialize this +function to return the [`ExaSolutionBuilder`](@ref) used to convert NLP +solutions into the desired representation. The default implementation +throws `CTBase.NotImplemented`. +""" +function get_exa_solution_builder(prob::AbstractOptimizationProblem) + throw( + CTBase.NotImplemented( + "get_exa_solution_builder not implemented for $(typeof(prob))", + ), + ) +end \ No newline at end of file diff --git a/src/constraints.jl b/src/ocp/constraints.jl similarity index 96% rename from src/constraints.jl rename to src/ocp/constraints.jl index 93ac4581..5f32ef2e 100644 --- a/src/constraints.jl +++ b/src/ocp/constraints.jl @@ -254,15 +254,60 @@ function constraint!( ) end +""" + as_vector(::Nothing) -> Nothing + +Return `nothing` unchanged. +""" as_vector(::Nothing) = nothing + +""" + as_vector(x::T) -> Vector{T} where {T<:ctNumber} + +Wrap a scalar number into a single-element vector. +""" (as_vector(x::T)::Vector{T}) where {T<:ctNumber} = [x] + +""" + as_vector(x::Vector{T}) -> Vector{T} where {T<:ctNumber} + +Return a vector unchanged. +""" as_vector(x::Vector{T}) where {T<:ctNumber} = x +""" + as_range(::Nothing) -> Nothing + +Return `nothing` unchanged. +""" as_range(::Nothing) = nothing + +""" + as_range(r::Int) -> UnitRange{Int} + +Convert a scalar integer to a single-element range `r:r`. +""" as_range(r::T) where {T<:Int} = r:r + +""" + as_range(r::OrdinalRange{Int}) -> OrdinalRange{Int} + +Return an ordinal range unchanged. +""" as_range(r::OrdinalRange{T}) where {T<:Int} = r +""" + discretize(constraint::Function, grid::Vector{T}) -> Vector where {T<:ctNumber} + +Discretise a constraint function over a time grid. +""" discretize(constraint::Function, grid::Vector{T}) where {T<:ctNumber} = constraint.(grid) + +""" + discretize(::Nothing, grid::Vector{T}) -> Nothing where {T<:ctNumber} + +Return `nothing` when discretising a missing constraint. +""" discretize(::Nothing, grid::Vector{T}) where {T<:ctNumber} = nothing # ------------------------------------------------------------------------------ # diff --git a/src/control.jl b/src/ocp/control.jl similarity index 100% rename from src/control.jl rename to src/ocp/control.jl diff --git a/src/definition.jl b/src/ocp/definition.jl similarity index 66% rename from src/definition.jl rename to src/ocp/definition.jl index a2548be8..8961df62 100644 --- a/src/definition.jl +++ b/src/ocp/definition.jl @@ -7,6 +7,14 @@ $(TYPEDSIGNATURES) Set the model definition of the optimal control problem. +# Arguments + +- `ocp::PreModel`: The pre-model to modify. +- `definition::Expr`: The symbolic expression defining the problem. + +# Returns + +- `Nothing` """ function definition!(ocp::PreModel, definition::Expr)::Nothing ocp.definition = definition @@ -22,6 +30,13 @@ $(TYPEDSIGNATURES) Return the model definition of the optimal control problem. +# Arguments + +- `ocp::Model`: The built optimal control problem model. + +# Returns + +- `Expr`: The symbolic expression defining the problem. """ function definition(ocp::Model)::Expr return ocp.definition @@ -32,6 +47,13 @@ $(TYPEDSIGNATURES) Return the model definition of the optimal control problem or `nothing`. +# Arguments + +- `ocp::PreModel`: The pre-model (may not have a definition set). + +# Returns + +- `Union{Expr, Nothing}`: The symbolic expression or `nothing` if not set. """ function definition(ocp::PreModel) return ocp.definition diff --git a/src/dual_model.jl b/src/ocp/dual_model.jl similarity index 100% rename from src/dual_model.jl rename to src/ocp/dual_model.jl diff --git a/src/dynamics.jl b/src/ocp/dynamics.jl similarity index 100% rename from src/dynamics.jl rename to src/ocp/dynamics.jl diff --git a/src/model.jl b/src/ocp/model.jl similarity index 98% rename from src/model.jl rename to src/ocp/model.jl index 8718975c..05b62fd2 100644 --- a/src/model.jl +++ b/src/ocp/model.jl @@ -325,7 +325,7 @@ end """ $(TYPEDSIGNATURES) -Return `true`. +Return `true` for an autonomous model. """ function is_autonomous( ::Model{ @@ -346,7 +346,7 @@ end """ $(TYPEDSIGNATURES) -Return `true`. +Return `false` for a non-autonomous model. """ function is_autonomous( ::Model{ @@ -544,10 +544,20 @@ function time_name(ocp::Model)::String end # Initial time +""" +$(TYPEDSIGNATURES) + +Throw an error for unsupported initial time access. +""" function initial_time(ocp::AbstractModel) throw(CTBase.UnauthorizedCall("You cannot get the initial time with this function.")) end +""" +$(TYPEDSIGNATURES) + +Throw an error for unsupported initial time access with variable. +""" function initial_time(ocp::AbstractModel, variable::AbstractVector) throw(CTBase.UnauthorizedCall("You cannot get the initial time with this function.")) end @@ -648,6 +658,7 @@ end """ $(TYPEDSIGNATURES) +Throw an error for unsupported final time access. """ function final_time(ocp::AbstractModel) throw(CTBase.UnauthorizedCall("You cannot get the final time with this function.")) @@ -656,6 +667,7 @@ end """ $(TYPEDSIGNATURES) +Throw an error for unsupported final time access with variable. """ function final_time(ocp::AbstractModel, variable::AbstractVector) throw(CTBase.UnauthorizedCall("You cannot get the final time with this function.")) @@ -785,6 +797,11 @@ function criterion(ocp::Model)::Symbol end # Mayer +""" +$(TYPEDSIGNATURES) + +Throw an error when accessing Mayer cost on a model without one. +""" function mayer(ocp::AbstractModel) throw(CTBase.UnauthorizedCall("This ocp has no Mayer objective.")) end @@ -841,6 +858,11 @@ function has_mayer_cost(ocp::Model)::Bool end # Lagrange +""" +$(TYPEDSIGNATURES) + +Throw an error when accessing Lagrange cost on a model without one. +""" function lagrange(ocp::AbstractModel) throw(CTBase.UnauthorizedCall("This ocp has no Lagrange objective.")) end diff --git a/src/objective.jl b/src/ocp/objective.jl similarity index 100% rename from src/objective.jl rename to src/ocp/objective.jl diff --git a/src/print.jl b/src/ocp/print.jl similarity index 92% rename from src/print.jl rename to src/ocp/print.jl index a6695df5..648d4dcf 100644 --- a/src/print.jl +++ b/src/ocp/print.jl @@ -1,6 +1,17 @@ # ------------------------------------------------------------------------------ # # PRINT # ------------------------------------------------------------------------------ # +""" +$(TYPEDSIGNATURES) + +Print an expression with indentation. + +# Arguments + +- `e::Expr`: The expression to print. +- `io::IO`: The output stream. +- `l::Int`: The indentation level (number of spaces). +""" function __print(e::Expr, io::IO, l::Int) @match e begin :(($a, $b)) => println(io, " "^l, a, ", ", b) @@ -8,6 +19,20 @@ function __print(e::Expr, io::IO, l::Int) end end +""" +$(TYPEDSIGNATURES) + +Print the abstract definition of an optimal control problem. + +# Arguments + +- `io::IO`: The output stream. +- `ocp::Union{Model,PreModel}`: The optimal control problem. + +# Returns + +- `Bool`: `true` if something was printed. +""" function __print_abstract_definition(io::IO, ocp::Union{Model,PreModel}) @assert hasproperty(definition(ocp), :head) printstyled(io, "Abstract definition:\n\n"; bold=true) @@ -20,6 +45,18 @@ function __print_abstract_definition(io::IO, ocp::Union{Model,PreModel}) return true end +""" +$(TYPEDSIGNATURES) + +Print the mathematical definition of an optimal control problem. + +Displays the problem in standard mathematical notation with objective, +dynamics, and constraints. + +# Returns + +- `Bool`: `true` if something was printed. +""" function __print_mathematical_definition( io::IO, some_printing::Bool, @@ -289,6 +326,13 @@ function Base.show(io::IO, ::MIME"text/plain", ocp::Model) return nothing end +""" +$(TYPEDSIGNATURES) + +Default show method for a [`Model`](@ref CTModels.Model). + +Prints only the type name. +""" function Base.show_default(io::IO, ocp::Model) return print(io, typeof(ocp)) end @@ -386,6 +430,9 @@ end """ $(TYPEDSIGNATURES) +Default show method for a [`PreModel`](@ref). + +Prints only the type name. """ function Base.show_default(io::IO, ocp::PreModel) return print(io, typeof(ocp)) diff --git a/src/solution.jl b/src/ocp/solution.jl similarity index 98% rename from src/solution.jl rename to src/ocp/solution.jl index 64e579dc..4b6dba78 100644 --- a/src/solution.jl +++ b/src/ocp/solution.jl @@ -25,6 +25,7 @@ Build a solution from the optimal control problem, the time grid, the state, con - `control_constraints_ub_dual::Matrix{Float64}`: the upper bound dual of the control constraints. - `variable_constraints_lb_dual::Vector{Float64}`: the lower bound dual of the variable constraints. - `variable_constraints_ub_dual::Vector{Float64}`: the upper bound dual of the variable constraints. +- `infos::Dict{Symbol,Any}`: additional solver information dictionary. # Returns @@ -52,6 +53,7 @@ function build_solution( control_constraints_ub_dual::Union{Matrix{Float64},Nothing}=__constraints(), variable_constraints_lb_dual::Union{Vector{Float64},Nothing}=__constraints(), variable_constraints_ub_dual::Union{Vector{Float64},Nothing}=__constraints(), + infos::Dict{Symbol,Any}=Dict{Symbol,Any}(), ) where { TX<:Union{Matrix{Float64},Function}, TU<:Union{Matrix{Float64},Function}, @@ -106,8 +108,7 @@ function build_solution( fp = (dim_x == 1) ? deepcopy(t -> p(t)[1]) : deepcopy(t -> p(t)) var = (dim_v == 1) ? v[1] : v - # misc infos - infos = Dict{Symbol,Any}() + # misc infos (use provided infos or empty dict) # nonlinear constraints and dual variables path_constraints_dual_fun = if isnothing(path_constraints_dual) @@ -573,6 +574,7 @@ end """ $(TYPEDSIGNATURES) +Return the dual model containing all constraint multipliers. """ function dual_model( sol::Solution{ @@ -674,6 +676,7 @@ end """ $(TYPEDSIGNATURES) +Return the optimal control problem model associated with the solution. """ function model( sol::Solution{ diff --git a/src/state.jl b/src/ocp/state.jl similarity index 100% rename from src/state.jl rename to src/ocp/state.jl diff --git a/src/time_dependence.jl b/src/ocp/time_dependence.jl similarity index 100% rename from src/time_dependence.jl rename to src/ocp/time_dependence.jl diff --git a/src/times.jl b/src/ocp/times.jl similarity index 100% rename from src/times.jl rename to src/ocp/times.jl diff --git a/src/variable.jl b/src/ocp/variable.jl similarity index 100% rename from src/variable.jl rename to src/ocp/variable.jl diff --git a/src/types.jl b/src/types.jl deleted file mode 100644 index 6915db52..00000000 --- a/src/types.jl +++ /dev/null @@ -1,643 +0,0 @@ -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type TimeDependence end - -""" -$(TYPEDEF) -""" -abstract type Autonomous<:TimeDependence end - -""" -$(TYPEDEF) -""" -abstract type NonAutonomous<:TimeDependence end - -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type AbstractStateModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct StateModel <: AbstractStateModel - name::String - components::Vector{String} -end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct StateModelSolution{TS<:Function} <: AbstractStateModel - name::String - components::Vector{String} - value::TS -end - -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type AbstractControlModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct ControlModel <: AbstractControlModel - name::String - components::Vector{String} -end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct ControlModelSolution{TS<:Function} <: AbstractControlModel - name::String - components::Vector{String} - value::TS -end - -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type AbstractVariableModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct VariableModel <: AbstractVariableModel - name::String - components::Vector{String} -end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct EmptyVariableModel <: AbstractVariableModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct VariableModelSolution{TS<:Union{ctNumber,ctVector}} <: AbstractVariableModel - name::String - components::Vector{String} - value::TS -end - -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type AbstractTimeModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct FixedTimeModel{T<:Time} <: AbstractTimeModel - time::T - name::String -end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct FreeTimeModel <: AbstractTimeModel - index::Int - name::String -end - -""" -$(TYPEDEF) -""" -abstract type AbstractTimesModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct TimesModel{TI<:AbstractTimeModel,TF<:AbstractTimeModel} <: AbstractTimesModel - initial::TI - final::TF - time_name::String -end - -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type AbstractObjectiveModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct MayerObjectiveModel{TM<:Function} <: AbstractObjectiveModel - mayer::TM - criterion::Symbol -end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct LagrangeObjectiveModel{TL<:Function} <: AbstractObjectiveModel - lagrange::TL - criterion::Symbol -end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct BolzaObjectiveModel{TM<:Function,TL<:Function} <: AbstractObjectiveModel - mayer::TM - lagrange::TL - criterion::Symbol -end - -# ------------------------------------------------------------------------------ # -# Constraints -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type AbstractConstraintsModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct ConstraintsModel{TP<:Tuple,TB<:Tuple,TS<:Tuple,TC<:Tuple,TV<:Tuple} <: - AbstractConstraintsModel - path_nl::TP - boundary_nl::TB - state_box::TS - control_box::TC - variable_box::TV -end - -# ------------------------------------------------------------------------------ # -# Model -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type AbstractModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct Model{ - TD<:TimeDependence, - TimesModelType<:AbstractTimesModel, - StateModelType<:AbstractStateModel, - ControlModelType<:AbstractControlModel, - VariableModelType<:AbstractVariableModel, - DynamicsModelType<:Function, - ObjectiveModelType<:AbstractObjectiveModel, - ConstraintsModelType<:AbstractConstraintsModel, - BuildExaModelType<:Union{Function,Nothing}, -} <: AbstractModel - times::TimesModelType - state::StateModelType - control::ControlModelType - variable::VariableModelType - dynamics::DynamicsModelType - objective::ObjectiveModelType - constraints::ConstraintsModelType - definition::Expr - build_examodel::BuildExaModelType - - function Model{TD}( # TD must be specified explicitly - times::AbstractTimesModel, - state::AbstractStateModel, - control::AbstractControlModel, - variable::AbstractVariableModel, - dynamics::Function, - objective::AbstractObjectiveModel, - constraints::AbstractConstraintsModel, - definition::Expr, - build_examodel::Union{Function,Nothing}, - ) where {TD<:TimeDependence} - return new{ - TD, - typeof(times), - typeof(state), - typeof(control), - typeof(variable), - typeof(dynamics), - typeof(objective), - typeof(constraints), - typeof(build_examodel), - }( - times, - state, - control, - variable, - dynamics, - objective, - constraints, - definition, - build_examodel, - ) - end -end - -""" -$(TYPEDSIGNATURES) - -""" -__is_times_set(ocp::Model)::Bool = true - -""" -$(TYPEDSIGNATURES) - -""" -__is_state_set(ocp::Model)::Bool = true - -""" -$(TYPEDSIGNATURES) - -""" -__is_control_set(ocp::Model)::Bool = true - -""" -$(TYPEDSIGNATURES) - -""" -__is_variable_set(ocp::Model)::Bool = true - -""" -$(TYPEDSIGNATURES) - -""" -__is_dynamics_set(ocp::Model)::Bool = true - -""" -$(TYPEDSIGNATURES) - -""" -__is_objective_set(ocp::Model)::Bool = true - -""" -$(TYPEDSIGNATURES) - -""" -__is_definition_set(ocp::Model)::Bool = true - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -@with_kw mutable struct PreModel <: AbstractModel - times::Union{AbstractTimesModel,Nothing} = nothing - state::Union{AbstractStateModel,Nothing} = nothing - control::Union{AbstractControlModel,Nothing} = nothing - variable::AbstractVariableModel = EmptyVariableModel() - dynamics::Union{Function,Vector{<:Tuple{<:AbstractRange{<:Int},<:Function}},Nothing} = - nothing - objective::Union{AbstractObjectiveModel,Nothing} = nothing - constraints::ConstraintsDictType = ConstraintsDictType() - definition::Union{Expr,Nothing} = nothing - autonomous::Union{Bool,Nothing} = nothing -end - -""" -$(TYPEDSIGNATURES) - -""" -__is_set(x) = !isnothing(x) - -""" -$(TYPEDSIGNATURES) - -""" -__is_autonomous_set(ocp::PreModel)::Bool = __is_set(ocp.autonomous) - -""" -$(TYPEDSIGNATURES) - -""" -__is_times_set(ocp::PreModel)::Bool = __is_set(ocp.times) - -""" -$(TYPEDSIGNATURES) - -""" -__is_state_set(ocp::PreModel)::Bool = __is_set(ocp.state) - -""" -$(TYPEDSIGNATURES) - -""" -__is_control_set(ocp::PreModel)::Bool = __is_set(ocp.control) - -""" -$(TYPEDSIGNATURES) - -""" -__is_variable_empty(v) = v isa EmptyVariableModel - -""" -$(TYPEDSIGNATURES) - -""" -__is_variable_set(ocp::PreModel)::Bool = !__is_variable_empty(ocp.variable) - -""" -$(TYPEDSIGNATURES) - -""" -__is_dynamics_set(ocp::PreModel)::Bool = __is_set(ocp.dynamics) - -""" -$(TYPEDSIGNATURES) - -""" -__is_objective_set(ocp::PreModel)::Bool = __is_set(ocp.objective) - -""" -$(TYPEDSIGNATURES) - -""" -__is_definition_set(ocp::PreModel)::Bool = __is_set(ocp.definition) - -""" -$(TYPEDSIGNATURES) - -""" -function state_dimension(ocp::PreModel)::Dimension - @ensure(__is_state_set(ocp), CTBase.UnauthorizedCall("the state must be set.")) - return length(ocp.state.components) -end - -""" -$(TYPEDSIGNATURES) - -""" -function __is_dynamics_complete(ocp::PreModel)::Bool - if isnothing(ocp.dynamics) - return false - elseif ocp.dynamics isa Function - return true - else # ocp.dynamics isa Vector{<:Tuple{<:AbstractRange{<:Int},<:Function}} - @ensure(__is_state_set(ocp), CTBase.UnauthorizedCall("the state must be set.")) - n = state_dimension(ocp) - covered = falses(n) - for (range, _) in ocp.dynamics - for i in range - if 1 <= i <= n - covered[i] = true - else - throw( - CTBase.UnauthorizedCall( - "Dynamics index $i out of bounds for state of size $n." - ), - ) - end - end - end - return all(covered) - end -end - -""" -$(TYPEDSIGNATURES) - -Return true if all the required fields are set in the PreModel. -""" -function __is_consistent(ocp::PreModel)::Bool - return __is_times_set(ocp) && - __is_state_set(ocp) && - __is_control_set(ocp) && - __is_dynamics_complete(ocp) && - __is_objective_set(ocp) && - __is_autonomous_set(ocp) -end - -""" -$(TYPEDSIGNATURES) - -Return true if the PreModel can be built into a Model. -""" -function __is_complete(ocp::PreModel)::Bool - return __is_times_set(ocp) && - __is_state_set(ocp) && - __is_control_set(ocp) && - __is_dynamics_complete(ocp) && - __is_objective_set(ocp) && - __is_definition_set(ocp) && - __is_autonomous_set(ocp) -end - -""" -$(TYPEDSIGNATURES) - -Return true if nothing has been set. -""" -function __is_empty(ocp::PreModel)::Bool - return !__is_times_set(ocp) && - !__is_state_set(ocp) && - !__is_control_set(ocp) && - !__is_dynamics_set(ocp) && - !__is_objective_set(ocp) && - !__is_definition_set(ocp) && - !__is_variable_set(ocp) && - !__is_autonomous_set(ocp) && - Base.isempty(ocp.constraints) -end - -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type AbstractTimeGridModel end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct TimeGridModel{T<:TimesDisc} <: AbstractTimeGridModel - value::T -end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct EmptyTimeGridModel <: AbstractTimeGridModel end - -is_empty(model::EmptyTimeGridModel)::Bool = true -is_empty(model::TimeGridModel)::Bool = false - -# ------------------------------------------------------------------------------ # -# Solver infos -""" -$(TYPEDEF) -""" -abstract type AbstractSolverInfos end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct SolverInfos{TI<:Dict{Symbol,Any}} <: AbstractSolverInfos - iterations::Int # number of iterations - status::Symbol # the status criterion - message::String # the message corresponding to the status criterion - successful::Bool # whether or not the method has finished successfully: CN1, stagnation vs iterations max - constraints_violation::Float64 # the constraints violation - infos::TI # additional information -end - -# ------------------------------------------------------------------------------ # -# Constraints and dual variables for the solutions -""" -$(TYPEDEF) -""" -abstract type AbstractDualModel end - -""" - -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct DualModel{ - PC_Dual<:Union{Function,Nothing}, - BC_Dual<:Union{ctVector,Nothing}, - SC_LB_Dual<:Union{Function,Nothing}, - SC_UB_Dual<:Union{Function,Nothing}, - CC_LB_Dual<:Union{Function,Nothing}, - CC_UB_Dual<:Union{Function,Nothing}, - VC_LB_Dual<:Union{ctVector,Nothing}, - VC_UB_Dual<:Union{ctVector,Nothing}, -} <: AbstractDualModel - path_constraints_dual::PC_Dual - boundary_constraints_dual::BC_Dual - state_constraints_lb_dual::SC_LB_Dual - state_constraints_ub_dual::SC_UB_Dual - control_constraints_lb_dual::CC_LB_Dual - control_constraints_ub_dual::CC_UB_Dual - variable_constraints_lb_dual::VC_LB_Dual - variable_constraints_ub_dual::VC_UB_Dual -end - -# ------------------------------------------------------------------------------ # -# Solution -# ------------------------------------------------------------------------------ # -""" -$(TYPEDEF) -""" -abstract type AbstractSolution end - -""" -$(TYPEDEF) - -**Fields** - -$(TYPEDFIELDS) -""" -struct Solution{ - TimeGridModelType<:AbstractTimeGridModel, - TimesModelType<:AbstractTimesModel, - StateModelType<:AbstractStateModel, - ControlModelType<:AbstractControlModel, - VariableModelType<:AbstractVariableModel, - CostateModelType<:Function, - ObjectiveValueType<:ctNumber, - DualModelType<:AbstractDualModel, - SolverInfosType<:AbstractSolverInfos, - ModelType<:AbstractModel, -} <: AbstractSolution - time_grid::TimeGridModelType - times::TimesModelType - state::StateModelType - control::ControlModelType - variable::VariableModelType - costate::CostateModelType - objective::ObjectiveValueType - dual::DualModelType - solver_infos::SolverInfosType - model::ModelType -end - -""" -$(TYPEDSIGNATURES) - -Check if the time grid is empty from the solution. -""" -is_empty_time_grid(sol::Solution)::Bool = is_empty(sol.time_grid) diff --git a/test/Project.toml b/test/Project.toml index 2f04b64b..95086f75 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,23 +2,25 @@ ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CTBase = "54762871-cc72-4466-b8e8-f6c8b58076cd" -CTDirect = "790bbbee-bee9-49ee-8912-a9de031322d5" -CTParser = "32681960-a1b1-40db-9bff-a1ca817385d1" +ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -NLPModelsIpopt = "f4238b75-b362-5c4c-b852-0801c9a21d71" +NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SolverCore = "ff4d7338-4cf1-434d-91df-b86cb86fb843" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ADNLPModels = "0.8" Aqua = "0.8" -CTBase = "0.16" -CTDirect = "0.17" -CTParser = "0.7" -JLD2 = "0.6" -JSON3 = "1" -NLPModelsIpopt = "0.11" +CTBase = "0.17" +ExaModels = "0.9" +NLPModels = "0.21" +OrderedCollections = "1.8" Plots = "1" +Random = "1" +SolverCore = "0.3" Test = "1" julia = "1.10" diff --git a/test/core/test_default.jl b/test/core/test_default.jl new file mode 100644 index 00000000..9e4543c3 --- /dev/null +++ b/test/core/test_default.jl @@ -0,0 +1,55 @@ +function test_default() + # TODO: add tests for src/core/default.jl (default options, etc.). + + Test.@testset "constraints and format defaults" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test CTModels.__constraints() === nothing + Test.@test CTModels.__format() == :JLD + + label1 = CTModels.__constraint_label() + label2 = CTModels.__constraint_label() + Test.@test label1 isa Symbol + Test.@test label2 isa Symbol + Test.@test label1 != label2 + Test.@test startswith(String(label1), "##unnamed") + Test.@test startswith(String(label2), "##unnamed") + end + + Test.@testset "state and control naming defaults" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test CTModels.__state_name() == "x" + Test.@test CTModels.__control_name() == "u" + + comps_state_1 = CTModels.__state_components(1, "x") + comps_state_3 = CTModels.__state_components(3, "x") + Test.@test comps_state_1 == ["x"] + Test.@test comps_state_3 == ["x" * CTBase.ctindices(i) for i in 1:3] + + comps_control_1 = CTModels.__control_components(1, "u") + comps_control_3 = CTModels.__control_components(3, "u") + Test.@test comps_control_1 == ["u"] + Test.@test comps_control_3 == ["u" * CTBase.ctindices(i) for i in 1:3] + end + + Test.@testset "time and criterion defaults" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test CTModels.__time_name() == "t" + Test.@test CTModels.__criterion_type() == :min + end + + Test.@testset "variable naming defaults" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test CTModels.__variable_name(0) == "" + Test.@test CTModels.__variable_name(1) == "v" + Test.@test CTModels.__variable_name(3) == "v" + + comps_var_0 = CTModels.__variable_components(0, "v") + comps_var_1 = CTModels.__variable_components(1, "v") + comps_var_3 = CTModels.__variable_components(3, "v") + + Test.@test comps_var_0 == String[] + Test.@test comps_var_1 == ["v"] + Test.@test comps_var_3 == ["v" * CTBase.ctindices(i) for i in 1:3] + end + + Test.@testset "matrix and filename defaults" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test CTModels.__matrix_dimension_storage() == 1 + Test.@test CTModels.__filename_export_import() == "solution" + end +end diff --git a/test/core/test_initial_guess_types.jl b/test/core/test_initial_guess_types.jl new file mode 100644 index 00000000..c81bb04a --- /dev/null +++ b/test/core/test_initial_guess_types.jl @@ -0,0 +1,62 @@ +function test_initial_guess_types() + # TODO: add tests for src/core/types/initial_guess.jl. + + # ======================================================================== + # Unit tests – core initial guess types + # ======================================================================== + + Test.@testset "OptimalControlInitialGuess structure" verbose=VERBOSE showtiming=SHOWTIMING begin + state_fun = t -> [t] + control_fun = t -> [-t] + variable_vec = [1.0, 2.0] + + ig = CTModels.OptimalControlInitialGuess(state_fun, control_fun, variable_vec) + + Test.@test ig.state === state_fun + Test.@test ig.control === control_fun + Test.@test ig.variable === variable_vec + + # Type parameters should reflect the concrete field types + Test.@test ig isa + CTModels.OptimalControlInitialGuess{typeof(state_fun),typeof(control_fun),typeof(variable_vec)} + end + + Test.@testset "OptimalControlPreInit structure" verbose=VERBOSE showtiming=SHOWTIMING begin + sx = :state_spec + su = :control_spec + sv = :variable_spec + + pre = CTModels.OptimalControlPreInit(sx, su, sv) + + Test.@test pre.state === sx + Test.@test pre.control === su + Test.@test pre.variable === sv + end + + # ======================================================================== + # Integration-style tests – fake consumer of initial guesses + # ======================================================================== + + Test.@testset "fake consumer of OptimalControlInitialGuess" verbose=VERBOSE showtiming=SHOWTIMING begin + state_fun = t -> 2t + control_fun = t -> -3t + variable_val = 1.23 + + ig = CTModels.OptimalControlInitialGuess(state_fun, control_fun, variable_val) + + # Simple fake consumer that only relies on the fields of the type + function consume_initial_guess(ig_local) + y = ig_local.state(0.5) + u = ig_local.control(0.5) + v = ig_local.variable + return y, u, v + end + + y, u, v = consume_initial_guess(ig) + + Test.@test y == 2 * 0.5 + Test.@test u == -3 * 0.5 + Test.@test v == variable_val + end +end + diff --git a/test/core/test_nlp_types.jl b/test/core/test_nlp_types.jl new file mode 100644 index 00000000..fdb7145a --- /dev/null +++ b/test/core/test_nlp_types.jl @@ -0,0 +1,30 @@ +function test_nlp_types() + # ---------------------------------------------------------------------- + # Type hierarchy for builders and optimization problems + # (moved from test/nlp/test_problem_core.jl) + # ---------------------------------------------------------------------- + Test.@testset "type hierarchy" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test isabstracttype(CTModels.AbstractBuilder) + Test.@test isabstracttype(CTModels.AbstractModelBuilder) + Test.@test isabstracttype(CTModels.AbstractSolutionBuilder) + Test.@test isabstracttype(CTModels.AbstractOptimizationProblem) + + Test.@test CTModels.ADNLPModelBuilder <: CTModels.AbstractModelBuilder + Test.@test CTModels.ExaModelBuilder <: CTModels.AbstractModelBuilder + end + + # ---------------------------------------------------------------------- + # Type hierarchy for OCP solution builders + # (moved from test/nlp/test_discretized_ocp.jl) + # ---------------------------------------------------------------------- + Test.@testset "type hierarchy" verbose=VERBOSE showtiming=SHOWTIMING begin + # AbstractOCPSolutionBuilder should be abstract and inherit from AbstractSolutionBuilder + Test.@test isabstracttype(CTModels.AbstractOCPSolutionBuilder) + Test.@test CTModels.AbstractOCPSolutionBuilder <: CTModels.AbstractSolutionBuilder + + # Concrete solution builders should inherit from AbstractOCPSolutionBuilder + Test.@test CTModels.ADNLPSolutionBuilder <: CTModels.AbstractOCPSolutionBuilder + Test.@test CTModels.ExaSolutionBuilder <: CTModels.AbstractOCPSolutionBuilder + end +end + diff --git a/test/core/test_ocp_components.jl b/test/core/test_ocp_components.jl new file mode 100644 index 00000000..2cb419bb --- /dev/null +++ b/test/core/test_ocp_components.jl @@ -0,0 +1,64 @@ +function test_ocp_components() + # TODO: add tests for src/core/types/ocp_components.jl. + + Test.@testset "state/control/variable models" verbose=VERBOSE showtiming=SHOWTIMING begin + state = CTModels.StateModel("y", ["u", "v"]) + Test.@test CTModels.dimension(state) == 2 + Test.@test CTModels.name(state) == "y" + Test.@test CTModels.components(state) == ["u", "v"] + + control = CTModels.ControlModel("u", ["u₁", "u₂"]) + Test.@test CTModels.dimension(control) == 2 + Test.@test CTModels.name(control) == "u" + Test.@test CTModels.components(control) == ["u₁", "u₂"] + + variable = CTModels.VariableModel("v", ["v₁", "v₂"]) + Test.@test CTModels.dimension(variable) == 2 + Test.@test CTModels.name(variable) == "v" + Test.@test CTModels.components(variable) == ["v₁", "v₂"] + end + + Test.@testset "time models" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test isabstracttype(CTModels.AbstractTimeModel) + + t0 = CTModels.FixedTimeModel(0.0, "t₀") + tf = CTModels.FixedTimeModel(1.0, "t_f") + Test.@test t0.time == 0.0 + Test.@test t0.name == "t₀" + Test.@test tf.time == 1.0 + Test.@test tf.name == "t_f" + + free_t0 = CTModels.FreeTimeModel(1, "t₀") + free_tf = CTModels.FreeTimeModel(2, "t_f") + Test.@test free_t0.index == 1 + Test.@test free_t0.name == "t₀" + Test.@test free_tf.index == 2 + Test.@test free_tf.name == "t_f" + + times = CTModels.TimesModel(t0, tf, "t") + Test.@test times.initial === t0 + Test.@test times.final === tf + Test.@test times.time_name == "t" + end + + Test.@testset "objective and constraints models" verbose=VERBOSE showtiming=SHOWTIMING begin + mayer_f = (x0, xf, v) -> x0[1] + xf[1] + lagrange_f = (t, x, u, v) -> u[1]^2 + + mayer = CTModels.MayerObjectiveModel(mayer_f, :min) + lagrange = CTModels.LagrangeObjectiveModel(lagrange_f, :max) + bolza = CTModels.BolzaObjectiveModel(mayer_f, lagrange_f, :min) + + Test.@test mayer.criterion == :min + Test.@test lagrange.criterion == :max + Test.@test bolza.criterion == :min + + # Simple construction of an empty ConstraintsModel + constraints = CTModels.ConstraintsModel((), (), (), (), ()) + Test.@test constraints.path_nl == () + Test.@test constraints.boundary_nl == () + Test.@test constraints.state_box == () + Test.@test constraints.control_box == () + Test.@test constraints.variable_box == () + end +end diff --git a/test/core/test_ocp_model_types.jl b/test/core/test_ocp_model_types.jl new file mode 100644 index 00000000..1590c5c3 --- /dev/null +++ b/test/core/test_ocp_model_types.jl @@ -0,0 +1,143 @@ +function test_ocp_model_types() + # TODO: add tests for src/core/types/ocp_model.jl. + + # ======================================================================== + # Unit tests – core OCP model types + # ======================================================================== + + Test.@testset "Model and PreModel hierarchy" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test isabstracttype(CTModels.AbstractModel) + Test.@test CTModels.Model <: CTModels.AbstractModel + Test.@test CTModels.PreModel <: CTModels.AbstractModel + end + + Test.@testset "__is_* predicates on Model" verbose=VERBOSE showtiming=SHOWTIMING begin + times = CTModels.TimesModel( + CTModels.FixedTimeModel(0.0, "t₀"), CTModels.FixedTimeModel(1.0, "t_f"), "t" + ) + state = CTModels.StateModel("x", ["x"]) + control = CTModels.ControlModel("u", ["u"]) + variable = CTModels.VariableModel("v", ["v"]) + dynamics = (r, t, x, u, v) -> nothing + objective = CTModels.MayerObjectiveModel((x0, xf, v) -> 0.0, :min) + constraints = CTModels.ConstraintsModel((), (), (), (), ()) + definition = quote end + build_examodel = nothing + + ocp = CTModels.Model{CTModels.Autonomous}( + times, + state, + control, + variable, + dynamics, + objective, + constraints, + definition, + build_examodel, + ) + + # Type parameters should follow the concrete component types + Test.@test ocp isa CTModels.Model{ + CTModels.Autonomous, + typeof(times), + typeof(state), + typeof(control), + typeof(variable), + typeof(dynamics), + typeof(objective), + typeof(constraints), + typeof(build_examodel), + } + + Test.@test CTModels.__is_times_set(ocp) + Test.@test CTModels.__is_state_set(ocp) + Test.@test CTModels.__is_control_set(ocp) + Test.@test CTModels.__is_variable_set(ocp) + Test.@test CTModels.__is_dynamics_set(ocp) + Test.@test CTModels.__is_objective_set(ocp) + Test.@test CTModels.__is_definition_set(ocp) + end + + Test.@testset "__is_* predicates on PreModel" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = CTModels.PreModel() + + # Fresh PreModel should be empty + Test.@test CTModels.__is_empty(ocp) + Test.@test !CTModels.__is_times_set(ocp) + Test.@test !CTModels.__is_state_set(ocp) + Test.@test !CTModels.__is_control_set(ocp) + Test.@test !CTModels.__is_dynamics_set(ocp) + Test.@test !CTModels.__is_objective_set(ocp) + Test.@test !CTModels.__is_definition_set(ocp) + + times = CTModels.TimesModel( + CTModels.FixedTimeModel(0.0, "t₀"), CTModels.FixedTimeModel(1.0, "t_f"), "t" + ) + state = CTModels.StateModel("x", ["x"]) + control = CTModels.ControlModel("u", ["u"]) + variable = CTModels.VariableModel("v", ["v"]) + dynamics = (r, t, x, u, v) -> nothing + objective = CTModels.MayerObjectiveModel((x0, xf, v) -> 0.0, :min) + + ocp.times = times + ocp.state = state + ocp.control = control + ocp.variable = variable + ocp.dynamics = dynamics + ocp.objective = objective + ocp.autonomous = true + + Test.@test CTModels.__is_times_set(ocp) + Test.@test CTModels.__is_state_set(ocp) + Test.@test CTModels.__is_control_set(ocp) + Test.@test CTModels.__is_variable_set(ocp) + Test.@test CTModels.__is_dynamics_set(ocp) + Test.@test CTModels.__is_objective_set(ocp) + Test.@test CTModels.__is_autonomous_set(ocp) + + # At this stage the model is consistent but not yet complete + Test.@test CTModels.__is_consistent(ocp) + Test.@test !CTModels.__is_complete(ocp) + + ocp.definition = quote end + + Test.@test CTModels.__is_definition_set(ocp) + Test.@test CTModels.__is_complete(ocp) + Test.@test !CTModels.__is_empty(ocp) + end + + # ======================================================================== + # Integration-style tests – fake buildability check + # ======================================================================== + + Test.@testset "fake PreModel buildability" verbose=VERBOSE showtiming=SHOWTIMING begin + function can_build(ocp_local) + return CTModels.__is_complete(ocp_local) + end + + empty_ocp = CTModels.PreModel() + Test.@test !can_build(empty_ocp) + + times = CTModels.TimesModel( + CTModels.FixedTimeModel(0.0, "t₀"), CTModels.FixedTimeModel(1.0, "t_f"), "t" + ) + state = CTModels.StateModel("x", ["x"]) + control = CTModels.ControlModel("u", ["u"]) + variable = CTModels.VariableModel("v", ["v"]) + dynamics = (r, t, x, u, v) -> nothing + objective = CTModels.MayerObjectiveModel((x0, xf, v) -> 0.0, :min) + + ocp = CTModels.PreModel() + ocp.times = times + ocp.state = state + ocp.control = control + ocp.variable = variable + ocp.dynamics = dynamics + ocp.objective = objective + ocp.definition = quote end + ocp.autonomous = true + + Test.@test can_build(ocp) + end +end + diff --git a/test/core/test_ocp_solution_types.jl b/test/core/test_ocp_solution_types.jl new file mode 100644 index 00000000..31bdf7d7 --- /dev/null +++ b/test/core/test_ocp_solution_types.jl @@ -0,0 +1,235 @@ +function test_ocp_solution_types() + # TODO: add tests for src/core/types/ocp_solution.jl. + + # ======================================================================== + # Unit tests – core solution-related types + # ======================================================================== + + Test.@testset "TimeGridModel and is_empty" verbose=VERBOSE showtiming=SHOWTIMING begin + grid = CTModels.TimeGridModel([0.0, 0.5, 1.0]) + empty_grid = CTModels.EmptyTimeGridModel() + + Test.@test CTModels.is_empty(empty_grid) + Test.@test !CTModels.is_empty(grid) + end + + Test.@testset "SolverInfos structure" verbose=VERBOSE showtiming=SHOWTIMING begin + extra_infos = Dict(:foo => 1, :bar => "x") + infos = CTModels.SolverInfos(10, :ok, "message", true, 1e-3, extra_infos) + + Test.@test infos.iterations == 10 + Test.@test infos.status == :ok + Test.@test infos.message == "message" + Test.@test infos.successful + Test.@test infos.constraints_violation ≈ 1e-3 + Test.@test infos.infos === extra_infos + Test.@test infos isa CTModels.AbstractSolverInfos + end + + Test.@testset "DualModel structure" verbose=VERBOSE showtiming=SHOWTIMING begin + pc = t -> [1.0, 2.0] + bc = [3.0, 4.0] + sc_lb = t -> [0.0] + sc_ub = t -> [1.0] + cc_lb = t -> [0.0] + cc_ub = t -> [1.0] + vc_lb = [5.0] + vc_ub = [6.0] + + dual = CTModels.DualModel(pc, bc, sc_lb, sc_ub, cc_lb, cc_ub, vc_lb, vc_ub) + + Test.@test dual.path_constraints_dual === pc + Test.@test dual.boundary_constraints_dual === bc + Test.@test dual.state_constraints_lb_dual === sc_lb + Test.@test dual.state_constraints_ub_dual === sc_ub + Test.@test dual.control_constraints_lb_dual === cc_lb + Test.@test dual.control_constraints_ub_dual === cc_ub + Test.@test dual.variable_constraints_lb_dual === vc_lb + Test.@test dual.variable_constraints_ub_dual === vc_ub + end + + Test.@testset "Solution structure and empty time grid" verbose=VERBOSE showtiming=SHOWTIMING begin + times = CTModels.TimesModel( + CTModels.FixedTimeModel(0.0, "t₀"), CTModels.FixedTimeModel(1.0, "t_f"), "t" + ) + state = CTModels.StateModel("x", ["x"]) + control = CTModels.ControlModel("u", ["u"]) + variable = CTModels.VariableModel("v", ["v"]) + + costate_fun = t -> [0.0] + objective_val = 0.0 + + dual = CTModels.DualModel( + nothing, + nothing, + nothing, + nothing, + nothing, + nothing, + nothing, + nothing, + ) + + infos = CTModels.SolverInfos(0, :unknown, "", false, 0.0, Dict{Symbol,Any}()) + + dynamics = (r, t, x, u, v) -> nothing + objective = CTModels.MayerObjectiveModel((x0, xf, v) -> 0.0, :min) + constraints = CTModels.ConstraintsModel((), (), (), (), ()) + definition = quote end + build_examodel = nothing + + model = CTModels.Model{CTModels.Autonomous}( + times, + state, + control, + variable, + dynamics, + objective, + constraints, + definition, + build_examodel, + ) + + grid_full = CTModels.TimeGridModel([0.0, 0.5, 1.0]) + grid_empty = CTModels.EmptyTimeGridModel() + + sol_full = CTModels.Solution( + grid_full, + times, + state, + control, + variable, + costate_fun, + objective_val, + dual, + infos, + model, + ) + + sol_empty = CTModels.Solution( + grid_empty, + times, + state, + control, + variable, + costate_fun, + objective_val, + dual, + infos, + model, + ) + + # Type parameters should reflect the underlying component types + Test.@test sol_full isa CTModels.Solution{ + typeof(grid_full), + typeof(times), + typeof(state), + typeof(control), + typeof(variable), + typeof(costate_fun), + typeof(objective_val), + typeof(dual), + typeof(infos), + typeof(model), + } + + Test.@test sol_empty isa CTModels.Solution{ + typeof(grid_empty), + typeof(times), + typeof(state), + typeof(control), + typeof(variable), + typeof(costate_fun), + typeof(objective_val), + typeof(dual), + typeof(infos), + typeof(model), + } + + Test.@test !CTModels.is_empty_time_grid(sol_full) + Test.@test CTModels.is_empty_time_grid(sol_empty) + end + + # ======================================================================== + # Integration-style tests – fake post-processing of a Solution + # ======================================================================== + + Test.@testset "fake Solution summary" verbose=VERBOSE showtiming=SHOWTIMING begin + times = CTModels.TimesModel( + CTModels.FixedTimeModel(0.0, "t₀"), CTModels.FixedTimeModel(1.0, "t_f"), "t" + ) + state = CTModels.StateModel("x", ["x"]) + control = CTModels.ControlModel("u", ["u"]) + variable = CTModels.VariableModel("v", ["v"]) + + costate_fun = t -> [0.0] + objective_val = 42.0 + + dual = CTModels.DualModel( + nothing, + nothing, + nothing, + nothing, + nothing, + nothing, + nothing, + nothing, + ) + + infos = CTModels.SolverInfos( + 15, + :converged, + "ok", + true, + 0.0, + Dict(:nit => 15), + ) + + dynamics = (r, t, x, u, v) -> nothing + objective = CTModels.MayerObjectiveModel((x0, xf, v) -> 0.0, :min) + constraints = CTModels.ConstraintsModel((), (), (), (), ()) + definition = quote end + build_examodel = nothing + + model = CTModels.Model{CTModels.Autonomous}( + times, + state, + control, + variable, + dynamics, + objective, + constraints, + definition, + build_examodel, + ) + + grid = CTModels.TimeGridModel([0.0, 1.0]) + sol = CTModels.Solution( + grid, + times, + state, + control, + variable, + costate_fun, + objective_val, + dual, + infos, + model, + ) + + function extract_summary(sol_local) + return ( + iterations=sol_local.solver_infos.iterations, + status=sol_local.solver_infos.status, + objective=sol_local.objective, + ) + end + + summary = extract_summary(sol) + + Test.@test summary.iterations == 15 + Test.@test summary.status == :converged + Test.@test summary.objective == 42.0 + end +end + diff --git a/test/core/test_types.jl b/test/core/test_types.jl new file mode 100644 index 00000000..b977cab5 --- /dev/null +++ b/test/core/test_types.jl @@ -0,0 +1,34 @@ +function test_types() + # TODO: add tests for src/core/types.jl (type includes and basic consistency). + + Test.@testset "OCP model and solution core types" verbose=VERBOSE showtiming=SHOWTIMING begin + # Abstract/model hierarchy + Test.@test isabstracttype(CTModels.AbstractModel) + Test.@test CTModels.Model <: CTModels.AbstractModel + Test.@test CTModels.PreModel <: CTModels.AbstractModel + + # Solution hierarchy + Test.@test isabstracttype(CTModels.AbstractSolution) + Test.@test CTModels.Solution <: CTModels.AbstractSolution + + # Time grid and dual/infos hierarchy + Test.@test isabstracttype(CTModels.AbstractTimeGridModel) + Test.@test CTModels.TimeGridModel <: CTModels.AbstractTimeGridModel + + Test.@test isabstracttype(CTModels.AbstractDualModel) + Test.@test CTModels.DualModel <: CTModels.AbstractDualModel + + Test.@test isabstracttype(CTModels.AbstractSolverInfos) + Test.@test CTModels.SolverInfos <: CTModels.AbstractSolverInfos + end + + Test.@testset "Initial guess core types" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test isabstracttype(CTModels.AbstractOptimalControlInitialGuess) + Test.@test CTModels.OptimalControlInitialGuess <: + CTModels.AbstractOptimalControlInitialGuess + + Test.@test isabstracttype(CTModels.AbstractOptimalControlPreInit) + Test.@test CTModels.OptimalControlPreInit <: + CTModels.AbstractOptimalControlPreInit + end +end diff --git a/test/test_utils.jl b/test/core/test_utils.jl similarity index 100% rename from test/test_utils.jl rename to test/core/test_utils.jl diff --git a/test/extras/test_manual.jl b/test/extras/test_manual.jl index c9df5032..7c1b1c64 100644 --- a/test/extras/test_manual.jl +++ b/test/extras/test_manual.jl @@ -4,9 +4,9 @@ using CTModels using Plots import CTParser: CTParser, @def CTParser.set_prefix(:CTModels); # code generated by @def is prefixed by CTModels (not by OptimalControl - the default) -include("../solution_example_path_constraints.jl") +include("../solution_example_dual.jl") -ocp, sol = solution_example_path_constraints() +ocp, sol = solution_example_dual() # @test plot(sol; time=:default) isa Plots.Plot diff --git a/test/init/test_initial_guess.jl b/test/init/test_initial_guess.jl new file mode 100644 index 00000000..87b0b815 --- /dev/null +++ b/test/init/test_initial_guess.jl @@ -0,0 +1,532 @@ +# Unit tests for CTModels initial guess construction and validation. +struct DummyOCP1DNoVar <: CTModels.AbstractModel end +struct DummyOCP1DVar <: CTModels.AbstractModel end +struct DummyOCP1D2Var <: CTModels.AbstractModel end + +CTModels.state_dimension(::DummyOCP1DNoVar) = 1 +CTModels.control_dimension(::DummyOCP1DNoVar) = 1 +CTModels.variable_dimension(::DummyOCP1DNoVar) = 0 + +CTModels.has_fixed_initial_time(::DummyOCP1DNoVar) = true +CTModels.initial_time(::DummyOCP1DNoVar) = 0.0 + +CTModels.state_name(::DummyOCP1DNoVar) = "x" +CTModels.state_components(::DummyOCP1DNoVar) = ["x"] +CTModels.control_name(::DummyOCP1DNoVar) = "u" +CTModels.control_components(::DummyOCP1DNoVar) = ["u"] +CTModels.variable_name(::DummyOCP1DNoVar) = "v" +CTModels.variable_components(::DummyOCP1DNoVar) = String[] + +CTModels.state_dimension(::DummyOCP1DVar) = 1 +CTModels.control_dimension(::DummyOCP1DVar) = 1 +CTModels.variable_dimension(::DummyOCP1DVar) = 1 + +CTModels.has_fixed_initial_time(::DummyOCP1DVar) = true +CTModels.initial_time(::DummyOCP1DVar) = 0.0 + +CTModels.state_name(::DummyOCP1DVar) = "x" +CTModels.state_components(::DummyOCP1DVar) = ["x"] +CTModels.control_name(::DummyOCP1DVar) = "u" +CTModels.control_components(::DummyOCP1DVar) = ["u"] +CTModels.variable_name(::DummyOCP1DVar) = "v" +CTModels.variable_components(::DummyOCP1DVar) = ["v"] + +CTModels.state_dimension(::DummyOCP1D2Var) = 1 +CTModels.control_dimension(::DummyOCP1D2Var) = 1 +CTModels.variable_dimension(::DummyOCP1D2Var) = 2 + +CTModels.has_fixed_initial_time(::DummyOCP1D2Var) = true +CTModels.initial_time(::DummyOCP1D2Var) = 0.0 + +CTModels.state_name(::DummyOCP1D2Var) = "x" +CTModels.state_components(::DummyOCP1D2Var) = ["x"] +CTModels.control_name(::DummyOCP1D2Var) = "u" +CTModels.control_components(::DummyOCP1D2Var) = ["u"] +CTModels.variable_name(::DummyOCP1D2Var) = "w" +CTModels.variable_components(::DummyOCP1D2Var) = ["tf", "a"] + +struct DummyOCP2DNoVar <: CTModels.AbstractModel end + +CTModels.state_dimension(::DummyOCP2DNoVar) = 2 +CTModels.control_dimension(::DummyOCP2DNoVar) = 0 +CTModels.variable_dimension(::DummyOCP2DNoVar) = 0 + +CTModels.has_fixed_initial_time(::DummyOCP2DNoVar) = true +CTModels.initial_time(::DummyOCP2DNoVar) = 0.0 + +CTModels.state_name(::DummyOCP2DNoVar) = "x" +CTModels.state_components(::DummyOCP2DNoVar) = ["x1", "x2"] +CTModels.control_name(::DummyOCP2DNoVar) = "u" +CTModels.control_components(::DummyOCP2DNoVar) = String[] +CTModels.variable_name(::DummyOCP2DNoVar) = "v" +CTModels.variable_components(::DummyOCP2DNoVar) = String[] + +struct DummyOCP1D2Control <: CTModels.AbstractModel end + +CTModels.state_dimension(::DummyOCP1D2Control) = 1 +CTModels.control_dimension(::DummyOCP1D2Control) = 2 +CTModels.variable_dimension(::DummyOCP1D2Control) = 0 + +CTModels.has_fixed_initial_time(::DummyOCP1D2Control) = true +CTModels.initial_time(::DummyOCP1D2Control) = 0.0 + +CTModels.state_name(::DummyOCP1D2Control) = "x" +CTModels.state_components(::DummyOCP1D2Control) = ["x"] +CTModels.control_name(::DummyOCP1D2Control) = "u" +CTModels.control_components(::DummyOCP1D2Control) = ["u1", "u2"] +CTModels.variable_name(::DummyOCP1D2Control) = "v" +CTModels.variable_components(::DummyOCP1D2Control) = String[] + +struct DummySolution1DVar <: CTModels.AbstractSolution + model + xfun::Function + ufun::Function + v +end + +CTModels.state(sol::DummySolution1DVar) = sol.xfun +CTModels.control(sol::DummySolution1DVar) = sol.ufun +CTModels.variable(sol::DummySolution1DVar) = sol.v + +function test_initial_guess() + Test.@testset "basic construction and validation" verbose=VERBOSE showtiming=SHOWTIMING begin + # Simple 1D dummy problem: scalar x,u, no variable (dim(x)=dim(u)=1, dim(v)=0) + ocp1 = DummyOCP1DNoVar() + + # Scalar initial guess consistent with dimension 1 + init1 = CTModels.initial_guess(ocp1; state=0.2, control=-0.1) + Test.@test init1 isa CTModels.AbstractOptimalControlInitialGuess + # validate_initial_guess should not throw + CTModels.validate_initial_guess(ocp1, init1) + + # Incorrect vector initial guess for state (dim 1 but length 2) + bad_state = [0.1, 0.2] + Test.@test_throws CTBase.IncorrectArgument CTModels.initial_guess( + ocp1; state=bad_state + ) + + # Scalar control init is OK, but a function returning a length-2 vector must be rejected + bad_control_fun = t -> [t, 2t] + init_bad_ctrl = CTModels.OptimalControlInitialGuess( + CTModels.state(init1), bad_control_fun, Float64[] + ) + Test.@test_throws CTBase.IncorrectArgument CTModels.validate_initial_guess( + ocp1, init_bad_ctrl + ) + + # For a multi-dimensional state (dim(x)=2), passing a scalar state should trigger + # the dimension-mismatch error in initial_state(ocp, ::Real). + ocp_state2 = DummyOCP2DNoVar() + Test.@test_throws CTBase.IncorrectArgument CTModels.initial_guess( + ocp_state2; state=0.1 + ) + + # For a multi-dimensional control (dim(u)=2), passing a scalar control should + # trigger the analogous error in initial_control(ocp, ::Real). + ocp_ctrl2 = DummyOCP1D2Control() + Test.@test_throws CTBase.IncorrectArgument CTModels.initial_guess( + ocp_ctrl2; control=0.1 + ) + end + + Test.@testset "variable dimension handling" verbose=VERBOSE showtiming=SHOWTIMING begin + # Dummy problem with scalar variable (dim(x)=dim(u)=dim(v)=1) + ocp2 = DummyOCP1DVar() + + # Scalar variable consistent with dimension 1 + init2 = CTModels.initial_guess(ocp2; variable=0.5) + CTModels.validate_initial_guess(ocp2, init2) + + # Variable as a length-2 vector for dimension 1 must throw + Test.@test_throws CTBase.IncorrectArgument CTModels.initial_guess( + ocp2; variable=[0.1, 0.2] + ) + + # Problem without variable: dim(v) == 0 + beam_data = Beam() # beam has no variable in its initial guess + ocp3 = beam_data.ocp + # Providing a scalar variable must throw + Test.@test_throws CTBase.IncorrectArgument CTModels.initial_guess( + ocp3; variable=1.0 + ) + end + + Test.@testset "2D variable block and components" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP1D2Var() + + # Full block specification for variable w + init_block = (w=[1.0, 2.0],) + ig_block = CTModels.build_initial_guess(ocp, init_block) + Test.@test ig_block isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig_block) + v_block = CTModels.variable(ig_block) + Test.@test length(v_block) == 2 + Test.@test v_block[1] ≈ 1.0 + Test.@test v_block[2] ≈ 2.0 + + # Only the tf component (first component) + init_tf = (tf=1.0,) + ig_tf = CTModels.build_initial_guess(ocp, init_tf) + Test.@test ig_tf isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig_tf) + v_tf = CTModels.variable(ig_tf) + Test.@test length(v_tf) == 2 + Test.@test v_tf[1] ≈ 1.0 + Test.@test v_tf[2] ≈ 0.1 # default value coming from initial_variable(ocp, nothing) + + # Only the a component (second component) + init_a = (a=0.5,) + ig_a = CTModels.build_initial_guess(ocp, init_a) + Test.@test ig_a isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig_a) + v_a = CTModels.variable(ig_a) + Test.@test length(v_a) == 2 + Test.@test v_a[1] ≈ 0.1 + Test.@test v_a[2] ≈ 0.5 + + # Both components specified + init_both = (tf=1.0, a=0.5) + ig_both = CTModels.build_initial_guess(ocp, init_both) + Test.@test ig_both isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig_both) + v_both = CTModels.variable(ig_both) + Test.@test length(v_both) == 2 + Test.@test v_both[1] ≈ 1.0 + Test.@test v_both[2] ≈ 0.5 + end + + Test.@testset "build_initial_guess from NamedTuple" verbose=VERBOSE showtiming=SHOWTIMING begin + beam_data2 = Beam() + ocp = beam_data2.ocp + + # Consistent NamedTuple + init_named = (state=[0.05, 0.1], control=0.1, variable=Float64[]) + ig = CTModels.build_initial_guess(ocp, init_named) + Test.@test ig isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig) + + # NamedTuple with incorrect state dimension must throw + bad_named = (state=[0.1, 0.2, 0.3], control=0.1, variable=Float64[]) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp, bad_named + ) + end + + Test.@testset "build_initial_guess generic inputs" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP1DNoVar() + + ig_default = CTModels.build_initial_guess(ocp, nothing) + Test.@test ig_default isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig_default) + + init1 = CTModels.initial_guess(ocp; state=0.2, control=-0.1) + ig_passthrough = CTModels.build_initial_guess(ocp, init1) + Test.@test ig_passthrough === init1 + CTModels.validate_initial_guess(ocp, ig_passthrough) + + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess(ocp, 42) + end + + Test.@testset "PreInit handling" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp1 = DummyOCP1DNoVar() + ocp2 = DummyOCP1DVar() + + pre1 = CTModels.pre_initial_guess(state=0.2, control=-0.1) + ig1 = CTModels.build_initial_guess(ocp1, pre1) + Test.@test ig1 isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp1, ig1) + + pre_bad_state = CTModels.pre_initial_guess(state=[0.1, 0.2]) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp1, pre_bad_state + ) + + pre2 = CTModels.pre_initial_guess(variable=0.5) + ig2 = CTModels.build_initial_guess(ocp2, pre2) + Test.@test ig2 isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp2, ig2) + + pre_bad_var = CTModels.pre_initial_guess(variable=[0.1, 0.2]) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp2, pre_bad_var + ) + end + + Test.@testset "time-grid NamedTuple (per-block tuples)" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP1DNoVar() + + time = [0.0, 0.5, 1.0] + state_samples = [[0.0], [0.5], [1.0]] + control_samples = [0.0, 0.0, 1.0] + + init_nt = ( + state=(time, state_samples), control=(time, control_samples), variable=Float64[] + ) + ig = CTModels.build_initial_guess(ocp, init_nt) + Test.@test ig isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig) + + xfun = CTModels.state(ig) + ufun = CTModels.control(ig) + + x0 = xfun(0.0); + x1 = xfun(1.0) + u0 = ufun(0.0); + u1 = ufun(1.0) + + x0_val = x0 isa AbstractVector ? x0[1] : x0 + x1_val = x1 isa AbstractVector ? x1[1] : x1 + u0_val = u0 isa AbstractVector ? u0[1] : u0 + u1_val = u1 isa AbstractVector ? u1[1] : u1 + + Test.@test isapprox(x0_val, 0.0; atol=1e-12) + Test.@test isapprox(x1_val, 1.0; atol=1e-12) + Test.@test isapprox(u0_val, 0.0; atol=1e-12) + Test.@test isapprox(u1_val, 1.0; atol=1e-12) + + # Same test but using a matrix for the state samples (time-grid + matrix2vec path) + state_matrix = [0.0; 0.5; 1.0] + init_nt_mat = ( + state=(time, state_matrix), control=(time, control_samples), variable=Float64[] + ) + ig_mat = CTModels.build_initial_guess(ocp, init_nt_mat) + Test.@test ig_mat isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig_mat) + + # Edge case: (time, nothing) for state should fall back to default initial_state + init_nt_state_nothing = ( + state=(time, nothing), control=(time, control_samples), variable=Float64[] + ) + ig_state_nothing = CTModels.build_initial_guess(ocp, init_nt_state_nothing) + Test.@test ig_state_nothing isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig_state_nothing) + + # Edge case: (time, nothing) for control should fall back to default initial_control + init_nt_control_nothing = ( + state=(time, state_samples), control=(time, nothing), variable=Float64[] + ) + ig_control_nothing = CTModels.build_initial_guess(ocp, init_nt_control_nothing) + Test.@test ig_control_nothing isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig_control_nothing) + + bad_state_samples = [[0.0], [1.0]] + bad_nt = ( + state=(time, bad_state_samples), + control=(time, control_samples), + variable=Float64[], + ) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp, bad_nt + ) + end + + Test.@testset "time-grid NamedTuple with 2D state matrix" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP2DNoVar() + + time = [0.0, 0.5, 1.0] + # Each row corresponds to a time sample, columns to state components (x1, x2) + state_matrix = [ + 0.0 1.0; + 0.5 1.5; + 1.0 2.0 + ] + + init_nt = (state=(time, state_matrix),) + ig = CTModels.build_initial_guess(ocp, init_nt) + Test.@test ig isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig) + + xfun = CTModels.state(ig) + x0 = xfun(0.0) + x1 = xfun(1.0) + + Test.@test x0[1] ≈ 0.0 + Test.@test x0[2] ≈ 1.0 + Test.@test x1[1] ≈ 1.0 + Test.@test x1[2] ≈ 2.0 + end + + Test.@testset "time-grid PreInit via tuples" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP1DNoVar() + time = [0.0, 0.5, 1.0] + state_samples = [[0.0], [0.5], [1.0]] + + pre = CTModels.pre_initial_guess(state=(time, state_samples)) + ig = CTModels.build_initial_guess(ocp, pre) + Test.@test ig isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig) + + xfun = CTModels.state(ig) + x0 = xfun(0.0); + x1 = xfun(1.0) + x0_val = x0 isa AbstractVector ? x0[1] : x0 + x1_val = x1 isa AbstractVector ? x1[1] : x1 + Test.@test isapprox(x0_val, 0.0; atol=1e-12) + Test.@test isapprox(x1_val, 1.0; atol=1e-12) + end + + Test.@testset "per-component state init without time" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP2DNoVar() + + # Init only via components x1, x2 + init_nt = (x1=0.0, x2=1.0) + ig = CTModels.build_initial_guess(ocp, init_nt) + Test.@test ig isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig) + + xfun = CTModels.state(ig) + x = xfun(0.0) + Test.@test x[1] ≈ 0.0 + Test.@test x[2] ≈ 1.0 + end + + Test.@testset "per-component state init with time" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP2DNoVar() + time = [0.0, 1.0] + init_nt = (x1=(time, [0.0, 1.0]), x2=(time, [1.0, 2.0])) + ig = CTModels.build_initial_guess(ocp, init_nt) + Test.@test ig isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig) + + xfun = CTModels.state(ig) + x0 = xfun(0.0); + x1 = xfun(1.0) + Test.@test x0[1] ≈ 0.0 + Test.@test x0[2] ≈ 1.0 + Test.@test x1[1] ≈ 1.0 + Test.@test x1[2] ≈ 2.0 + end + + Test.@testset "uniqueness between block and component specs" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP2DNoVar() + bad_nt = (state=[0.0, 0.0], x1=1.0) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp, bad_nt + ) + end + + Test.@testset "warm-start from AbstractSolution" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP1DVar() + + xfun = t -> 0.1 + ufun = t -> -0.2 + v = 0.5 + + sol_ok = DummySolution1DVar(ocp, xfun, ufun, v) + ig = CTModels.build_initial_guess(ocp, sol_ok) + Test.@test ig isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig) + + model_bad_var = DummyOCP1DNoVar() + sol_bad_var = DummySolution1DVar(model_bad_var, xfun, ufun, v) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp, sol_bad_var + ) + + model_bad_state = DummyOCP2DNoVar() + sol_bad_state = DummySolution1DVar(model_bad_state, xfun, ufun, v) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp, sol_bad_state + ) + end + + Test.@testset "NamedTuple alias keys from OCP names" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp1 = DummyOCP1DNoVar() + + init_nt1 = (x=0.2, u=-0.1) + ig1 = CTModels.build_initial_guess(ocp1, init_nt1) + Test.@test ig1 isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp1, ig1) + + time = [0.0, 0.5, 1.0] + state_samples = [[0.0], [0.5], [1.0]] + control_samples = [0.0, 0.0, 1.0] + + init_nt2 = (x=(time, state_samples), u=(time, control_samples), variable=Float64[]) + ig2 = CTModels.build_initial_guess(ocp1, init_nt2) + Test.@test ig2 isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp1, ig2) + end + + Test.@testset "NamedTuple error cases" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp1 = DummyOCP1DNoVar() + + bad_unknown = (state=0.1, foo=1.0) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp1, bad_unknown + ) + + bad_time = (time=[0.0, 1.0], state=0.1) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp1, bad_time + ) + + ocp2 = DummyOCP2DNoVar() + + bad_comp_vector = (x1=[0.0, 1.0]) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp2, bad_comp_vector + ) + + time = [0.0, 1.0, 2.0] + bad_comp_time = (x1=(time, [0.0, 1.0])) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp2, bad_comp_time + ) + + ocp3 = DummyOCP2DNoVar() + bad_state_fun = t -> [0.0] + bad_nt_state_fun = (state=bad_state_fun,) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp3, bad_nt_state_fun + ) + end + + Test.@testset "per-component control init without time" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP1D2Control() + + init_nt = (u1=0.0, u2=1.0) + ig = CTModels.build_initial_guess(ocp, init_nt) + Test.@test ig isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig) + + ufun = CTModels.control(ig) + u = ufun(0.0) + Test.@test length(u) == 2 + Test.@test u[1] ≈ 0.0 + Test.@test u[2] ≈ 1.0 + end + + Test.@testset "per-component control init with time" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP1D2Control() + time = [0.0, 1.0] + + init_nt = (u1=(time, [0.0, 1.0]), u2=(time, [1.0, 2.0])) + ig = CTModels.build_initial_guess(ocp, init_nt) + Test.@test ig isa CTModels.AbstractOptimalControlInitialGuess + CTModels.validate_initial_guess(ocp, ig) + + ufun = CTModels.control(ig) + u0 = ufun(0.0) + u1 = ufun(1.0) + + Test.@test u0[1] ≈ 0.0 + Test.@test u0[2] ≈ 1.0 + Test.@test u1[1] ≈ 1.0 + Test.@test u1[2] ≈ 2.0 + end + + Test.@testset "uniqueness between control block and component specs" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCP1D2Control() + + bad_nt1 = (control=[0.0, 1.0], u1=1.0) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp, bad_nt1 + ) + + bad_nt2 = (u=[0.0, 1.0], u1=1.0) + Test.@test_throws CTBase.IncorrectArgument CTModels.build_initial_guess( + ocp, bad_nt2 + ) + end +end \ No newline at end of file diff --git a/test/io/test_export_import.jl b/test/io/test_export_import.jl new file mode 100644 index 00000000..71a03031 --- /dev/null +++ b/test/io/test_export_import.jl @@ -0,0 +1,454 @@ +using JLD2 +using JSON3 + +# ============================================================================ +# TEST HELPERS +# ============================================================================ + +function remove_if_exists(filename::String) + isfile(filename) && rm(filename) +end + +# ============================================================================ +# MAIN TEST FUNCTION +# ============================================================================ + +function test_export_import() + + # ======================================================================== + # Integration tests – basic round-trip with solution_example + # ======================================================================== + + Test.@testset "JSON round-trip: solution_example (matrix)" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp, sol = solution_example() + + CTModels.export_ocp_solution(sol; filename="solution_test", format=:JSON) + sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_test", format=:JSON) + + @test CTModels.objective(sol) ≈ CTModels.objective(sol_reloaded) atol=1e-8 + @test CTModels.iterations(sol) == CTModels.iterations(sol_reloaded) + @test CTModels.successful(sol) == CTModels.successful(sol_reloaded) + @test CTModels.status(sol) == CTModels.status(sol_reloaded) + + remove_if_exists("solution_test.json") + end + + Test.@testset "JSON round-trip: solution_example (function)" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp, sol = solution_example(; fun=true) + + CTModels.export_ocp_solution(sol; filename="solution_test_fun", format=:JSON) + sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_test_fun", format=:JSON) + + @test CTModels.objective(sol) ≈ CTModels.objective(sol_reloaded) atol=1e-8 + @test CTModels.iterations(sol) == CTModels.iterations(sol_reloaded) + + remove_if_exists("solution_test_fun.json") + end + + Test.@testset "JLD round-trip: solution_example" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp, sol = solution_example() + + # Suppress JLD2 warnings about anonymous functions (expected behaviour) + Base.CoreLogging.with_logger(Base.CoreLogging.NullLogger()) do + CTModels.export_ocp_solution(sol; filename="solution_test") # default is :JLD + end + sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_test", format=:JLD) + + @test CTModels.objective(sol) ≈ CTModels.objective(sol_reloaded) atol=1e-8 + @test CTModels.iterations(sol) == CTModels.iterations(sol_reloaded) + + remove_if_exists("solution_test.jld2") + end + + # ======================================================================== + # Comprehensive JSON tests – all fields with solution_example_dual + # ======================================================================== + + Test.@testset "JSON comprehensive: all fields preserved" verbose=VERBOSE showtiming=SHOWTIMING begin + # Use solution_example_dual which has all duals populated + ocp, sol = solution_example_dual() + + # Export + CTModels.export_ocp_solution(sol; filename="solution_full", format=:JSON) + + # Read raw JSON to verify structure + json_string = read("solution_full.json", String) + blob = JSON3.read(json_string) + + # Verify all expected keys are present + expected_keys = [ + "time_grid", "state", "control", "variable", "costate", + "objective", "iterations", "constraints_violation", + "message", "status", "successful", + "path_constraints_dual", + "state_constraints_lb_dual", "state_constraints_ub_dual", + "control_constraints_lb_dual", "control_constraints_ub_dual", + "boundary_constraints_dual", + "variable_constraints_lb_dual", "variable_constraints_ub_dual", + ] + for key in expected_keys + @test haskey(blob, key) + end + + # Verify scalar fields + @test blob["objective"] ≈ CTModels.objective(sol) atol=1e-10 + @test blob["iterations"] == CTModels.iterations(sol) + @test blob["constraints_violation"] ≈ CTModels.constraints_violation(sol) atol=1e-10 + @test blob["message"] == CTModels.message(sol) + @test blob["status"] == string(CTModels.status(sol)) + @test blob["successful"] == CTModels.successful(sol) + + # Verify time_grid + T_orig = CTModels.time_grid(sol) + T_json = Vector{Float64}(blob["time_grid"]) + @test length(T_json) == length(T_orig) + @test T_json ≈ T_orig atol=1e-10 + + # Verify variable + v_orig = CTModels.variable(sol) + v_json = if isempty(blob["variable"]) + Float64[] + else + Vector{Float64}(blob["variable"]) + end + @test v_json ≈ v_orig atol=1e-10 + + # Verify state discretization + state_json = blob["state"] + @test length(state_json) == length(T_orig) + x_func = CTModels.state(sol) + for (i, t) in enumerate(T_orig) + x_expected = x_func(t) + x_from_json = if state_json[i] isa Number + state_json[i] + else + Vector{Float64}(state_json[i]) + end + @test x_from_json ≈ x_expected atol=1e-8 + end + + # Verify control discretization + control_json = blob["control"] + @test length(control_json) == length(T_orig) + u_func = CTModels.control(sol) + for (i, t) in enumerate(T_orig) + u_expected = u_func(t) + u_from_json = if control_json[i] isa Number + control_json[i] + else + Vector{Float64}(control_json[i]) + end + @test u_from_json ≈ u_expected atol=1e-8 + end + + # Verify costate discretization + costate_json = blob["costate"] + @test length(costate_json) == length(T_orig) + p_func = CTModels.costate(sol) + for (i, t) in enumerate(T_orig) + p_expected = p_func(t) + p_from_json = if costate_json[i] isa Number + costate_json[i] + else + Vector{Float64}(costate_json[i]) + end + @test p_from_json ≈ p_expected atol=1e-8 + end + + # Verify path_constraints_dual if present + pcd = CTModels.path_constraints_dual(sol) + if !isnothing(pcd) + pcd_json = blob["path_constraints_dual"] + @test !isnothing(pcd_json) + @test length(pcd_json) == length(T_orig) + for (i, t) in enumerate(T_orig) + pcd_expected = pcd(t) + pcd_from_json = Vector{Float64}(pcd_json[i]) + @test pcd_from_json ≈ pcd_expected atol=1e-8 + end + end + + # Verify boundary_constraints_dual if present + bcd = CTModels.boundary_constraints_dual(sol) + if !isnothing(bcd) + bcd_json = blob["boundary_constraints_dual"] + @test !isnothing(bcd_json) + bcd_from_json = Vector{Float64}(bcd_json) + @test bcd_from_json ≈ bcd atol=1e-10 + end + + # Verify variable_constraints_lb_dual if present + vclbd = CTModels.variable_constraints_lb_dual(sol) + if !isnothing(vclbd) + vclbd_json = blob["variable_constraints_lb_dual"] + @test !isnothing(vclbd_json) + vclbd_from_json = Vector{Float64}(vclbd_json) + @test vclbd_from_json ≈ vclbd atol=1e-10 + end + + # Verify variable_constraints_ub_dual if present + vcubd = CTModels.variable_constraints_ub_dual(sol) + if !isnothing(vcubd) + vcubd_json = blob["variable_constraints_ub_dual"] + @test !isnothing(vcubd_json) + vcubd_from_json = Vector{Float64}(vcubd_json) + @test vcubd_from_json ≈ vcubd atol=1e-10 + end + + remove_if_exists("solution_full.json") + end + + Test.@testset "JSON import: all fields reconstructed" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp, sol = solution_example_dual() + + CTModels.export_ocp_solution(sol; filename="solution_import_test", format=:JSON) + sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_import_test", format=:JSON) + + # Scalar fields + @test CTModels.objective(sol_reloaded) ≈ CTModels.objective(sol) atol=1e-8 + @test CTModels.iterations(sol_reloaded) == CTModels.iterations(sol) + @test CTModels.constraints_violation(sol_reloaded) ≈ CTModels.constraints_violation(sol) atol=1e-8 + @test CTModels.message(sol_reloaded) == CTModels.message(sol) + @test CTModels.status(sol_reloaded) == CTModels.status(sol) + @test CTModels.successful(sol_reloaded) == CTModels.successful(sol) + + # Time grid + @test CTModels.time_grid(sol_reloaded) ≈ CTModels.time_grid(sol) atol=1e-10 + + # Metadata: dimensions, names, components and time labels + @test CTModels.state_dimension(sol_reloaded) == CTModels.state_dimension(sol) + @test CTModels.control_dimension(sol_reloaded) == CTModels.control_dimension(sol) + @test CTModels.variable_dimension(sol_reloaded) == CTModels.variable_dimension(sol) + + @test CTModels.state_name(sol_reloaded) == CTModels.state_name(sol) + @test CTModels.control_name(sol_reloaded) == CTModels.control_name(sol) + @test CTModels.variable_name(sol_reloaded) == CTModels.variable_name(sol) + + @test CTModels.state_components(sol_reloaded) == CTModels.state_components(sol) + @test CTModels.control_components(sol_reloaded) == CTModels.control_components(sol) + @test CTModels.variable_components(sol_reloaded) == CTModels.variable_components(sol) + + @test CTModels.initial_time_name(sol_reloaded) == CTModels.initial_time_name(sol) + @test CTModels.final_time_name(sol_reloaded) == CTModels.final_time_name(sol) + @test CTModels.time_name(sol_reloaded) == CTModels.time_name(sol) + + # Variable + @test CTModels.variable(sol_reloaded) ≈ CTModels.variable(sol) atol=1e-10 + + # State at sample times + T = CTModels.time_grid(sol) + x_orig = CTModels.state(sol) + x_reload = CTModels.state(sol_reloaded) + for t in T + @test x_reload(t) ≈ x_orig(t) atol=1e-8 + end + + # Control at sample times + u_orig = CTModels.control(sol) + u_reload = CTModels.control(sol_reloaded) + for t in T + @test u_reload(t) ≈ u_orig(t) atol=1e-8 + end + + # Costate at sample times + p_orig = CTModels.costate(sol) + p_reload = CTModels.costate(sol_reloaded) + for t in T + @test p_reload(t) ≈ p_orig(t) atol=1e-8 + end + + # Path constraints dual + pcd_orig = CTModels.path_constraints_dual(sol) + pcd_reload = CTModels.path_constraints_dual(sol_reloaded) + if !isnothing(pcd_orig) + @test !isnothing(pcd_reload) + for t in T + @test pcd_reload(t) ≈ pcd_orig(t) atol=1e-8 + end + else + @test isnothing(pcd_reload) + end + + # Boundary constraints dual + bcd_orig = CTModels.boundary_constraints_dual(sol) + bcd_reload = CTModels.boundary_constraints_dual(sol_reloaded) + if !isnothing(bcd_orig) + @test !isnothing(bcd_reload) + @test bcd_reload ≈ bcd_orig atol=1e-10 + else + @test isnothing(bcd_reload) + end + + # State constraints lb dual + sclbd_orig = CTModels.state_constraints_lb_dual(sol) + sclbd_reload = CTModels.state_constraints_lb_dual(sol_reloaded) + if !isnothing(sclbd_orig) + @test !isnothing(sclbd_reload) + for t in T + @test sclbd_reload(t) ≈ sclbd_orig(t) atol=1e-8 + end + else + @test isnothing(sclbd_reload) + end + + # State constraints ub dual + scubd_orig = CTModels.state_constraints_ub_dual(sol) + scubd_reload = CTModels.state_constraints_ub_dual(sol_reloaded) + if !isnothing(scubd_orig) + @test !isnothing(scubd_reload) + for t in T + @test scubd_reload(t) ≈ scubd_orig(t) atol=1e-8 + end + else + @test isnothing(scubd_reload) + end + + # Control constraints lb dual + cclbd_orig = CTModels.control_constraints_lb_dual(sol) + cclbd_reload = CTModels.control_constraints_lb_dual(sol_reloaded) + if !isnothing(cclbd_orig) + @test !isnothing(cclbd_reload) + for t in T + @test cclbd_reload(t) ≈ cclbd_orig(t) atol=1e-8 + end + else + @test isnothing(cclbd_reload) + end + + # Control constraints ub dual + ccubd_orig = CTModels.control_constraints_ub_dual(sol) + ccubd_reload = CTModels.control_constraints_ub_dual(sol_reloaded) + if !isnothing(ccubd_orig) + @test !isnothing(ccubd_reload) + for t in T + @test ccubd_reload(t) ≈ ccubd_orig(t) atol=1e-8 + end + else + @test isnothing(ccubd_reload) + end + + # Variable constraints lb dual + vclbd_orig = CTModels.variable_constraints_lb_dual(sol) + vclbd_reload = CTModels.variable_constraints_lb_dual(sol_reloaded) + if !isnothing(vclbd_orig) + @test !isnothing(vclbd_reload) + @test vclbd_reload ≈ vclbd_orig atol=1e-10 + else + @test isnothing(vclbd_reload) + end + + # Variable constraints ub dual + vcubd_orig = CTModels.variable_constraints_ub_dual(sol) + vcubd_reload = CTModels.variable_constraints_ub_dual(sol_reloaded) + if !isnothing(vcubd_orig) + @test !isnothing(vcubd_reload) + @test vcubd_reload ≈ vcubd_orig atol=1e-10 + else + @test isnothing(vcubd_reload) + end + + remove_if_exists("solution_import_test.json") + end + + # ======================================================================== + # Edge cases + # ======================================================================== + + Test.@testset "JSON: solution with all duals nothing" verbose=VERBOSE showtiming=SHOWTIMING begin + # solution_example has no duals + ocp, sol = solution_example() + + CTModels.export_ocp_solution(sol; filename="solution_no_duals", format=:JSON) + + # Read raw JSON + json_string = read("solution_no_duals.json", String) + blob = JSON3.read(json_string) + + # Verify dual fields are null + @test isnothing(blob["path_constraints_dual"]) + @test isnothing(blob["boundary_constraints_dual"]) + @test isnothing(blob["state_constraints_lb_dual"]) + @test isnothing(blob["state_constraints_ub_dual"]) + @test isnothing(blob["control_constraints_lb_dual"]) + @test isnothing(blob["control_constraints_ub_dual"]) + @test isnothing(blob["variable_constraints_lb_dual"]) + @test isnothing(blob["variable_constraints_ub_dual"]) + + # Import and verify duals are nothing + sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_no_duals", format=:JSON) + @test isnothing(CTModels.path_constraints_dual(sol_reloaded)) + @test isnothing(CTModels.boundary_constraints_dual(sol_reloaded)) + @test isnothing(CTModels.state_constraints_lb_dual(sol_reloaded)) + @test isnothing(CTModels.state_constraints_ub_dual(sol_reloaded)) + @test isnothing(CTModels.control_constraints_lb_dual(sol_reloaded)) + @test isnothing(CTModels.control_constraints_ub_dual(sol_reloaded)) + @test isnothing(CTModels.variable_constraints_lb_dual(sol_reloaded)) + @test isnothing(CTModels.variable_constraints_ub_dual(sol_reloaded)) + + remove_if_exists("solution_no_duals.json") + end + + Test.@testset "JSON: solver infos dict preserved" verbose=VERBOSE showtiming=SHOWTIMING begin + # Create a solution with custom infos + ocp, sol_base = solution_example() + T = CTModels.time_grid(sol_base) + + # Build a new solution with custom infos + x = CTModels.state(sol_base) + u = CTModels.control(sol_base) + p = CTModels.costate(sol_base) + v = CTModels.variable(sol_base) + + custom_infos = Dict{Symbol,Any}( + :solver_name => "TestSolver", + :tolerance => 1e-6, + :max_iterations => 1000, + :converged => true, + :residuals => [1e-3, 1e-5, 1e-8], + :nested => Dict{Symbol,Any}(:a => 1, :b => "test"), + ) + + sol = CTModels.build_solution( + ocp, + Vector{Float64}(T), + x, + u, + isa(v, Number) ? [v] : v, + p; + objective=CTModels.objective(sol_base), + iterations=CTModels.iterations(sol_base), + constraints_violation=CTModels.constraints_violation(sol_base), + message=CTModels.message(sol_base), + status=CTModels.status(sol_base), + successful=CTModels.successful(sol_base), + infos=custom_infos, + ) + + # Verify infos is set correctly + @test CTModels.infos(sol)[:solver_name] == "TestSolver" + @test CTModels.infos(sol)[:tolerance] == 1e-6 + + # Export and import + CTModels.export_ocp_solution(sol; filename="solution_with_infos", format=:JSON) + sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_with_infos", format=:JSON) + + # Verify infos is preserved + reloaded_infos = CTModels.infos(sol_reloaded) + @test reloaded_infos[:solver_name] == "TestSolver" + @test reloaded_infos[:tolerance] == 1e-6 + @test reloaded_infos[:max_iterations] == 1000 + @test reloaded_infos[:converged] == true + @test reloaded_infos[:residuals] == [1e-3, 1e-5, 1e-8] + @test reloaded_infos[:nested][:a] == 1 + @test reloaded_infos[:nested][:b] == "test" + + # Verify JSON structure + json_string = read("solution_with_infos.json", String) + blob = JSON3.read(json_string) + @test haskey(blob, "infos") + @test blob["infos"]["solver_name"] == "TestSolver" + @test blob["infos"]["tolerance"] == 1e-6 + + remove_if_exists("solution_with_infos.json") + end +end diff --git a/test/test_ext_exceptions.jl b/test/io/test_ext_exceptions.jl similarity index 100% rename from test/test_ext_exceptions.jl rename to test/io/test_ext_exceptions.jl diff --git a/test/meta/test_CTModels.jl b/test/meta/test_CTModels.jl new file mode 100644 index 00000000..c71d7515 --- /dev/null +++ b/test/meta/test_CTModels.jl @@ -0,0 +1,48 @@ +struct CTMDummySol <: CTModels.AbstractSolution end +struct CTMDummyModelTop <: CTModels.AbstractModel end + +function test_CTModels() + # TODO: add tests for the CTModels.jl top-level module file. + + # ======================================================================== + # Unit tests – basic aliases and tags + # ======================================================================== + + Test.@testset "type aliases and tags" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test CTModels.Dimension == Int + Test.@test CTModels.ctNumber == Real + Test.@test CTModels.Time === CTModels.ctNumber + + # For parametric aliases, test mutual <: rather than strict identity + Test.@test CTModels.ctVector <: AbstractVector{<:CTModels.ctNumber} + Test.@test AbstractVector{<:CTModels.ctNumber} <: CTModels.ctVector + + Test.@test CTModels.Times <: AbstractVector{<:CTModels.Time} + Test.@test AbstractVector{<:CTModels.Time} <: CTModels.Times + + Test.@test CTModels.JLD2Tag <: CTModels.AbstractTag + Test.@test CTModels.JSON3Tag <: CTModels.AbstractTag + + # Aliases towards CTSolvers usage + Test.@test CTModels.AbstractOptimalControlProblem === CTModels.AbstractModel + Test.@test CTModels.AbstractOptimalControlSolution === CTModels.AbstractSolution + end + + # ======================================================================== + # Integration-style tests – export/import format guards + # ======================================================================== + + Test.@testset "export/import format guards" verbose=VERBOSE showtiming=SHOWTIMING begin + sol = CTMDummySol() + ocp = CTMDummyModelTop() + + # Unknown format should trigger an IncorrectArgument without touching extensions. + Test.@test_throws CTBase.IncorrectArgument CTModels.export_ocp_solution( + sol; format=:FOO + ) + Test.@test_throws CTBase.IncorrectArgument CTModels.import_ocp_solution( + ocp; format=:FOO + ) + end +end + diff --git a/test/test_aqua.jl b/test/meta/test_aqua.jl similarity index 100% rename from test/test_aqua.jl rename to test/meta/test_aqua.jl diff --git a/test/nlp/test_discretized_ocp.jl b/test/nlp/test_discretized_ocp.jl new file mode 100644 index 00000000..2e178c15 --- /dev/null +++ b/test/nlp/test_discretized_ocp.jl @@ -0,0 +1,480 @@ +# Unit tests for CTModels discretized optimal control problems and solution builders. +# ============================================================================ +# TEST HELPER TYPES +# ============================================================================ + +# Dummy stats types for testing solution builders +struct DummyStatsDiscretizedOCP <: SolverCore.AbstractExecutionStats + value::Int +end + +struct DummyStatsDiscretizedOCP2 <: SolverCore.AbstractExecutionStats + value::String +end + +struct DummyStatsDiscretizedOCP3 <: SolverCore.AbstractExecutionStats + status::Symbol +end + +struct DummyStatsDiscretizedOCP4 <: SolverCore.AbstractExecutionStats end + +# Dummy OCP types for testing DiscretizedOptimalControlProblem +struct DummyOCPDiscretized <: CTModels.AbstractModel end +struct DummyOCPDiscretized2 <: CTModels.AbstractModel end +struct DummyOCPDiscretized3 <: CTModels.AbstractModel + data::String +end +struct DummyOCPDiscretized4 <: CTModels.AbstractModel end +struct DummyOCPDiscretized5 <: CTModels.AbstractModel end +struct DummyOCPDiscretized6 <: CTModels.AbstractModel end +struct DummyOCPDiscretized7 <: CTModels.AbstractModel end +struct DummyOCPDiscretized8 <: CTModels.AbstractModel + name::String +end +struct DummyOCPDiscretized9 <: CTModels.AbstractModel end + +struct SimpleOCPDiscretized <: CTModels.AbstractModel + dim::Int +end + +struct ComplexOCPDiscretized <: CTModels.AbstractModel + state_dim::Int + control_dim::Int + constraints::Vector{String} +end + +# ============================================================================ +# TEST FUNCTION +# ============================================================================ + +function test_discretized_ocp() + + # ============================================================================ + # SOLUTION BUILDERS - UNIT TESTS + # ============================================================================ + + Test.@testset "ADNLPSolutionBuilder" verbose=VERBOSE showtiming=SHOWTIMING begin + # Test constructor: wrap a function + call_count = Ref(0) + last_arg = Ref{Any}(nothing) + + function test_adnlp_builder_fn(stats) + call_count[] += 1 + last_arg[] = stats + return (:adnlp_result, stats) + end + + builder = CTModels.ADNLPSolutionBuilder(test_adnlp_builder_fn) + + # Verify the function is stored + Test.@test builder.f === test_adnlp_builder_fn + Test.@test builder isa CTModels.ADNLPSolutionBuilder + + # Test call operator: should invoke the wrapped function + stats = DummyStatsDiscretizedOCP(42) + + result = builder(stats) + + # Verify the wrapped function was called with correct argument + Test.@test call_count[] == 1 + Test.@test last_arg[] === stats + Test.@test result == (:adnlp_result, stats) + end + + Test.@testset "ExaSolutionBuilder" verbose=VERBOSE showtiming=SHOWTIMING begin + # Test constructor: wrap a function + call_count = Ref(0) + last_arg = Ref{Any}(nothing) + + function test_exa_builder_fn(stats) + call_count[] += 1 + last_arg[] = stats + return (:exa_result, stats) + end + + builder = CTModels.ExaSolutionBuilder(test_exa_builder_fn) + + # Verify the function is stored + Test.@test builder.f === test_exa_builder_fn + Test.@test builder isa CTModels.ExaSolutionBuilder + + # Test call operator: should invoke the wrapped function + stats = DummyStatsDiscretizedOCP2("test") + + result = builder(stats) + + # Verify the wrapped function was called with correct argument + Test.@test call_count[] == 1 + Test.@test last_arg[] === stats + Test.@test result == (:exa_result, stats) + end + + # ============================================================================ + # DISCRETIZED OCP - CONSTRUCTORS + # ============================================================================ + + Test.@testset "DiscretizedOptimalControlProblem - tuple constructor" verbose=VERBOSE showtiming=SHOWTIMING begin + # Create a dummy OCP (we need an AbstractOptimalControlProblem) + ocp = DummyOCPDiscretized() + + # Create dummy model builders + adnlp_model_builder = CTModels.ADNLPModelBuilder(x -> error("unused")) + exa_model_builder = CTModels.ExaModelBuilder((T, x; kwargs...) -> error("unused")) + + # Create dummy solution builders + adnlp_solution_builder = CTModels.ADNLPSolutionBuilder(s -> s) + exa_solution_builder = CTModels.ExaSolutionBuilder(s -> s) + + # Build using tuple constructor with backend builder bundles + backend_builders = ( + :adnlp => + CTModels.OCPBackendBuilders(adnlp_model_builder, adnlp_solution_builder), + :exa => CTModels.OCPBackendBuilders(exa_model_builder, exa_solution_builder), + ) + + docp = CTModels.DiscretizedOptimalControlProblem(ocp, backend_builders) + + # Verify the problem was constructed correctly + Test.@test docp isa CTModels.DiscretizedOptimalControlProblem + Test.@test docp.optimal_control_problem === ocp + + # The Tuple-of-Pairs inputs should have been converted to a NamedTuple of OCPBackendBuilders + expected_backend_builders = (; + adnlp=CTModels.OCPBackendBuilders(adnlp_model_builder, adnlp_solution_builder), + exa=CTModels.OCPBackendBuilders(exa_model_builder, exa_solution_builder), + ) + + Test.@test docp.backend_builders == expected_backend_builders + Test.@test docp.backend_builders.adnlp.model === adnlp_model_builder + Test.@test docp.backend_builders.adnlp.solution === adnlp_solution_builder + Test.@test docp.backend_builders.exa.model === exa_model_builder + Test.@test docp.backend_builders.exa.solution === exa_solution_builder + end + + Test.@testset "DiscretizedOptimalControlProblem - individual args constructor" verbose=VERBOSE showtiming=SHOWTIMING begin + # Create a dummy OCP + ocp = DummyOCPDiscretized2() + + # Create builders + adnlp_model_builder = CTModels.ADNLPModelBuilder(x -> error("unused")) + exa_model_builder = CTModels.ExaModelBuilder((T, x; kwargs...) -> error("unused")) + adnlp_solution_builder = CTModels.ADNLPSolutionBuilder(s -> (:adnlp_sol, s)) + exa_solution_builder = CTModels.ExaSolutionBuilder(s -> (:exa_sol, s)) + + # Build using individual args constructor + docp = CTModels.DiscretizedOptimalControlProblem( + ocp, + adnlp_model_builder, + exa_model_builder, + adnlp_solution_builder, + exa_solution_builder, + ) + + # Verify the problem was constructed correctly + Test.@test docp isa CTModels.DiscretizedOptimalControlProblem + Test.@test docp.optimal_control_problem === ocp + + # Verify the builders were converted to the expected backend_builders representation + expected_backend_builders = (; + adnlp=CTModels.OCPBackendBuilders(adnlp_model_builder, adnlp_solution_builder), + exa=CTModels.OCPBackendBuilders(exa_model_builder, exa_solution_builder), + ) + + Test.@test docp.backend_builders == expected_backend_builders + Test.@test docp.backend_builders.adnlp.model === adnlp_model_builder + Test.@test docp.backend_builders.adnlp.solution === adnlp_solution_builder + Test.@test docp.backend_builders.exa.model === exa_model_builder + Test.@test docp.backend_builders.exa.solution === exa_solution_builder + end + + # ============================================================================ + # ACCESSOR FUNCTIONS + # ============================================================================ + + Test.@testset "ocp_model" verbose=VERBOSE showtiming=SHOWTIMING begin + # Create a DOCP with a specific OCP + ocp = DummyOCPDiscretized3("test_data") + + adnlp_model_builder = CTModels.ADNLPModelBuilder(x -> error("unused")) + exa_model_builder = CTModels.ExaModelBuilder((T, x; kwargs...) -> error("unused")) + adnlp_solution_builder = CTModels.ADNLPSolutionBuilder(s -> s) + exa_solution_builder = CTModels.ExaSolutionBuilder(s -> s) + + docp = CTModels.DiscretizedOptimalControlProblem( + ocp, + adnlp_model_builder, + exa_model_builder, + adnlp_solution_builder, + exa_solution_builder, + ) + + # Test ocp_model accessor + retrieved_ocp = CTModels.ocp_model(docp) + Test.@test retrieved_ocp === ocp + Test.@test retrieved_ocp.data == "test_data" + end + + Test.@testset "get_adnlp_model_builder" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCPDiscretized4() + + # Create a specific builder to verify retrieval + function my_adnlp_builder(x) + return :my_adnlp_model + end + adnlp_model_builder = CTModels.ADNLPModelBuilder(my_adnlp_builder) + exa_model_builder = CTModels.ExaModelBuilder((T, x; kwargs...) -> error("unused")) + adnlp_solution_builder = CTModels.ADNLPSolutionBuilder(s -> s) + exa_solution_builder = CTModels.ExaSolutionBuilder(s -> s) + + docp = CTModels.DiscretizedOptimalControlProblem( + ocp, + adnlp_model_builder, + exa_model_builder, + adnlp_solution_builder, + exa_solution_builder, + ) + + # Test get_adnlp_model_builder accessor + retrieved_builder = CTModels.get_adnlp_model_builder(docp) + Test.@test retrieved_builder === adnlp_model_builder + Test.@test retrieved_builder.f === my_adnlp_builder + end + + Test.@testset "get_exa_model_builder" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCPDiscretized5() + + # Create a specific builder to verify retrieval + function my_exa_builder(::Type{T}, x; kwargs...) where {T} + return :my_exa_model + end + adnlp_model_builder = CTModels.ADNLPModelBuilder(x -> error("unused")) + exa_model_builder = CTModels.ExaModelBuilder(my_exa_builder) + adnlp_solution_builder = CTModels.ADNLPSolutionBuilder(s -> s) + exa_solution_builder = CTModels.ExaSolutionBuilder(s -> s) + + docp = CTModels.DiscretizedOptimalControlProblem( + ocp, + adnlp_model_builder, + exa_model_builder, + adnlp_solution_builder, + exa_solution_builder, + ) + + # Test get_exa_model_builder accessor + retrieved_builder = CTModels.get_exa_model_builder(docp) + Test.@test retrieved_builder === exa_model_builder + Test.@test retrieved_builder.f === my_exa_builder + end + + Test.@testset "get_adnlp_solution_builder" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCPDiscretized6() + + # Create a specific solution builder to verify retrieval + function my_adnlp_solution_builder(stats) + return (:my_adnlp_solution, stats) + end + adnlp_model_builder = CTModels.ADNLPModelBuilder(x -> error("unused")) + exa_model_builder = CTModels.ExaModelBuilder((T, x; kwargs...) -> error("unused")) + adnlp_solution_builder = CTModels.ADNLPSolutionBuilder(my_adnlp_solution_builder) + exa_solution_builder = CTModels.ExaSolutionBuilder(s -> s) + + docp = CTModels.DiscretizedOptimalControlProblem( + ocp, + adnlp_model_builder, + exa_model_builder, + adnlp_solution_builder, + exa_solution_builder, + ) + + # Test get_adnlp_solution_builder accessor + retrieved_builder = CTModels.get_adnlp_solution_builder(docp) + Test.@test retrieved_builder === adnlp_solution_builder + Test.@test retrieved_builder.f === my_adnlp_solution_builder + end + + Test.@testset "get_exa_solution_builder" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCPDiscretized7() + + # Create a specific solution builder to verify retrieval + function my_exa_solution_builder(stats) + return (:my_exa_solution, stats) + end + adnlp_model_builder = CTModels.ADNLPModelBuilder(x -> error("unused")) + exa_model_builder = CTModels.ExaModelBuilder((T, x; kwargs...) -> error("unused")) + adnlp_solution_builder = CTModels.ADNLPSolutionBuilder(s -> s) + exa_solution_builder = CTModels.ExaSolutionBuilder(my_exa_solution_builder) + + docp = CTModels.DiscretizedOptimalControlProblem( + ocp, + adnlp_model_builder, + exa_model_builder, + adnlp_solution_builder, + exa_solution_builder, + ) + + # Test get_exa_solution_builder accessor + retrieved_builder = CTModels.get_exa_solution_builder(docp) + Test.@test retrieved_builder === exa_solution_builder + Test.@test retrieved_builder.f === my_exa_solution_builder + end + + # ============================================================================ + # INTEGRATION TESTS + # ============================================================================ + + Test.@testset "end-to-end workflow" verbose=VERBOSE showtiming=SHOWTIMING begin + # Create a complete DOCP and verify the full workflow + ocp = DummyOCPDiscretized8("integration_test") + + # Track calls to verify builders are invoked correctly + adnlp_model_calls = Ref(0) + exa_model_calls = Ref(0) + adnlp_solution_calls = Ref(0) + exa_solution_calls = Ref(0) + + function adnlp_model_fn(x; kwargs...) + adnlp_model_calls[] += 1 + # Minimal ADNLPModel construction, similar to test_ctmodels_problem_core + f(z) = sum(z .^ 2) + return ADNLPModels.ADNLPModel(f, x) + end + + function exa_model_fn(::Type{T}, x; kwargs...) where {T} + exa_model_calls[] += 1 + return (:exa_model, T, x) + end + + function adnlp_solution_fn(stats) + adnlp_solution_calls[] += 1 + return (:adnlp_solution, stats) + end + + function exa_solution_fn(stats) + exa_solution_calls[] += 1 + return (:exa_solution, stats) + end + + # Create DOCP + docp = CTModels.DiscretizedOptimalControlProblem( + ocp, + CTModels.ADNLPModelBuilder(adnlp_model_fn), + CTModels.ExaModelBuilder(exa_model_fn), + CTModels.ADNLPSolutionBuilder(adnlp_solution_fn), + CTModels.ExaSolutionBuilder(exa_solution_fn), + ) + + # Verify OCP retrieval + Test.@test CTModels.ocp_model(docp).name == "integration_test" + + # Retrieve and use model builders + adnlp_builder = CTModels.get_adnlp_model_builder(docp) + exa_builder = CTModels.get_exa_model_builder(docp) + + # Calling the ADNLPModelBuilder should produce a valid ADNLPModels.ADNLPModel + nlp = adnlp_builder([1.0, 2.0]) + Test.@test nlp isa ADNLPModels.ADNLPModel + Test.@test adnlp_model_calls[] == 1 + + # For ExaModelBuilder, constructing a full ExaModels.ExaModel is non-trivial. + # As in test_ctmodels_problem_core, we limit ourselves to checking that the + # correct builder was retrieved and that its wrapped callable is exa_model_fn. + Test.@test exa_builder isa CTModels.ExaModelBuilder + Test.@test exa_builder.f === exa_model_fn + + # Retrieve and use solution builders + adnlp_sol_builder = CTModels.get_adnlp_solution_builder(docp) + exa_sol_builder = CTModels.get_exa_solution_builder(docp) + + stats = DummyStatsDiscretizedOCP3(:success) + + Test.@test adnlp_sol_builder(stats) == (:adnlp_solution, stats) + Test.@test adnlp_solution_calls[] == 1 + + Test.@test exa_sol_builder(stats) == (:exa_solution, stats) + Test.@test exa_solution_calls[] == 1 + end + + # ============================================================================ + # EDGE CASES + # ============================================================================ + + Test.@testset "solution builder that throws" verbose=VERBOSE showtiming=SHOWTIMING begin + # Test that errors in solution builders are propagated correctly + ocp = DummyOCPDiscretized9() + + function throwing_builder(stats) + error("Intentional error in solution builder") + end + + docp = CTModels.DiscretizedOptimalControlProblem( + ocp, + CTModels.ADNLPModelBuilder(x -> error("unused")), + CTModels.ExaModelBuilder((T, x; kwargs...) -> error("unused")), + CTModels.ADNLPSolutionBuilder(throwing_builder), + CTModels.ExaSolutionBuilder(s -> s), + ) + + builder = CTModels.get_adnlp_solution_builder(docp) + + stats = DummyStatsDiscretizedOCP4() + + # Verify the error is propagated + Test.@test_throws ErrorException builder(stats) + end + + Test.@testset "missing backend errors" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = DummyOCPDiscretized() + + # Construct a DOCP with only an :adnlp backend registered. + adnlp_model_builder = CTModels.ADNLPModelBuilder(x -> :ad_model) + adnlp_solution_builder = CTModels.ADNLPSolutionBuilder(s -> s) + adnlp_bundle = CTModels.OCPBackendBuilders( + adnlp_model_builder, adnlp_solution_builder + ) + + docp_ad_only = CTModels.DiscretizedOptimalControlProblem( + ocp, (:adnlp => adnlp_bundle,) + ) + + Test.@test_throws ArgumentError CTModels.get_exa_model_builder(docp_ad_only) + Test.@test_throws ArgumentError CTModels.get_exa_solution_builder(docp_ad_only) + + # Construct a DOCP with only an :exa backend registered. + exa_model_builder = CTModels.ExaModelBuilder((T, x; kwargs...) -> :exa_model) + exa_solution_builder = CTModels.ExaSolutionBuilder(s -> s) + exa_bundle = CTModels.OCPBackendBuilders(exa_model_builder, exa_solution_builder) + + docp_exa_only = CTModels.DiscretizedOptimalControlProblem( + ocp, (:exa => exa_bundle,) + ) + + Test.@test_throws ArgumentError CTModels.get_adnlp_model_builder(docp_exa_only) + Test.@test_throws ArgumentError CTModels.get_adnlp_solution_builder(docp_exa_only) + end + + Test.@testset "different OCP types" verbose=VERBOSE showtiming=SHOWTIMING begin + # Test that DOCP works with different concrete OCP types + # Create DOCPs with different OCP types + simple_ocp = SimpleOCPDiscretized(5) + complex_ocp = ComplexOCPDiscretized(10, 3, ["bound1", "bound2"]) + + adnlp_builder = CTModels.ADNLPModelBuilder(x -> :model) + exa_builder = CTModels.ExaModelBuilder((T, x; kwargs...) -> :model) + adnlp_sol_builder = CTModels.ADNLPSolutionBuilder(s -> s) + exa_sol_builder = CTModels.ExaSolutionBuilder(s -> s) + + docp_simple = CTModels.DiscretizedOptimalControlProblem( + simple_ocp, adnlp_builder, exa_builder, adnlp_sol_builder, exa_sol_builder + ) + + docp_complex = CTModels.DiscretizedOptimalControlProblem( + complex_ocp, adnlp_builder, exa_builder, adnlp_sol_builder, exa_sol_builder + ) + + # Verify both work correctly + Test.@test CTModels.ocp_model(docp_simple).dim == 5 + Test.@test CTModels.ocp_model(docp_complex).state_dim == 10 + Test.@test CTModels.ocp_model(docp_complex).control_dim == 3 + Test.@test length(CTModels.ocp_model(docp_complex).constraints) == 2 + end +end \ No newline at end of file diff --git a/test/nlp/test_model_api.jl b/test/nlp/test_model_api.jl new file mode 100644 index 00000000..63aa9534 --- /dev/null +++ b/test/nlp/test_model_api.jl @@ -0,0 +1,180 @@ +# Unit tests for the generic optimization model API (model and solution builders). +struct DummyProblemAPI <: CTModels.AbstractOptimizationProblem end + +struct DummyStatsAPI <: SolverCore.AbstractExecutionStats end + +struct DummySolutionAPI <: CTModels.AbstractSolution end + +struct FakeBackendAPI <: CTModels.AbstractOptimizationModeler + model_calls::Base.RefValue{Int} + solution_calls::Base.RefValue{Int} +end + +function (b::FakeBackendAPI)( + prob::CTModels.AbstractOptimizationProblem, initial_guess +)::NLPModels.AbstractNLPModel + b.model_calls[] += 1 + # Use a simple real ADNLPModel here so that we respect the declared + # return type ::NLPModels.AbstractNLPModel without defining custom + # subtypes of NLPModels internals. + f(z) = sum(z .^ 2) + return ADNLPModels.ADNLPModel(f, initial_guess) +end + +function (b::FakeBackendAPI)( + prob::CTModels.AbstractOptimizationProblem, + nlp_solution::SolverCore.AbstractExecutionStats, +) + b.solution_calls[] += 1 + return DummySolutionAPI() +end + +struct DummyOCPForModelAPI <: CTModels.AbstractModel end + +function make_dummy_docp_for_model_api() + ocp = DummyOCPForModelAPI() + adnlp_builder = CTModels.ADNLPModelBuilder( + (x; kwargs...) -> begin + f(z) = sum(z .^ 2) + # We deliberately ignore the extra keyword arguments such as + # show_time, backend, and AD backend options here. For this + # unit test we only need a valid ADNLPModel instance. + return ADNLPModels.ADNLPModel(f, x) + end + ) + exa_builder = CTModels.ExaModelBuilder((T, x; kwargs...) -> :exa_model_dummy) + adnlp_solution_builder = CTModels.ADNLPSolutionBuilder(s -> s) + exa_solution_builder = CTModels.ExaSolutionBuilder(s -> s) + return CTModels.DiscretizedOptimalControlProblem( + ocp, adnlp_builder, exa_builder, adnlp_solution_builder, exa_solution_builder + ) +end + +function test_model_api() + + # ======================================================================== + # Problems + # ======================================================================== + ros = Rosenbrock() + elec = Elec() + maxd = Max1MinusX2() + + # ------------------------------------------------------------------ + # Unit tests for build_model delegation + # ------------------------------------------------------------------ + Test.@testset "build_model delegation" verbose=VERBOSE showtiming=SHOWTIMING begin + prob = DummyProblemAPI() + x0 = [1.0, 2.0] + model_calls = Ref(0) + solution_calls = Ref(0) + backend = FakeBackendAPI(model_calls, solution_calls) + + nlp = CTModels.build_model(prob, x0, backend) + Test.@test nlp isa NLPModels.AbstractNLPModel + Test.@test model_calls[] == 1 + Test.@test solution_calls[] == 0 + end + + # ------------------------------------------------------------------ + # Unit tests for nlp_model(DiscretizedOptimalControlProblem, ...) + # ------------------------------------------------------------------ + Test.@testset "nlp_model(DiscretizedOptimalControlProblem, ...)" verbose=VERBOSE showtiming=SHOWTIMING begin + docp = make_dummy_docp_for_model_api() + x0 = [1.0, 2.0] + modeler = CTModels.ADNLPModeler() + + nlp = CTModels.nlp_model(docp, x0, modeler) + Test.@test nlp isa NLPModels.AbstractNLPModel + end + + # ------------------------------------------------------------------ + # Unit tests for build_solution(prob, stats, backend) delegation + # ------------------------------------------------------------------ + # Here we verify that build_solution(prob, nlp_solution, backend) + # calls the backend's (prob, nlp_solution) method and returns whatever + # the backend returns (here a DummySolutionAPI instance). + + Test.@testset "build_solution(prob, stats, backend)" verbose=VERBOSE showtiming=SHOWTIMING begin + prob = DummyProblemAPI() + stats = DummyStatsAPI() + model_calls = Ref(0) + solution_calls = Ref(0) + backend = FakeBackendAPI(model_calls, solution_calls) + + sol = CTModels.build_solution(prob, stats, backend) + Test.@test sol isa DummySolutionAPI + Test.@test model_calls[] == 0 + Test.@test solution_calls[] == 1 + end + + # ------------------------------------------------------------------ + # Unit tests for ocp_solution(DiscretizedOptimalControlProblem, ...) + # ------------------------------------------------------------------ + Test.@testset "ocp_solution(DiscretizedOptimalControlProblem, ...)" verbose=VERBOSE showtiming=SHOWTIMING begin + docp = make_dummy_docp_for_model_api() + stats = DummyStatsAPI() + model_calls = Ref(0) + solution_calls = Ref(0) + backend = FakeBackendAPI(model_calls, solution_calls) + + sol = CTModels.ocp_solution(docp, stats, backend) + Test.@test sol isa DummySolutionAPI + Test.@test model_calls[] == 0 + Test.@test solution_calls[] == 1 + end + + # ------------------------------------------------------------------ + # Integration-style tests for build_model on real problems + # ------------------------------------------------------------------ + Test.@testset "build_model on Rosenbrock, Elec, and Max1MinusX2" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@testset "Rosenbrock" verbose=VERBOSE showtiming=SHOWTIMING begin + modeler_ad = CTModels.ADNLPModeler() + nlp_ad = CTModels.build_model(ros.prob, ros.init, modeler_ad) + Test.@test nlp_ad isa ADNLPModels.ADNLPModel + Test.@test nlp_ad.meta.x0 == ros.init + Test.@test NLPModels.obj(nlp_ad, nlp_ad.meta.x0) == + rosenbrock_objective(ros.init) + Test.@test NLPModels.cons(nlp_ad, nlp_ad.meta.x0)[1] == + rosenbrock_constraint(ros.init) + Test.@test nlp_ad.meta.minimize == rosenbrock_is_minimize() + + modeler_exa = CTModels.ExaModeler() + nlp_exa = CTModels.build_model(ros.prob, ros.init, modeler_exa) + Test.@test nlp_exa isa ExaModels.ExaModel + end + + Test.@testset "Elec" verbose=VERBOSE showtiming=SHOWTIMING begin + modeler_ad = CTModels.ADNLPModeler() + nlp_ad = CTModels.build_model(elec.prob, elec.init, modeler_ad) + Test.@test nlp_ad isa ADNLPModels.ADNLPModel + Test.@test nlp_ad.meta.x0 == vcat(elec.init.x, elec.init.y, elec.init.z) + Test.@test NLPModels.obj(nlp_ad, nlp_ad.meta.x0) == + elec_objective(elec.init.x, elec.init.y, elec.init.z) + Test.@test NLPModels.cons(nlp_ad, nlp_ad.meta.x0) == + elec_constraint(elec.init.x, elec.init.y, elec.init.z) + Test.@test nlp_ad.meta.minimize == elec_is_minimize() + + BaseType = Float64 + modeler_exa = CTModels.ExaModeler(; base_type=BaseType) + nlp_exa = CTModels.build_model(elec.prob, elec.init, modeler_exa) + Test.@test nlp_exa isa ExaModels.ExaModel{BaseType} + end + + Test.@testset "Max1MinusX2" verbose=VERBOSE showtiming=SHOWTIMING begin + modeler_ad = CTModels.ADNLPModeler() + nlp_ad = CTModels.build_model(maxd.prob, maxd.init, modeler_ad) + Test.@test nlp_ad isa ADNLPModels.ADNLPModel + Test.@test nlp_ad.meta.x0 == maxd.init + Test.@test NLPModels.obj(nlp_ad, nlp_ad.meta.x0) == + max1minusx2_objective(maxd.init) + Test.@test NLPModels.cons(nlp_ad, nlp_ad.meta.x0)[1] == + max1minusx2_constraint(maxd.init) + Test.@test nlp_ad.meta.minimize == max1minusx2_is_minimize() + + BaseType = Float64 + modeler_exa = CTModels.ExaModeler(; base_type=BaseType) + nlp_exa = CTModels.build_model(maxd.prob, maxd.init, modeler_exa) + Test.@test nlp_exa isa ExaModels.ExaModel{BaseType} + end + end +end \ No newline at end of file diff --git a/test/nlp/test_nlp_backends.jl b/test/nlp/test_nlp_backends.jl new file mode 100644 index 00000000..450e5233 --- /dev/null +++ b/test/nlp/test_nlp_backends.jl @@ -0,0 +1,437 @@ +# Unit tests for NLP backends (ADNLPModels and ExaModels) used by CTModels problems. +struct CM_DummyBackendStats <: SolverCore.AbstractExecutionStats end + +struct CM_DummyModelerMissing <: CTModels.AbstractOptimizationModeler end + +function test_nlp_backends() + + # ======================================================================== + # Problems + # ======================================================================== + ros = Rosenbrock() + elec = Elec() + maxd = Max1MinusX2() + + # ------------------------------------------------------------------ + # Low-level defaults for ADNLPModeler / ExaModeler + # ------------------------------------------------------------------ + Test.@testset "raw defaults" verbose=VERBOSE showtiming=SHOWTIMING begin + # ADNLPModels defaults + Test.@test CTModels.__adnlp_model_show_time() isa Bool + Test.@test CTModels.__adnlp_model_backend() isa Symbol + + Test.@test CTModels.__adnlp_model_show_time() == false + Test.@test CTModels.__adnlp_model_backend() == :optimized + + # ExaModels defaults + Test.@test CTModels.__exa_model_base_type() isa DataType + Test.@test CTModels.__exa_model_backend() isa Union{Nothing,Symbol} + + Test.@test CTModels.__exa_model_base_type() === Float64 + Test.@test CTModels.__exa_model_backend() === nothing + end + + # ------------------------------------------------------------------ + # ADNLPModels backends (direct calls to ADNLPModeler) + # ------------------------------------------------------------------ + # These tests exercise the call + # (modeler::ADNLPModeler)(prob, initial_guess) + # directly, without going through the generic model API. We verify + # that the resulting ADNLPModel has the correct initial point, + # objective, constraints, and that the AD backends are configured as + # expected when using the manual backend path. + Test.@testset "ADNLPModels – Rosenbrock (direct call)" verbose=VERBOSE showtiming=SHOWTIMING begin + modeler = CTModels.ADNLPModeler() + nlp_adnlp = modeler(ros.prob, ros.init) + Test.@test nlp_adnlp isa ADNLPModels.ADNLPModel + Test.@test nlp_adnlp.meta.x0 == ros.init + Test.@test NLPModels.obj(nlp_adnlp, nlp_adnlp.meta.x0) == + rosenbrock_objective(ros.init) + Test.@test NLPModels.cons(nlp_adnlp, nlp_adnlp.meta.x0)[1] == + rosenbrock_constraint(ros.init) + Test.@test nlp_adnlp.meta.minimize == rosenbrock_is_minimize() + end + + # Different CTModels problem (Elec), + # still calling the backend directly. + Test.@testset "ADNLPModels – Elec (direct call)" verbose=VERBOSE showtiming=SHOWTIMING begin + modeler = CTModels.ADNLPModeler() + nlp_adnlp = modeler(elec.prob, elec.init) + Test.@test nlp_adnlp isa ADNLPModels.ADNLPModel + Test.@test nlp_adnlp.meta.x0 == vcat(elec.init.x, elec.init.y, elec.init.z) + Test.@test NLPModels.obj(nlp_adnlp, nlp_adnlp.meta.x0) == + elec_objective(elec.init.x, elec.init.y, elec.init.z) + Test.@test NLPModels.cons(nlp_adnlp, nlp_adnlp.meta.x0) == + elec_constraint(elec.init.x, elec.init.y, elec.init.z) + Test.@test nlp_adnlp.meta.minimize == elec_is_minimize() + end + + # 1D maximization problem: Max1MinusX2 + Test.@testset "ADNLPModels – Max1MinusX2 (direct call)" verbose=VERBOSE showtiming=SHOWTIMING begin + modeler = CTModels.ADNLPModeler() + nlp_adnlp = modeler(maxd.prob, maxd.init) + Test.@test nlp_adnlp isa ADNLPModels.ADNLPModel + Test.@test nlp_adnlp.meta.x0 == maxd.init + Test.@test NLPModels.obj(nlp_adnlp, nlp_adnlp.meta.x0) == + max1minusx2_objective(maxd.init) + Test.@test NLPModels.cons(nlp_adnlp, nlp_adnlp.meta.x0)[1] == + max1minusx2_constraint(maxd.init) + Test.@test nlp_adnlp.meta.minimize == max1minusx2_is_minimize() + end + + # For a problem without specialized get_* methods, ADNLPModeler + # should surface the generic NotImplemented error from get_adnlp_model_builder + # even when called directly. + Test.@testset "ADNLPModels – DummyProblem (NotImplemented, direct call)" verbose=VERBOSE showtiming=SHOWTIMING begin + modeler = CTModels.ADNLPModeler() + Test.@test_throws CTBase.NotImplemented modeler(DummyProblem(), ros.init) + end + + # ------------------------------------------------------------------ + # ExaModels backends (direct calls to ExaModeler, CPU) + # ------------------------------------------------------------------ + # These tests exercise the call + # (modeler::ExaModeler)(prob, initial_guess) + # directly, using a concrete BaseType (Float32). + Test.@testset "ExaModels (CPU) – Rosenbrock (BaseType=Float32, direct call)" verbose=VERBOSE showtiming=SHOWTIMING begin + BaseType = Float32 + modeler = CTModels.ExaModeler(; base_type=BaseType) + nlp_exa_cpu = modeler(ros.prob, ros.init) + Test.@test nlp_exa_cpu isa ExaModels.ExaModel{BaseType} + Test.@test nlp_exa_cpu.meta.x0 == BaseType.(ros.init) + Test.@test eltype(nlp_exa_cpu.meta.x0) == BaseType + Test.@test NLPModels.obj(nlp_exa_cpu, nlp_exa_cpu.meta.x0) == + rosenbrock_objective(BaseType.(ros.init)) + Test.@test NLPModels.cons(nlp_exa_cpu, nlp_exa_cpu.meta.x0)[1] == + rosenbrock_constraint(BaseType.(ros.init)) + Test.@test nlp_exa_cpu.meta.minimize == rosenbrock_is_minimize() + end + + # Same ExaModels backend but on the Elec problem, with direct backend call. + Test.@testset "ExaModels (CPU) – Elec (BaseType=Float32, direct call)" begin + BaseType = Float32 + modeler = CTModels.ExaModeler(; base_type=BaseType) + nlp_exa_cpu = modeler(elec.prob, elec.init) + Test.@test nlp_exa_cpu isa ExaModels.ExaModel{BaseType} + Test.@test nlp_exa_cpu.meta.x0 == + BaseType.(vcat(elec.init.x, elec.init.y, elec.init.z)) + Test.@test eltype(nlp_exa_cpu.meta.x0) == BaseType + Test.@test NLPModels.obj(nlp_exa_cpu, nlp_exa_cpu.meta.x0) == elec_objective( + BaseType.(elec.init.x), BaseType.(elec.init.y), BaseType.(elec.init.z) + ) + Test.@test NLPModels.cons(nlp_exa_cpu, nlp_exa_cpu.meta.x0) == elec_constraint( + BaseType.(elec.init.x), BaseType.(elec.init.y), BaseType.(elec.init.z) + ) + Test.@test nlp_exa_cpu.meta.minimize == elec_is_minimize() + end + + Test.@testset "ExaModels (CPU) – Max1MinusX2 (BaseType=Float32, direct call)" verbose=VERBOSE showtiming=SHOWTIMING begin + BaseType = Float32 + modeler = CTModels.ExaModeler(; base_type=BaseType) + nlp_exa_cpu = modeler(maxd.prob, maxd.init) + Test.@test nlp_exa_cpu isa ExaModels.ExaModel{BaseType} + Test.@test nlp_exa_cpu.meta.x0 == BaseType.(maxd.init) + Test.@test eltype(nlp_exa_cpu.meta.x0) == BaseType + Test.@test NLPModels.obj(nlp_exa_cpu, nlp_exa_cpu.meta.x0) == + max1minusx2_objective(BaseType.(maxd.init)) + Test.@test NLPModels.cons(nlp_exa_cpu, nlp_exa_cpu.meta.x0)[1] == + max1minusx2_constraint(BaseType.(maxd.init)) + Test.@test nlp_exa_cpu.meta.minimize == max1minusx2_is_minimize() + end + + # For a problem without specialized get_* methods, ExaModeler + # should surface the generic NotImplemented error from get_exa_model_builder + # even when called directly. + Test.@testset "ExaModels (CPU) – DummyProblem (NotImplemented, direct call)" verbose=VERBOSE showtiming=SHOWTIMING begin + modeler = CTModels.ExaModeler() + Test.@test_throws CTBase.NotImplemented modeler(DummyProblem(), ros.init) + end + + # ------------------------------------------------------------------ + # Constructor-level tests for ADNLPModeler and ExaModeler + # ------------------------------------------------------------------ + # These tests now focus on the options_values / options_sources + # NamedTuples exposed via _options / _option_sources. + + Test.@testset "ADNLPModeler constructor" verbose=VERBOSE showtiming=SHOWTIMING begin + # Default constructor should use the values from ctmodels/default.jl + backend_default = CTModels.ADNLPModeler() + vals_default = CTModels._options_values(backend_default) + srcs_default = CTModels._option_sources(backend_default) + + Test.@test vals_default.show_time == CTModels.__adnlp_model_show_time() + Test.@test vals_default.backend == CTModels.__adnlp_model_backend() + Test.@test all(srcs_default[k] == :ct_default for k in propertynames(srcs_default)) + + # Custom backend and extra kwargs should be stored with provenance + backend_manual = CTModels.ADNLPModeler(; backend=:toto, foo=1) + vals_manual = CTModels._options_values(backend_manual) + srcs_manual = CTModels._option_sources(backend_manual) + + Test.@test vals_manual.backend == :toto + Test.@test srcs_manual.backend == :user + Test.@test vals_manual.foo == 1 + Test.@test srcs_manual.foo == :user + end + + Test.@testset "ExaModeler constructor" verbose=VERBOSE showtiming=SHOWTIMING begin + # Default constructor should use backend from ctmodels/default.jl + exa_default = CTModels.ExaModeler() + vals_default = CTModels._options_values(exa_default) + srcs_default = CTModels._option_sources(exa_default) + + Test.@test vals_default.backend === CTModels.__exa_model_backend() + Test.@test srcs_default.backend == :ct_default + + # Custom base_type and kwargs: base_type is reflected in the modeler type, + # while remaining options and their provenance are tracked as usual. + exa_custom = CTModels.ExaModeler(; base_type=Float32) + vals_custom = CTModels._options_values(exa_custom) + srcs_custom = CTModels._option_sources(exa_custom) + + Test.@test exa_custom isa CTModels.ExaModeler{Float32} + Test.@test vals_custom.backend === CTModels.__exa_model_backend() + Test.@test srcs_custom.backend == :ct_default + + # Unknown options should now be rejected for ExaModeler (strict_keys=true). + err = nothing + try + CTModels.ExaModeler(; base_type=Float32, foo=2) + catch e + err = e + end + Test.@test err isa CTBase.IncorrectArgument + buf = sprint(showerror, err) + Test.@test occursin("Unknown option foo", buf) + Test.@test occursin("show_options(ExaModeler)", buf) + end + + # ------------------------------------------------------------------ + # Options metadata and validation helpers for ADNLPModeler/ExaModeler + # ------------------------------------------------------------------ + + Test.@testset "ADNLPModeler options metadata and validation" verbose=VERBOSE showtiming=SHOWTIMING begin + keys_ad = CTModels.options_keys(CTModels.ADNLPModeler) + Test.@test :show_time in keys_ad + Test.@test :backend in keys_ad + + ad_backend = CTModels.ADNLPModeler() + ad_type_from_instance = typeof(ad_backend) + + keys_ad_inst = CTModels.options_keys(ad_type_from_instance) + Test.@test Set(keys_ad_inst) == Set(keys_ad) + + Test.@test CTModels.option_type(:show_time, CTModels.ADNLPModeler) == Bool + Test.@test CTModels.option_type(:backend, CTModels.ADNLPModeler) == Symbol + + Test.@test CTModels.option_type(:show_time, ad_type_from_instance) == Bool + Test.@test CTModels.option_type(:backend, ad_type_from_instance) == Symbol + + desc_backend = CTModels.option_description(:backend, CTModels.ADNLPModeler) + Test.@test desc_backend isa AbstractString + Test.@test !isempty(desc_backend) + + desc_backend_inst = CTModels.option_description(:backend, ad_type_from_instance) + Test.@test desc_backend_inst isa AbstractString + Test.@test !isempty(desc_backend_inst) + + # Invalid type for a known option should trigger a CTBase.IncorrectArgument + Test.@test_throws CTBase.IncorrectArgument CTModels.ADNLPModeler(; show_time="yes") + end + + Test.@testset "ExaModeler options metadata and validation" verbose=VERBOSE showtiming=SHOWTIMING begin + keys_exa = CTModels.options_keys(CTModels.ExaModeler) + Test.@test :base_type in keys_exa + Test.@test :backend in keys_exa + Test.@test :minimize in keys_exa + + exa_backend = CTModels.ExaModeler() + exa_type_from_instance = typeof(exa_backend) + + keys_exa_inst = CTModels.options_keys(exa_type_from_instance) + Test.@test Set(keys_exa_inst) == Set(keys_exa) + + Test.@test CTModels.option_type(:base_type, CTModels.ExaModeler) <: + Type{<:AbstractFloat} + Test.@test CTModels.option_type(:minimize, CTModels.ExaModeler) == Bool + + Test.@test CTModels.option_type(:base_type, exa_type_from_instance) <: + Type{<:AbstractFloat} + Test.@test CTModels.option_type(:minimize, exa_type_from_instance) == Bool + + # Invalid type for a known option should trigger a CTBase.IncorrectArgument + Test.@test_throws CTBase.IncorrectArgument CTModels.ExaModeler(; minimize=1) + end + + Test.@testset "ExaModeler unknown option suggestions" verbose=VERBOSE showtiming=SHOWTIMING begin + err = nothing + try + CTModels._validate_option_kwargs( + (minimise=true,), CTModels.ExaModeler; strict_keys=true + ) + catch e + err = e + end + Test.@test err isa CTBase.IncorrectArgument + buf = sprint(showerror, err) + Test.@test occursin("Unknown option minimise", buf) + Test.@test occursin("minimize", buf) + Test.@test occursin("show_options(ExaModeler)", buf) + end + + Test.@testset "default_options and option_default" verbose=VERBOSE showtiming=SHOWTIMING begin + # ADNLPModeler defaults should be consistent between helpers and metadata. + opts_ad = CTModels.default_options(CTModels.ADNLPModeler) + Test.@test opts_ad.show_time == CTModels.__adnlp_model_show_time() + Test.@test opts_ad.backend == CTModels.__adnlp_model_backend() + + ad_backend = CTModels.ADNLPModeler() + ad_type_from_instance = typeof(ad_backend) + + opts_ad_inst = CTModels.default_options(ad_type_from_instance) + Test.@test opts_ad_inst == opts_ad + + Test.@test CTModels.option_default(:show_time, CTModels.ADNLPModeler) == + CTModels.__adnlp_model_show_time() + Test.@test CTModels.option_default(:backend, CTModels.ADNLPModeler) == + CTModels.__adnlp_model_backend() + + Test.@test CTModels.option_default(:show_time, ad_type_from_instance) == + CTModels.__adnlp_model_show_time() + Test.@test CTModels.option_default(:backend, ad_type_from_instance) == + CTModels.__adnlp_model_backend() + + # ExaModeler defaults: base_type and backend have defaults, minimize has none. + opts_exa = CTModels.default_options(CTModels.ExaModeler) + Test.@test opts_exa.base_type === CTModels.__exa_model_base_type() + Test.@test opts_exa.backend === CTModels.__exa_model_backend() + Test.@test :minimize ∉ propertynames(opts_exa) + + exa_backend = CTModels.ExaModeler() + exa_type_from_instance = typeof(exa_backend) + + opts_exa_inst = CTModels.default_options(exa_type_from_instance) + Test.@test opts_exa_inst == opts_exa + + Test.@test CTModels.option_default(:base_type, CTModels.ExaModeler) === + CTModels.__exa_model_base_type() + Test.@test CTModels.option_default(:backend, CTModels.ExaModeler) === + CTModels.__exa_model_backend() + Test.@test CTModels.option_default(:minimize, CTModels.ExaModeler) === missing + + Test.@test CTModels.option_default(:base_type, exa_type_from_instance) === + CTModels.__exa_model_base_type() + Test.@test CTModels.option_default(:backend, exa_type_from_instance) === + CTModels.__exa_model_backend() + Test.@test CTModels.option_default(:minimize, exa_type_from_instance) === missing + end + + Test.@testset "modeler symbols and registry" verbose=VERBOSE showtiming=SHOWTIMING begin + # get_symbol on types and instances + Test.@test CTModels.get_symbol(CTModels.ADNLPModeler) == :adnlp + Test.@test CTModels.get_symbol(CTModels.ExaModeler) == :exa + Test.@test CTModels.get_symbol(CTModels.ADNLPModeler()) == :adnlp + Test.@test CTModels.get_symbol(CTModels.ExaModeler()) == :exa + + # tool_package_name on types and instances + Test.@test CTModels.tool_package_name(CTModels.ADNLPModeler) == "ADNLPModels" + Test.@test CTModels.tool_package_name(CTModels.ExaModeler) == "ExaModels" + Test.@test CTModels.tool_package_name(CTModels.ADNLPModeler()) == "ADNLPModels" + Test.@test CTModels.tool_package_name(CTModels.ExaModeler()) == "ExaModels" + + regs = CTModels.registered_modeler_types() + Test.@test CTModels.ADNLPModeler in regs + Test.@test CTModels.ExaModeler in regs + + syms = CTModels.modeler_symbols() + Test.@test :adnlp in syms + Test.@test :exa in syms + + # build_modeler_from_symbol should construct proper concrete modelers. + m_ad = CTModels.build_modeler_from_symbol(:adnlp; backend=:manual) + Test.@test m_ad isa CTModels.ADNLPModeler + vals_ad = CTModels._options_values(m_ad) + Test.@test vals_ad.backend == :manual + + m_exa = CTModels.build_modeler_from_symbol(:exa; base_type=Float32) + Test.@test m_exa isa CTModels.ExaModeler{Float32} + end + + Test.@testset "build_modeler_from_symbol unknown symbol" verbose=VERBOSE showtiming=SHOWTIMING begin + err = nothing + try + CTModels.build_modeler_from_symbol(:foo) + catch e + err = e + end + Test.@test err isa CTBase.IncorrectArgument + + buf = sprint(showerror, err) + Test.@test occursin("Unknown NLP model symbol", buf) + Test.@test occursin("foo", buf) + # The message should list the supported symbols from modeler_symbols(). + for sym in CTModels.modeler_symbols() + Test.@test occursin(string(sym), buf) + end + end + + Test.@testset "tool_package_name default implementation" verbose=VERBOSE showtiming=SHOWTIMING begin + # For types without specialization, tool_package_name should return missing. + dummy = CM_DummyModelerMissing() + Test.@test CTModels.tool_package_name(CM_DummyModelerMissing) === missing + Test.@test CTModels.tool_package_name(dummy) === missing + end + + # ------------------------------------------------------------------ + # Solution-building via ADNLPModeler/ExaModeler(prob, nlp_solution) + # ------------------------------------------------------------------ + # For OptimizationProblem (defined in test/problems/problems_definition.jl), + # get_adnlp_solution_builder and get_exa_solution_builder return custom + # solution builders (ADNLPSolutionBuilder, ExaSolutionBuilder) that are + # callable on the nlp_solution and simply return it unchanged. Here we + # verify that the backends correctly route through those builders. + + Test.@testset "ADNLPModeler solution building" verbose=VERBOSE showtiming=SHOWTIMING begin + # Build an OptimizationProblem with dummy builders (unused in this test) + dummy_ad_builder = CTModels.ADNLPModelBuilder(x -> error("unused")) + function dummy_exa_builder_f(::Type{T}, x; kwargs...) where {T} + error("unused") + end + dummy_exa_builder = CTModels.ExaModelBuilder(dummy_exa_builder_f) + prob = OptimizationProblem( + dummy_ad_builder, + dummy_exa_builder, + ADNLPSolutionBuilder(), + ExaSolutionBuilder(), + ) + + stats = CM_DummyBackendStats() + modeler = CTModels.ADNLPModeler() + # Should call get_adnlp_solution_builder(prob) and then + # builder(stats), which is implemented in problems_definition.jl + # to return stats unchanged. + result = modeler(prob, stats) + Test.@test result === stats + end + + Test.@testset "ExaModeler solution building" verbose=VERBOSE showtiming=SHOWTIMING begin + dummy_ad_builder = CTModels.ADNLPModelBuilder(x -> error("unused")) + function dummy_exa_builder_f2(::Type{T}, x; kwargs...) where {T} + error("unused") + end + dummy_exa_builder = CTModels.ExaModelBuilder(dummy_exa_builder_f2) + prob = OptimizationProblem( + dummy_ad_builder, + dummy_exa_builder, + ADNLPSolutionBuilder(), + ExaSolutionBuilder(), + ) + + stats = CM_DummyBackendStats() + modeler = CTModels.ExaModeler() + # Should call get_exa_solution_builder(prob) and then + # builder(stats), which returns stats. + result = modeler(prob, stats) + Test.@test result === stats + end +end \ No newline at end of file diff --git a/test/nlp/test_options_schema.jl b/test/nlp/test_options_schema.jl new file mode 100644 index 00000000..d5a0b77d --- /dev/null +++ b/test/nlp/test_options_schema.jl @@ -0,0 +1,232 @@ +# Unit tests for generic options schema utilities (OptionSpec and helpers). + +# Dummy tool types for exercising the generic API +struct CM_DummyToolNoSpecs <: CTModels.AbstractOCPTool end + +struct CM_DummyToolWithSpecs <: CTModels.AbstractOCPTool + options_values + options_sources +end + +CTModels._option_specs(::Type{CM_DummyToolNoSpecs}) = missing + +function CTModels._option_specs(::Type{CM_DummyToolWithSpecs}) + ( + max_iter=CTModels.OptionSpec(; + type=Int, default=100, description="Max iterations" + ), + tol=CTModels.OptionSpec(; type=Float64, default=1e-6, description="Tolerance"), + verbose=CTModels.OptionSpec(; type=Bool, default=missing, description=missing), + ) +end + +function test_options_schema() + + # ======================================================================== + # METADATA ACCESSORS (options_keys, is_an_option_key, option_* helpers) + # ======================================================================== + + Test.@testset "metadata accessors" verbose=VERBOSE showtiming=SHOWTIMING begin + # No specs: options_keys / is_an_option_key / option_* should return missing + Test.@test CTModels.options_keys(CM_DummyToolNoSpecs) === missing + Test.@test CTModels.is_an_option_key(:foo, CM_DummyToolNoSpecs) === missing + Test.@test CTModels.option_type(:foo, CM_DummyToolNoSpecs) === missing + Test.@test CTModels.option_description(:foo, CM_DummyToolNoSpecs) === missing + Test.@test CTModels.option_default(:foo, CM_DummyToolNoSpecs) === missing + Test.@test CTModels.default_options(CM_DummyToolNoSpecs) == NamedTuple() + + # With specs + keys = CTModels.options_keys(CM_DummyToolWithSpecs) + Test.@test Set(keys) == Set((:max_iter, :tol, :verbose)) + + Test.@test CTModels.is_an_option_key(:max_iter, CM_DummyToolWithSpecs) + Test.@test !CTModels.is_an_option_key(:foo, CM_DummyToolWithSpecs) + + Test.@test CTModels.option_type(:max_iter, CM_DummyToolWithSpecs) == Int + Test.@test CTModels.option_type(:tol, CM_DummyToolWithSpecs) == Float64 + Test.@test CTModels.option_type(:foo, CM_DummyToolWithSpecs) === missing + + Test.@test CTModels.option_description(:max_iter, CM_DummyToolWithSpecs) isa + AbstractString + Test.@test CTModels.option_description(:verbose, CM_DummyToolWithSpecs) === missing + + Test.@test CTModels.option_default(:max_iter, CM_DummyToolWithSpecs) == 100 + Test.@test CTModels.option_default(:tol, CM_DummyToolWithSpecs) == 1e-6 + Test.@test CTModels.option_default(:verbose, CM_DummyToolWithSpecs) === missing + + # default_options should include only non-missing defaults + defs = CTModels.default_options(CM_DummyToolWithSpecs) + Test.@test Set(propertynames(defs)) == Set((:max_iter, :tol)) + Test.@test defs.max_iter == 100 + Test.@test defs.tol == 1e-6 + + # Instance-based accessors should behave like the type-based ones + vals_inst, srcs_inst = CTModels._build_ocp_tool_options(CM_DummyToolWithSpecs) + tool_inst = CM_DummyToolWithSpecs(vals_inst, srcs_inst) + + keys_from_type = CTModels.options_keys(CM_DummyToolWithSpecs) + keys_from_inst = CTModels.options_keys(tool_inst) + Test.@test Set(keys_from_inst) == Set(keys_from_type) + + defs_from_type = CTModels.default_options(CM_DummyToolWithSpecs) + defs_from_inst = CTModels.default_options(tool_inst) + Test.@test defs_from_inst == defs_from_type + + Test.@test CTModels.option_default(:max_iter, tool_inst) == 100 + Test.@test CTModels.option_default(:tol, tool_inst) == 1e-6 + Test.@test CTModels.option_default(:verbose, tool_inst) === missing + end + + # ======================================================================== + # _filter_options + # ======================================================================== + + Test.@testset "_filter_options" verbose=VERBOSE showtiming=SHOWTIMING begin + nt = (a=1, b=2, c=3) + filtered = CTModels._filter_options(nt, (:b,)) + Test.@test Set(propertynames(filtered)) == Set((:a, :c)) + Test.@test filtered.a == 1 + Test.@test filtered.c == 3 + end + + # ======================================================================== + # _string_distance and _suggest_option_keys + # ======================================================================== + + Test.@testset "suggestions" verbose=VERBOSE showtiming=SHOWTIMING begin + # A simple sanity check on the distance function + d_exact = CTModels._string_distance("max_iter", "max_iter") + d_close = CTModels._string_distance("max_iter", "mx_iter") + d_far = CTModels._string_distance("max_iter", "tol") + Test.@test d_exact == 0 + Test.@test d_close < d_far + + # Suggestions should prioritize the closest known key + sugg = CTModels._suggest_option_keys(:mx_iter, CM_DummyToolWithSpecs) + Test.@test length(sugg) >= 1 + Test.@test sugg[1] == :max_iter + end + + # ======================================================================== + # get_option_value / get_option_source / get_option_default + # ======================================================================== + + Test.@testset "get_option_*" verbose=VERBOSE showtiming=SHOWTIMING begin + # Build values/sources using the generic constructor + vals, srcs = CTModels._build_ocp_tool_options(CM_DummyToolWithSpecs; tol=1e-4) + tool = CM_DummyToolWithSpecs(vals, srcs) + + # Known options with and without user override + Test.@test CTModels.get_option_value(tool, :max_iter) == 100 + Test.@test CTModels.get_option_source(tool, :max_iter) == :ct_default + Test.@test CTModels.get_option_default(tool, :max_iter) == 100 + + Test.@test CTModels.get_option_value(tool, :tol) == 1e-4 + Test.@test CTModels.get_option_source(tool, :tol) == :user + Test.@test CTModels.get_option_default(tool, :tol) == 1e-6 + + # Known option declared but with no default and not set by the user + err_no_val = nothing + try + CTModels.get_option_value(tool, :verbose) + catch e + err_no_val = e + end + Test.@test err_no_val isa CTBase.IncorrectArgument + buf_no_val = sprint(showerror, err_no_val) + # Basic sanity: error message should be non-empty + Test.@test !isempty(buf_no_val) + + # Unknown option key should trigger an IncorrectArgument with suggestions + err_unknown = nothing + try + CTModels.get_option_value(tool, :mx_iter) + catch e + err_unknown = e + end + Test.@test err_unknown isa CTBase.IncorrectArgument + buf_unknown = sprint(showerror, err_unknown) + Test.@test occursin("Unknown option mx_iter", buf_unknown) + Test.@test occursin("max_iter", buf_unknown) + Test.@test occursin("show_options(CM_DummyToolWithSpecs)", buf_unknown) + end + + # ======================================================================== + # _show_options + # ======================================================================== + + Test.@testset "_show_options" verbose=VERBOSE showtiming=SHOWTIMING begin + # Just ensure that calling _show_options on both dummy tools does not throw, + # while silencing the printed output. + redirect_stdout(devnull) do + CTModels.show_options(CM_DummyToolNoSpecs) + CTModels.show_options(CM_DummyToolWithSpecs) + end + Test.@test true + end + + # ======================================================================== + # _validate_option_kwargs + # ======================================================================== + + Test.@testset "_validate_option_kwargs" verbose=VERBOSE showtiming=SHOWTIMING begin + # No specs: nothing should be validated or rejected + CTModels._validate_option_kwargs((foo=1,), CM_DummyToolNoSpecs; strict_keys=false) + + # Known keys with correct types + CTModels._validate_option_kwargs( + (max_iter=200, tol=1e-5), CM_DummyToolWithSpecs; strict_keys=false + ) + + # Unknown key with strict_keys = false should be accepted + CTModels._validate_option_kwargs( + (foo=1,), CM_DummyToolWithSpecs; strict_keys=false + ) + + # Unknown key with strict_keys = true should error with suggestions + err_unknown = nothing + try + CTModels._validate_option_kwargs( + (mx_iter=10,), CM_DummyToolWithSpecs; strict_keys=true + ) + catch e + err_unknown = e + end + Test.@test err_unknown isa CTBase.IncorrectArgument + buf = sprint(showerror, err_unknown) + Test.@test occursin("Unknown option mx_iter", buf) + Test.@test occursin("max_iter", buf) + Test.@test occursin("show_options(CM_DummyToolWithSpecs)", buf) + + # Wrong type for a known option should error + err_type = nothing + try + CTModels._validate_option_kwargs( + (tol="1e-6",), CM_DummyToolWithSpecs; strict_keys=false + ) + catch e + err_type = e + end + Test.@test err_type isa CTBase.IncorrectArgument + buf_type = sprint(showerror, err_type) + Test.@test occursin("Invalid type for option tol", buf_type) + end + + # ======================================================================== + # _build_ocp_tool_options + # ======================================================================== + + Test.@testset "_build_ocp_tool_options" verbose=VERBOSE showtiming=SHOWTIMING begin + # With specs: defaults merged with user overrides and provenance tracked + vals, srcs = CTModels._build_ocp_tool_options(CM_DummyToolWithSpecs; tol=1e-4) + Test.@test vals.max_iter == 100 + Test.@test vals.tol == 1e-4 + Test.@test srcs.max_iter == :ct_default + Test.@test srcs.tol == :user + + # Without specs: user kwargs should pass through unchanged and be marked as :user + vals2, srcs2 = CTModels._build_ocp_tool_options(CM_DummyToolNoSpecs; foo=1, bar=2) + Test.@test vals2 == (foo=1, bar=2) + Test.@test srcs2 == (foo=:user, bar=:user) + end +end \ No newline at end of file diff --git a/test/nlp/test_problem_core.jl b/test/nlp/test_problem_core.jl new file mode 100644 index 00000000..e627811e --- /dev/null +++ b/test/nlp/test_problem_core.jl @@ -0,0 +1,103 @@ +# Unit tests for CTModels problem-specific core builders (e.g. Rosenbrock). +function test_problem_core() + + # ======================================================================== + # Problems + # ======================================================================== + ros = Rosenbrock() + + # Tests for problem-specific model builders provided by CTModels problems + # (here the Rosenbrock problem exposes its own build_adnlp_model/build_exa_model). + Test.@testset "ADNLPModels – Rosenbrock (specific builder)" verbose=VERBOSE showtiming=SHOWTIMING begin + nlp_adnlp = ros.prob.build_adnlp_model(ros.init; show_time=false) + Test.@test nlp_adnlp isa ADNLPModels.ADNLPModel + Test.@test nlp_adnlp.meta.x0 == ros.init + Test.@test NLPModels.obj(nlp_adnlp, nlp_adnlp.meta.x0) == + rosenbrock_objective(ros.init) + Test.@test NLPModels.cons(nlp_adnlp, nlp_adnlp.meta.x0)[1] == + rosenbrock_constraint(ros.init) + Test.@test nlp_adnlp.meta.minimize == rosenbrock_is_minimize() + end + + Test.@testset "ExaModels (CPU) – Rosenbrock (specific builder, BaseType=Float32)" verbose=VERBOSE showtiming=SHOWTIMING begin + BaseType = Float32 + nlp_exa_cpu = ros.prob.build_exa_model(BaseType, ros.init) + Test.@test nlp_exa_cpu isa ExaModels.ExaModel{BaseType} + Test.@test nlp_exa_cpu.meta.x0 == BaseType.(ros.init) + Test.@test eltype(nlp_exa_cpu.meta.x0) == BaseType + Test.@test NLPModels.obj(nlp_exa_cpu, nlp_exa_cpu.meta.x0) == + rosenbrock_objective(BaseType.(ros.init)) + Test.@test NLPModels.cons(nlp_exa_cpu, nlp_exa_cpu.meta.x0)[1] == + rosenbrock_constraint(BaseType.(ros.init)) + Test.@test nlp_exa_cpu.meta.minimize == rosenbrock_is_minimize() + end + + # Tests for the generic ADNLPModelBuilder wrapper (higher-order function + # that delegates to an arbitrary callable). Here we build a simple + # ADNLPModel to respect the return type annotation ::ADNLPModels.ADNLPModel + # and we verify that the inner builder is called exactly once with the + # expected initial guess, and that keyword arguments are forwarded. + Test.@testset "ADNLPModelBuilder wrapper" verbose=VERBOSE showtiming=SHOWTIMING begin + calls = Ref(0) + last_x = Ref{Any}(nothing) + function local_ad_builder(x; kwargs...) + calls[] += 1 + last_x[] = x + f(z) = sum(z .^ 2) + return ADNLPModels.ADNLPModel(f, x) + end + + builder = CTModels.ADNLPModelBuilder(local_ad_builder) + x0 = ros.init + nlp = builder(x0) # no extra kwargs to keep ADNLPModel signature simple + + Test.@test nlp isa ADNLPModels.ADNLPModel + Test.@test calls[] == 1 + Test.@test last_x[] == x0 + + # Keyword arguments should be forwarded to the inner builder. + kw_calls = Ref(0) + seen_kwargs = Ref{Any}(nothing) + function local_ad_builder_kwargs(x; a=0, b=0) + kw_calls[] += 1 + seen_kwargs[] = (x, a, b) + f(z) = sum(z .^ 2) + return ADNLPModels.ADNLPModel(f, x) + end + + builder_kwargs = CTModels.ADNLPModelBuilder(local_ad_builder_kwargs) + x1 = ros.init + _ = builder_kwargs(x1; a=1, b=2) + + Test.@test kw_calls[] == 1 + Test.@test seen_kwargs[] == (x1, 1, 2) + end + + # Tests for the generic ExaModelBuilder wrapper. Constructing a full + # ExaModels.ExaModel instance in isolation is non-trivial, and the + # call operator is annotated to return ::ExaModels.ExaModel. To avoid + # fragile tests that depend on ExaModels internals, we limit ourselves + # to checking that the wrapped callable is correctly stored inside + # ExaModelBuilder. + Test.@testset "ExaModelBuilder wrapper" verbose=VERBOSE showtiming=SHOWTIMING begin + function local_exa_builder(::Type{BaseType}, x; foo=1) where {BaseType} + return (:exa_builder_called, BaseType, x, foo) + end + + builder = CTModels.ExaModelBuilder(local_exa_builder) + + Test.@test builder.f === local_exa_builder + Test.@test builder isa CTModels.ExaModelBuilder{typeof(local_exa_builder)} + end + + # Tests for the generic "NotImplemented" behaviour of the get_* functions + # when called on a problem type that has no specialized implementation. + Test.@testset "generic get_* NotImplemented" verbose=VERBOSE showtiming=SHOWTIMING begin + dummy = DummyProblem() + + Test.@test_throws CTBase.NotImplemented CTModels.get_adnlp_model_builder(dummy) + Test.@test_throws CTBase.NotImplemented CTModels.get_exa_model_builder(dummy) + Test.@test_throws CTBase.NotImplemented CTModels.get_adnlp_solution_builder(dummy) + Test.@test_throws CTBase.NotImplemented CTModels.get_exa_solution_builder(dummy) + end +end \ No newline at end of file diff --git a/test/test_constraints.jl b/test/ocp/test_constraints.jl similarity index 100% rename from test/test_constraints.jl rename to test/ocp/test_constraints.jl diff --git a/test/test_control.jl b/test/ocp/test_control.jl similarity index 88% rename from test/test_control.jl rename to test/ocp/test_control.jl index 62c8a2cf..f9a7d775 100644 --- a/test/test_control.jl +++ b/test/ocp/test_control.jl @@ -1,13 +1,8 @@ function test_control() # - @test isconcretetype(CTModels.ControlModel) # ControlModel - control = CTModels.ControlModel("u", ["u₁", "u₂"]) - @test CTModels.dimension(control) == 2 - @test CTModels.name(control) == "u" - @test CTModels.components(control) == ["u₁", "u₂"] # some checks ocp = CTModels.PreModel() diff --git a/test/ocp/test_definition.jl b/test/ocp/test_definition.jl new file mode 100644 index 00000000..43073559 --- /dev/null +++ b/test/ocp/test_definition.jl @@ -0,0 +1,53 @@ +function test_definition() + # TODO: add tests for src/ocp/definition.jl. + + # ======================================================================== + # Unit tests – setters/getters on PreModel and Model + # ======================================================================== + + Test.@testset "definition! and definition on PreModel" verbose=VERBOSE showtiming=SHOWTIMING begin + pre = CTModels.PreModel() + expr = :(x = 1) + + CTModels.definition!(pre, expr) + + Test.@test CTModels.definition(pre) === expr + end + + # ======================================================================== + # Integration-style tests – definition propagated through build + # ======================================================================== + + Test.@testset "definition carried to Model after build" verbose=VERBOSE showtiming=SHOWTIMING begin + pre = CTModels.PreModel() + + # Minimal consistent problem using the high-level API + CTModels.time!(pre; t0=0.0, tf=1.0) + CTModels.state!(pre, 1) + CTModels.control!(pre, 1) + CTModels.variable!(pre, 0) + + dyn!(r, t, x, u, v) = r .= 0 + CTModels.dynamics!(pre, dyn!) + + mayer(x0, xf, v) = 0.0 + lagrange(t, x, u, v) = 0.0 + CTModels.objective!(pre, :min; mayer=mayer, lagrange=lagrange) + + expr = quote + t ∈ [0, 1], time + x ∈ R, state + u ∈ R, control + ẋ(t) == u(t) + ∫(0.5u(t)^2) → min + end + + CTModels.definition!(pre, expr) + CTModels.time_dependence!(pre; autonomous=false) + + model = CTModels.build(pre) + + Test.@test CTModels.definition(model) === expr + end +end + diff --git a/test/ocp/test_dual_model.jl b/test/ocp/test_dual_model.jl new file mode 100644 index 00000000..423fecb7 --- /dev/null +++ b/test/ocp/test_dual_model.jl @@ -0,0 +1,30 @@ +function test_dual_model() + # TODO: add tests for src/ocp/dual_model.jl. + + # ======================================================================== + # Unit tests – low-level DualModel accessors + # ======================================================================== + + Test.@testset "DualModel constraint dual accessors" verbose=VERBOSE showtiming=SHOWTIMING begin + pc = t -> [1.0, 2.0] + bc = [3.0, 4.0] + sc_lb = t -> [0.0] + sc_ub = t -> [1.0] + cc_lb = t -> [0.0] + cc_ub = t -> [1.0] + vc_lb = [5.0] + vc_ub = [6.0] + + dual = CTModels.DualModel(pc, bc, sc_lb, sc_ub, cc_lb, cc_ub, vc_lb, vc_ub) + + Test.@test CTModels.path_constraints_dual(dual) === pc + Test.@test CTModels.boundary_constraints_dual(dual) === bc + Test.@test CTModels.state_constraints_lb_dual(dual) === sc_lb + Test.@test CTModels.state_constraints_ub_dual(dual) === sc_ub + Test.@test CTModels.control_constraints_lb_dual(dual) === cc_lb + Test.@test CTModels.control_constraints_ub_dual(dual) === cc_ub + Test.@test CTModels.variable_constraints_lb_dual(dual) === vc_lb + Test.@test CTModels.variable_constraints_ub_dual(dual) === vc_ub + end +end + diff --git a/test/test_dynamics.jl b/test/ocp/test_dynamics.jl similarity index 100% rename from test/test_dynamics.jl rename to test/ocp/test_dynamics.jl diff --git a/test/test_model.jl b/test/ocp/test_model.jl similarity index 97% rename from test/test_model.jl rename to test/ocp/test_model.jl index f84d8a73..9d8827d3 100644 --- a/test/test_model.jl +++ b/test/ocp/test_model.jl @@ -162,8 +162,9 @@ function test_model() @test CTModels.constraint(model, :control_scalar_2)[4] == 17 @test CTModels.constraint(model, :variable_scalar_2)[4] == 18 - # print the premodel - display(pre_ocp) + # print the premodel (captured, no terminal output) + io = IOBuffer() + show(io, MIME"text/plain"(), pre_ocp) # -------------------------------------------------------------------------- # # Just for printing @@ -177,7 +178,8 @@ function test_model() CTModels.objective!(pre_ocp, :min; mayer=mayer, lagrange=lagrange) CTModels.definition!(pre_ocp, quote end) CTModels.time_dependence!(pre_ocp; autonomous=false) - display(pre_ocp) + io = IOBuffer() + show(io, MIME"text/plain"(), pre_ocp) # pre_ocp = CTModels.PreModel() @@ -189,5 +191,6 @@ function test_model() CTModels.objective!(pre_ocp, :min; mayer=mayer, lagrange=lagrange) CTModels.definition!(pre_ocp, quote end) CTModels.time_dependence!(pre_ocp; autonomous=true) - display(pre_ocp) + io = IOBuffer() + show(io, MIME"text/plain"(), pre_ocp) end diff --git a/test/test_objective.jl b/test/ocp/test_objective.jl similarity index 100% rename from test/test_objective.jl rename to test/ocp/test_objective.jl diff --git a/test/test_ocp.jl b/test/ocp/test_ocp.jl similarity index 97% rename from test/test_ocp.jl rename to test/ocp/test_ocp.jl index 6b7fc902..d2bfd1c3 100644 --- a/test/test_ocp.jl +++ b/test/ocp/test_ocp.jl @@ -102,8 +102,9 @@ function test_ocp() build_examodel, ) - # print - display(ocp) + # print (captured, no terminal output) + io = IOBuffer() + show(io, MIME"text/plain"(), ocp) # tests on times @test CTModels.initial_time(ocp, [0.0, 10.0]) == 0.0 @@ -309,8 +310,9 @@ function test_ocp() build_examodel, ) - # print - display(ocp) + # print (captured, no terminal output) + io = IOBuffer() + show(io, MIME"text/plain"(), ocp) # tests on objective @test CTModels.objective(ocp) == objective @@ -368,7 +370,8 @@ function test_ocp() definition, build_examodel, ) - display(ocp) + io = IOBuffer() + show(io, MIME"text/plain"(), ocp) # times = CTModels.TimesModel( @@ -393,5 +396,6 @@ function test_ocp() definition, build_examodel, ) - display(ocp) + io = IOBuffer() + show(io, MIME"text/plain"(), ocp) end diff --git a/test/ocp/test_print.jl b/test/ocp/test_print.jl new file mode 100644 index 00000000..bd56e882 --- /dev/null +++ b/test/ocp/test_print.jl @@ -0,0 +1,81 @@ +function test_print() + # TODO: add tests for src/ocp/print.jl. + + # ======================================================================== + # Unit/integration tests – printing PreModel + # ======================================================================== + + Test.@testset "show(PreModel) prints abstract and mathematical definitions" verbose=VERBOSE showtiming=SHOWTIMING begin + pre = CTModels.PreModel() + + # Minimal consistent problem + CTModels.time!(pre; t0=0.0, tf=1.0) + CTModels.state!(pre, 1, "x", ["x"]) + CTModels.control!(pre, 1, "u", ["u"]) + CTModels.variable!(pre, 0) + + dyn!(r, t, x, u, v) = r .= 0 + CTModels.dynamics!(pre, dyn!) + + mayer(x0, xf, v) = 0.0 + lagrange(t, x, u, v) = 0.0 + CTModels.objective!(pre, :min; mayer=mayer, lagrange=lagrange) + + def_expr = quote + t ∈ [0, 1], time + x ∈ R, state + u ∈ R, control + ẋ(t) == u(t) + ∫(0.5u(t)^2) → min + end + CTModels.definition!(pre, def_expr) + CTModels.time_dependence!(pre; autonomous=false) + + io = IOBuffer() + show(io, MIME"text/plain"(), pre) + s = String(take!(io)) + + Test.@test occursin("Abstract definition:", s) + Test.@test occursin("optimal control problem is of the form:", s) + end + + # ======================================================================== + # Integration tests – printing Model + # ======================================================================== + + Test.@testset "show(Model) prints abstract and mathematical definitions" verbose=VERBOSE showtiming=SHOWTIMING begin + pre = CTModels.PreModel() + + CTModels.time!(pre; t0=0.0, tf=1.0) + CTModels.state!(pre, 1, "x", ["x"]) + CTModels.control!(pre, 1, "u", ["u"]) + CTModels.variable!(pre, 0) + + dyn!(r, t, x, u, v) = r .= 0 + CTModels.dynamics!(pre, dyn!) + + mayer(x0, xf, v) = 0.0 + lagrange(t, x, u, v) = 0.0 + CTModels.objective!(pre, :min; mayer=mayer, lagrange=lagrange) + + def_expr = quote + t ∈ [0, 1], time + x ∈ R, state + u ∈ R, control + ẋ(t) == u(t) + ∫(0.5u(t)^2) → min + end + CTModels.definition!(pre, def_expr) + CTModels.time_dependence!(pre; autonomous=false) + + model = CTModels.build(pre) + + io = IOBuffer() + show(io, MIME"text/plain"(), model) + s = String(take!(io)) + + Test.@test occursin("Abstract definition:", s) + Test.@test occursin("optimal control problem is of the form:", s) + end +end + diff --git a/test/test_solution.jl b/test/ocp/test_solution.jl similarity index 100% rename from test/test_solution.jl rename to test/ocp/test_solution.jl diff --git a/test/test_state.jl b/test/ocp/test_state.jl similarity index 89% rename from test/test_state.jl rename to test/ocp/test_state.jl index da68a3c5..1ec931f9 100644 --- a/test/test_state.jl +++ b/test/ocp/test_state.jl @@ -1,13 +1,8 @@ function test_state() # - @test isconcretetype(CTModels.StateModel) # StateModel - state = CTModels.StateModel("y", ["u", "v"]) - @test CTModels.dimension(state) == 2 - @test CTModels.name(state) == "y" - @test CTModels.components(state) == ["u", "v"] # some checks ocp = CTModels.PreModel() diff --git a/test/ocp/test_time_dependence.jl b/test/ocp/test_time_dependence.jl new file mode 100644 index 00000000..04b5d0b8 --- /dev/null +++ b/test/ocp/test_time_dependence.jl @@ -0,0 +1,54 @@ +function test_time_dependence() + # TODO: add tests for src/ocp/time_dependence.jl. + + # ======================================================================== + # Unit tests – time_dependence! and is_autonomous + # ======================================================================== + + Test.@testset "time_dependence! basic behavior" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp = CTModels.PreModel() + + # Initially not set + Test.@test !CTModels.__is_autonomous_set(ocp) + + # Set once + CTModels.time_dependence!(ocp; autonomous=true) + Test.@test CTModels.__is_autonomous_set(ocp) + Test.@test CTModels.is_autonomous(ocp) === true + + # Second call must fail + Test.@test_throws CTBase.UnauthorizedCall CTModels.time_dependence!(ocp; autonomous=false) + end + + # ======================================================================== + # Integration-style tests – fake OCPs with different time dependence + # ======================================================================== + + Test.@testset "fake OCP time dependence flag" verbose=VERBOSE showtiming=SHOWTIMING begin + function build_premodel_with_time_dependence(flag::Bool) + ocp = CTModels.PreModel() + CTModels.time!(ocp; t0=0.0, tf=1.0) + CTModels.state!(ocp, 1) + CTModels.control!(ocp, 1) + CTModels.variable!(ocp, 0) + + dyn!(r, t, x, u, v) = r .= 0 + CTModels.dynamics!(ocp, dyn!) + + mayer(x0, xf, v) = 0.0 + lagrange(t, x, u, v) = 0.0 + CTModels.objective!(ocp, :min; mayer=mayer, lagrange=lagrange) + + CTModels.definition!(ocp, quote end) + CTModels.time_dependence!(ocp; autonomous=flag) + return ocp + end + + pre_autonomous = build_premodel_with_time_dependence(true) + pre_nonautonomous = build_premodel_with_time_dependence(false) + + Test.@test CTModels.is_autonomous(pre_autonomous) === true + Test.@test CTModels.is_autonomous(pre_nonautonomous) === false + end +end + diff --git a/test/test_times.jl b/test/ocp/test_times.jl similarity index 68% rename from test/test_times.jl rename to test/ocp/test_times.jl index bd38f1f3..77a88850 100644 --- a/test/test_times.jl +++ b/test/ocp/test_times.jl @@ -1,3 +1,10 @@ +struct FakeTimeVector{T} <: AbstractVector{T} + data::Vector{T} +end + +Base.length(v::FakeTimeVector) = length(v.data) +Base.getindex(v::FakeTimeVector{T}, i::Int) where {T} = v.data[i] + function test_times() # @@ -87,6 +94,40 @@ function test_times() CTModels.variable!(ocp, 2) @test_throws CTBase.IncorrectArgument CTModels.time!(ocp, t0=0.0, ind0=1) @test_throws CTBase.IncorrectArgument CTModels.time!(ocp, tf=10.0, indf=1) - @test_throws CTBase.IncorrectArgument CTModels.time!(ocp, t0=0.0, tf=10.0, ind0=1) @test_throws CTBase.IncorrectArgument CTModels.time!(ocp, t0=0.0, tf=10.0, indf=1) + + Test.@testset "times: FreeTimeModel with FakeTimeVector" verbose=VERBOSE showtiming=SHOWTIMING begin + ft = CTModels.FreeTimeModel(2, "s") + v_ok = FakeTimeVector([1.0, 3.0]) + @test CTModels.time(ft, v_ok) == 3.0 + + v_short = FakeTimeVector([1.0]) + @test_throws CTBase.IncorrectArgument CTModels.time(ft, v_short) + end + + Test.@testset "times: TimesModel names and flags" verbose=VERBOSE showtiming=SHOWTIMING begin + t0 = CTModels.FixedTimeModel(0.0, "t0") + tf = CTModels.FixedTimeModel(1.0, "tf") + times = CTModels.TimesModel(t0, tf, "t") + + @test CTModels.time_name(times) == "t" + @test CTModels.initial_time_name(times) == "t0" + @test CTModels.final_time_name(times) == "tf" + + @test CTModels.has_fixed_initial_time(times) + @test !CTModels.has_free_initial_time(times) + @test CTModels.has_fixed_final_time(times) + @test !CTModels.has_free_final_time(times) + + tf2 = CTModels.FixedTimeModel(2.0, "tf2") + t0_free = CTModels.FreeTimeModel(1, "v1") + times_free = CTModels.TimesModel(t0_free, tf2, "t") + v = [2.5] + + @test CTModels.initial_time(times_free, v) == 2.5 + @test !CTModels.has_fixed_initial_time(times_free) + @test CTModels.has_free_initial_time(times_free) + @test CTModels.has_fixed_final_time(times_free) + @test !CTModels.has_free_final_time(times_free) + end end diff --git a/test/test_variable.jl b/test/ocp/test_variable.jl similarity index 89% rename from test/test_variable.jl rename to test/ocp/test_variable.jl index 2ffa86f1..1ccfd505 100644 --- a/test/test_variable.jl +++ b/test/ocp/test_variable.jl @@ -1,13 +1,8 @@ function test_variable() # - @test isconcretetype(CTModels.VariableModel) # VariableModel - variable = CTModels.VariableModel("v", ["v₁", "v₂"]) - @test CTModels.dimension(variable) == 2 - @test CTModels.name(variable) == "v" - @test CTModels.components(variable) == ["v₁", "v₂"] # some checks ocp = CTModels.PreModel() diff --git a/test/plot/test_plot.jl b/test/plot/test_plot.jl new file mode 100644 index 00000000..6c5aa9a3 --- /dev/null +++ b/test/plot/test_plot.jl @@ -0,0 +1,525 @@ +using Plots + +struct FakeModelDoPlot{N} <: CTModels.AbstractModel end + +struct FakeSolutionDoPlot{N} <: CTModels.AbstractSolution + ocp::FakeModelDoPlot{N} + pcd +end + +CTModels.dim_path_constraints_nl(::FakeModelDoPlot{N}) where {N} = N +CTModels.model(sol::FakeSolutionDoPlot{N}) where {N} = sol.ocp +CTModels.path_constraints_dual(sol::FakeSolutionDoPlot) = sol.pcd +CTModels.state_dimension(::FakeSolutionDoPlot) = 2 +CTModels.control_dimension(::FakeSolutionDoPlot) = 1 + +function test_plot() + + # Resolve the plotting extension module to access internal helpers. + plots_ext = Base.get_extension(CTModels, :CTModelsPlots) + + # ======================================================================== + # Unit tests – helper logic (no plotting side effects) + # ======================================================================== + + Test.@testset "plot helpers: clean" verbose=VERBOSE showtiming=SHOWTIMING begin + description = (:states, :controls, :costates, :constraint, :cons, :duals, :state) + cleaned = plots_ext.clean(description) + Test.@test Set(cleaned) == Set((:state, :control, :costate, :path, :dual)) + end + + Test.@testset "plot helpers: do_plot" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp, sol, pre_ocp = solution_example() + ocp_pc, sol_pc = solution_example_dual() + + # All descriptions enabled with non-:none styles + desc = (:state, :costate, :control, :path, :dual) + (ps, pc, pu, pp, pd) = plots_ext.do_plot( + sol, + desc...; + state_style=NamedTuple(), + control_style=NamedTuple(), + costate_style=NamedTuple(), + path_style=NamedTuple(), + dual_style=NamedTuple(), + ) + Test.@test ps + Test.@test pc + Test.@test pu + Test.@test pp + Test.@test !pd + + (ps2, pc2, pu2, pp2, pd2) = plots_ext.do_plot( + sol_pc, + desc...; + state_style=NamedTuple(), + control_style=NamedTuple(), + costate_style=NamedTuple(), + path_style=NamedTuple(), + dual_style=NamedTuple(), + ) + Test.@test ps2 + Test.@test pc2 + Test.@test pu2 + Test.@test pp2 + Test.@test pd2 + + # Styles set to :none disable corresponding components + (ps3, pc3, pu3, pp3, pd3) = plots_ext.do_plot( + sol, + :state, + :control, + :path, + :dual; + state_style=:none, + control_style=:none, + costate_style=:none, + path_style=:none, + dual_style=:none, + ) + Test.@test !ps3 + Test.@test !pu3 + Test.@test !pp3 + Test.@test !pd3 + + # Fakes: explicit combinations of path constraints and duals + desc2 = (:state, :costate, :control, :path, :dual) + + # no path constraints, no duals + fake1 = FakeSolutionDoPlot(FakeModelDoPlot{0}(), nothing) + (_, _, _, fp1, fd1) = plots_ext.do_plot( + fake1, + desc2...; + state_style=NamedTuple(), + control_style=NamedTuple(), + costate_style=NamedTuple(), + path_style=NamedTuple(), + dual_style=NamedTuple(), + ) + Test.@test !fp1 + Test.@test !fd1 + + # path constraints present, no duals + fake2 = FakeSolutionDoPlot(FakeModelDoPlot{2}(), nothing) + (_, _, _, fp2, fd2) = plots_ext.do_plot( + fake2, + desc2...; + state_style=NamedTuple(), + control_style=NamedTuple(), + costate_style=NamedTuple(), + path_style=NamedTuple(), + dual_style=NamedTuple(), + ) + Test.@test fp2 + Test.@test !fd2 + + # path constraints present, duals present + fake3 = FakeSolutionDoPlot(FakeModelDoPlot{3}(), (1.0,)) + (_, _, _, fp3, fd3) = plots_ext.do_plot( + fake3, + desc2...; + state_style=NamedTuple(), + control_style=NamedTuple(), + costate_style=NamedTuple(), + path_style=NamedTuple(), + dual_style=NamedTuple(), + ) + Test.@test fp3 + Test.@test fd3 + end + + Test.@testset "plot defaults: scalar helpers" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test plots_ext.__plot_layout() == :split + Test.@test plots_ext.__control_layout() == :components + Test.@test plots_ext.__time_normalization() == :default + Test.@test plots_ext.__plot_style() == NamedTuple() + Test.@test plots_ext.__plot_label_suffix() == "" + end + + Test.@testset "plot defaults: __size_plot – layout=:group" verbose=VERBOSE showtiming=SHOWTIMING begin + fake = FakeSolutionDoPlot(FakeModelDoPlot{0}(), nothing) + desc = (:state, :costate, :control) + + sz_components = plots_ext.__size_plot( + fake, + CTModels.model(fake), + :components, + :group, + desc...; + state_style=NamedTuple(), + control_style=NamedTuple(), + costate_style=NamedTuple(), + path_style=NamedTuple(), + dual_style=NamedTuple(), + ) + Test.@test sz_components == (600, 280) + + sz_all = plots_ext.__size_plot( + fake, + CTModels.model(fake), + :all, + :group, + desc...; + state_style=NamedTuple(), + control_style=NamedTuple(), + costate_style=NamedTuple(), + path_style=NamedTuple(), + dual_style=NamedTuple(), + ) + Test.@test sz_all == (600, 420) + end + + Test.@testset "plot defaults: __size_plot – layout=:split" verbose=VERBOSE showtiming=SHOWTIMING begin + # Only state → 2 lines + fake_state = FakeSolutionDoPlot(FakeModelDoPlot{0}(), nothing) + sz_state = plots_ext.__size_plot( + fake_state, + CTModels.model(fake_state), + :components, + :split, + :state; + state_style=NamedTuple(), + control_style=:none, + costate_style=:none, + path_style=:none, + dual_style=:none, + ) + Test.@test sz_state == (600, 420) + + # Only control norm → 1 line + fake_control = FakeSolutionDoPlot(FakeModelDoPlot{0}(), nothing) + sz_control = plots_ext.__size_plot( + fake_control, + CTModels.model(fake_control), + :norm, + :split, + :control; + state_style=:none, + control_style=NamedTuple(), + costate_style=:none, + path_style=:none, + dual_style=:none, + ) + Test.@test sz_control == (600, 280) + + # State + control + path constraints (nc = 2) → nb_lines > 2 + fake_full = FakeSolutionDoPlot(FakeModelDoPlot{2}(), nothing) + sz_full = plots_ext.__size_plot( + fake_full, + CTModels.model(fake_full), + :components, + :split, + :state, + :control, + :path; + state_style=NamedTuple(), + control_style=NamedTuple(), + costate_style=:none, + path_style=NamedTuple(), + dual_style=:none, + ) + Test.@test sz_full == (600, 140 * 5) # 2 (state) + 1 (control) + 2 (path) + + # Invalid control keyword should throw + Test.@test_throws CTBase.IncorrectArgument plots_ext.__size_plot( + fake_state, + CTModels.model(fake_state), + :wrong_choice, + :split, + :state; + state_style=NamedTuple(), + control_style=NamedTuple(), + costate_style=:none, + path_style=:none, + dual_style=:none, + ) + end + + Test.@testset "plot tree: __plot_tree" verbose=VERBOSE showtiming=SHOWTIMING begin + # Single leaf → one subplot + leaf = plots_ext.PlotLeaf() + p_leaf = plots_ext.__plot_tree(leaf, 0) + Test.@test p_leaf isa Plots.Plot + Test.@test length(p_leaf.subplots) == 1 + + # Row layout with three leaves → three subplots + leaves_row = [plots_ext.PlotLeaf() for _ in 1:3] + node_row = plots_ext.PlotNode(:row, leaves_row) + p_row = plots_ext.__plot_tree(node_row) + Test.@test p_row isa Plots.Plot + Test.@test length(p_row.subplots) == 3 + + # Column layout with EmptyPlot filtered out → one subplot + children_col = [plots_ext.EmptyPlot(), plots_ext.PlotLeaf(), plots_ext.EmptyPlot()] + node_col = plots_ext.PlotNode(:column, children_col) + p_col = plots_ext.__plot_tree(node_col) + Test.@test p_col isa Plots.Plot + Test.@test length(p_col.subplots) == 1 + + # Nested nodes: a row of two columns (one with 2 leaves, one with 1) + col1 = plots_ext.PlotNode(:column, [plots_ext.PlotLeaf(), plots_ext.PlotLeaf()]) + col2 = plots_ext.PlotNode(:column, [plots_ext.PlotLeaf()]) + root = plots_ext.PlotNode(:row, [col1, col2]) + p_nested = plots_ext.__plot_tree(root) + Test.@test p_nested isa Plots.Plot + # At the top level we have at least two column blocks + Test.@test length(p_nested.subplots) ≥ 2 + end + + Test.@testset "plot helpers: do_decorate" verbose=VERBOSE showtiming=SHOWTIMING begin + ocp, sol, pre_ocp = solution_example() + + # No model → nothing is decorated regardless of styles + (dt, dsb, dcb, dpb) = plots_ext.do_decorate( + model=nothing, + time_style=NamedTuple(), + state_bounds_style=NamedTuple(), + control_bounds_style=NamedTuple(), + path_bounds_style=NamedTuple(), + ) + Test.@test !dt + Test.@test !dsb + Test.@test !dcb + Test.@test !dpb + + # With model and non-:none styles → all decorations active + (dt2, dsb2, dcb2, dpb2) = plots_ext.do_decorate( + model=ocp, + time_style=NamedTuple(), + state_bounds_style=NamedTuple(), + control_bounds_style=NamedTuple(), + path_bounds_style=NamedTuple(), + ) + Test.@test dt2 + Test.@test dsb2 + Test.@test dcb2 + Test.@test dpb2 + + # Individual :none styles disable specific decorations + (dt3, dsb3, dcb3, dpb3) = plots_ext.do_decorate( + model=ocp, + time_style=:none, + state_bounds_style=NamedTuple(), + control_bounds_style=:none, + path_bounds_style=NamedTuple(), + ) + Test.@test !dt3 + Test.@test dsb3 + Test.@test !dcb3 + Test.@test dpb3 + end + + Test.@testset "plot helpers: __keep_series_attributes" verbose=VERBOSE showtiming=SHOWTIMING begin + attrs = plots_ext.__keep_series_attributes( + color=:red, + linestyle=:dash, + foo=1, + ) + keys = [kv[1] for kv in attrs] + + # Unknown attributes should be filtered out + Test.@test :foo ∉ keys + + # All returned keys must be known Plots series attributes + series_attrs = Plots.attributes(:Series) + for k in keys + Test.@test k ∈ series_attrs + end + end + + # ======================================================================== + # Integration tests – solution_example (no path constraints) + # ======================================================================== + + ocp, sol, pre_ocp = solution_example() + + Test.@testset "plot(sol) – time keyword" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test plot(sol; time=:default) isa Plots.Plot + Test.@test plot(sol; time=:normalize) isa Plots.Plot + Test.@test plot(sol; time=:normalise) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot(sol; time=:wrong_choice) + end + + Test.@testset "plot(sol) – layout and control options" verbose=VERBOSE showtiming=SHOWTIMING begin + # group layout + Test.@test plot(sol; layout=:group, control=:components) isa Plots.Plot + Test.@test plot(sol; layout=:group, control=:norm) isa Plots.Plot + Test.@test plot(sol; layout=:group, control=:all) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot( + sol; layout=:group, control=:wrong_choice + ) + + # split layout + Test.@test plot(sol; layout=:split, control=:components) isa Plots.Plot + Test.@test plot(sol; layout=:split, control=:norm) isa Plots.Plot + Test.@test plot(sol; layout=:split, control=:all) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot( + sol; layout=:split, control=:wrong_choice + ) + + # layout only + Test.@test plot(sol; layout=:split) isa Plots.Plot + Test.@test plot(sol; layout=:group) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot(sol; layout=:wrong_choice) + end + + Test.@testset "plot!(...) – reuse of plots and time keyword" verbose=VERBOSE showtiming=SHOWTIMING begin + # Start from plot(sol, time=...) + plt = plot(sol; time=:default) + Test.@test plot!(plt, sol; time=:default) isa Plots.Plot + Test.@test plot!(plt, sol; time=:normalize) isa Plots.Plot + Test.@test plot!(plt, sol; time=:normalise) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!( + plt, sol; time=:wrong_choice + ) + + # plot!(sol, ...) variants with implicit current plot + plot(sol; time=:default) + Test.@test plot!(sol; time=:default) isa Plots.Plot + Test.@test plot!(sol; time=:normalize) isa Plots.Plot + Test.@test plot!(sol; time=:normalise) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!(sol; time=:wrong_choice) + + # Start from an empty plot() + plt2 = plot() + Test.@test plot!(plt2, sol; time=:default) isa Plots.Plot + Test.@test plot!(plt2, sol; time=:normalize) isa Plots.Plot + Test.@test plot!(plt2, sol; time=:normalise) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!( + plt2, sol; time=:wrong_choice + ) + end + + Test.@testset "plot!(...) – layout and control options" verbose=VERBOSE showtiming=SHOWTIMING begin + # group layout + plt = plot(sol; layout=:group, control=:components) + Test.@test plot!(plt, sol; layout=:group, control=:components) isa Plots.Plot + Test.@test plot!(plt, sol; layout=:group, control=:norm) isa Plots.Plot + + plt = plot(sol; layout=:group, control=:norm) + Test.@test plot!(plt, sol; layout=:group, control=:components) isa Plots.Plot + Test.@test plot!(plt, sol; layout=:group, control=:norm) isa Plots.Plot + + plt = plot(sol; layout=:group, control=:all) + Test.@test plot!(plt, sol; layout=:group, control=:all) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!( + plt, sol; layout=:group, control=:wrong_choice + ) + + # split layout + plt = plot(sol; layout=:split, control=:components) + Test.@test plot!(plt, sol; layout=:split, control=:components) isa Plots.Plot + Test.@test plot!(plt, sol; layout=:split, control=:norm) isa Plots.Plot + + plt = plot(sol; layout=:split, control=:norm) + Test.@test plot!(plt, sol; layout=:split, control=:components) isa Plots.Plot + Test.@test plot!(plt, sol; layout=:split, control=:norm) isa Plots.Plot + + plt = plot(sol; layout=:split, control=:all) + Test.@test plot!(plt, sol; layout=:split, control=:all) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!( + plt, sol; layout=:split, control=:wrong_choice + ) + + # layout only + plt = plot(sol; layout=:split) + Test.@test plot!(plt, sol; layout=:split) isa Plots.Plot + + plt = plot(sol; layout=:group) + Test.@test plot!(plt, sol; layout=:group) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!( + plt, sol; layout=:wrong_choice + ) + end + + Test.@testset "display(sol) – side effect" verbose=VERBOSE showtiming=SHOWTIMING begin + Test.@test display(sol) isa Nothing + end + + # ======================================================================== + # Integration tests – solution_example_dual (with duals) + # ======================================================================== + + ocp_pc, sol_pc = solution_example_dual() + + Test.@testset "plot(sol with path constraints) – time and layout" verbose=VERBOSE showtiming=SHOWTIMING begin + # time keyword + Test.@test plot(sol_pc; time=:default) isa Plots.Plot + Test.@test plot(sol_pc; time=:normalize) isa Plots.Plot + Test.@test plot(sol_pc; time=:normalise) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot( + sol_pc; time=:wrong_choice + ) + + # layout/control + Test.@test plot(sol_pc; layout=:group, control=:components) isa Plots.Plot + Test.@test plot(sol_pc; layout=:group, control=:norm) isa Plots.Plot + Test.@test plot(sol_pc; layout=:group, control=:all) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot( + sol_pc; layout=:group, control=:wrong_choice + ) + + Test.@test plot(sol_pc; layout=:split, control=:components) isa Plots.Plot + Test.@test plot(sol_pc; layout=:split, control=:norm) isa Plots.Plot + Test.@test plot(sol_pc; layout=:split, control=:all) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot( + sol_pc; layout=:split, control=:wrong_choice + ) + + Test.@test plot(sol_pc; layout=:split) isa Plots.Plot + Test.@test plot(sol_pc; layout=:group) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot( + sol_pc; layout=:wrong_choice + ) + end + + Test.@testset "plot!(sol with path constraints) – layout and time" verbose=VERBOSE showtiming=SHOWTIMING begin + # time keyword + plt = plot(sol_pc; time=:default) + Test.@test plot!(plt, sol_pc; time=:default) isa Plots.Plot + Test.@test plot!(plt, sol_pc; time=:normalize) isa Plots.Plot + Test.@test plot!(plt, sol_pc; time=:normalise) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!( + plt, sol_pc; time=:wrong_choice + ) + + # layout/control + plt = plot(sol_pc; layout=:group, control=:components) + Test.@test plot!(plt, sol_pc; layout=:group, control=:components) isa Plots.Plot + Test.@test plot!(plt, sol_pc; layout=:group, control=:norm) isa Plots.Plot + + plt = plot(sol_pc; layout=:group, control=:norm) + Test.@test plot!(plt, sol_pc; layout=:group, control=:components) isa Plots.Plot + Test.@test plot!(plt, sol_pc; layout=:group, control=:norm) isa Plots.Plot + + plt = plot(sol_pc; layout=:group, control=:all) + Test.@test plot!(plt, sol_pc; layout=:group, control=:all) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!( + plt, sol_pc; layout=:group, control=:wrong_choice + ) + + plt = plot(sol_pc; layout=:split, control=:components) + Test.@test plot!(plt, sol_pc; layout=:split, control=:components) isa Plots.Plot + Test.@test plot!(plt, sol_pc; layout=:split, control=:norm) isa Plots.Plot + + plt = plot(sol_pc; layout=:split, control=:norm) + Test.@test plot!(plt, sol_pc; layout=:split, control=:components) isa Plots.Plot + Test.@test plot!(plt, sol_pc; layout=:split, control=:norm) isa Plots.Plot + + plt = plot(sol_pc; layout=:split, control=:all) + Test.@test plot!(plt, sol_pc; layout=:split, control=:all) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!( + plt, sol_pc; layout=:split, control=:wrong_choice + ) + + plt = plot(sol_pc; layout=:split) + Test.@test plot!(plt, sol_pc; layout=:split) isa Plots.Plot + + plt = plot(sol_pc; layout=:group) + Test.@test plot!(plt, sol_pc; layout=:group) isa Plots.Plot + Test.@test_throws CTBase.IncorrectArgument plot!( + plt, sol_pc; layout=:wrong_choice + ) + end +end + diff --git a/test/problems/beam.jl b/test/problems/beam.jl new file mode 100644 index 00000000..d35fb42c --- /dev/null +++ b/test/problems/beam.jl @@ -0,0 +1,65 @@ +# Beam optimal control problem definition used by tests and examples. +# +# Returns a NamedTuple with fields: +# - ocp :: the CTParser-defined optimal control problem +# - obj :: reference optimal objective value (Ipopt / MadNLP, Collocation) +# - name :: a short problem name +# - init :: NamedTuple of components for CTSolvers.initial_guess +function Beam() + + pre_ocp = CTModels.PreModel() + + CTModels.variable!(pre_ocp, 0) + + CTModels.time!(pre_ocp; t0=0.0, tf=1.0) + + CTModels.state!(pre_ocp, 2) + + CTModels.control!(pre_ocp, 1) + + dynamics!(r, t, x, u, v) = begin + r[1] = x[2] + r[2] = u[1] + return nothing + end + CTModels.dynamics!(pre_ocp, dynamics!) + + lagrange(t, x, u, v) = u[1]^2 + CTModels.objective!(pre_ocp, :min; lagrange=lagrange) + + f_boundary(r, x0, xf, v) = begin + r[1] = x0[1] - 0.0 + r[2] = x0[2] - 1.0 + r[3] = xf[1] - 0.0 + r[4] = xf[2] + 1.0 + return nothing + end + CTModels.constraint!(pre_ocp, :boundary; f=f_boundary, lb=zeros(4), ub=zeros(4), label=:beam_boundary) + + CTModels.constraint!(pre_ocp, :state; rg=1:1, lb=[0.0], ub=[0.1], label=:beam_state_x1) + CTModels.constraint!(pre_ocp, :control; rg=1:1, lb=[-10.0], ub=[10.0], label=:beam_control_u) + + definition = quote + t ∈ [0, 1], time + x ∈ R², state + u ∈ R, control + + x(0) == [0, 1] + x(1) == [0, -1] + 0 ≤ x₁(t) ≤ 0.1 + -10 ≤ u(t) ≤ 10 + + ẋ(t) == [x₂(t), u(t)] + + ∫(u(t)^2) → min + end + CTModels.definition!(pre_ocp, definition) + + CTModels.time_dependence!(pre_ocp; autonomous=true) + + ocp = CTModels.build(pre_ocp) + + init = (state=[0.05, 0.1], control=0.1) + + return (ocp=ocp, obj=8.898598, name="beam", init=init) +end \ No newline at end of file diff --git a/test/problems/elec.jl b/test/problems/elec.jl new file mode 100644 index 00000000..a26e93a3 --- /dev/null +++ b/test/problems/elec.jl @@ -0,0 +1,94 @@ +# Elec benchmark problem definition used by CTSolvers tests. +using Random + +function elec_objective(x, y, z, i, j) + 1.0 / sqrt((x[i] - x[j])^2 + (y[i] - y[j])^2 + (z[i] - z[j])^2) +end +elec_constraint(x, y, z, i) = x[i]^2 + y[i]^2 + z[i]^2 - 1.0 +function elec_objective(x, y, z) + np = length(x) + obj = 0.0 + for i in 1:(np - 1) + for j in (i + 1):np + obj += elec_objective(x, y, z, i, j) + end + end + return obj +end +function elec_constraint(x, y, z) + np = length(x) + return [elec_constraint(x, y, z, i) for i in 1:np] +end +elec_is_minimize() = true + +function Elec(; np::Int=5, seed::Int=2713) + # Set the starting point to a quasi-uniform distribution of electrons on a unit sphere + Random.seed!(seed) + + # Objective: minimize Coulomb potential + function F(vars) + x = vars[1:np] + y = vars[(np + 1):2np] + z = vars[(2np + 1):end] + return elec_objective(x, y, z) + end + + # Constraints: unit-ball constraint for each electron + function c(vars) + x = vars[1:np] + y = vars[(np + 1):2np] + z = vars[(2np + 1):end] + return elec_constraint(x, y, z) + end + + lcon = zeros(np) + ucon = zeros(np) + minimize = elec_is_minimize() + + # Define ADNLPModels builder + function build_adnlp_model(guess::NamedTuple; kwargs...)::ADNLPModels.ADNLPModel + # Convert tuple to flat vector for ADNLPModels + guess_vec = vcat(guess.x, guess.y, guess.z) + return ADNLPModels.ADNLPModel( + F, guess_vec, c, lcon, ucon; minimize=minimize, kwargs... + ) + end + + # Define ExaModels builder + function build_exa_model( + ::Type{BaseType}, guess::NamedTuple; kwargs... + )::ExaModels.ExaModel where {BaseType<:AbstractFloat} + m = ExaModels.ExaCore(BaseType; minimize=minimize, kwargs...) + + x = ExaModels.variable(m, 1:np; start=guess.x) + y = ExaModels.variable(m, 1:np; start=guess.y) + z = ExaModels.variable(m, 1:np; start=guess.z) + + # Coulomb potential objective + itr = [(i, j) for i in 1:(np - 1) for j in (i + 1):np] + ExaModels.objective(m, sum(elec_objective(x, y, z, i, j) for (i, j) in itr)) + + # Unit-ball constraints + ExaModels.constraint(m, elec_constraint(x, y, z, i) for i in 1:np) + + return ExaModels.ExaModel(m) + end + + prob = OptimizationProblem( + CTModels.ADNLPModelBuilder(build_adnlp_model), + CTModels.ExaModelBuilder(build_exa_model), + ADNLPSolutionBuilder(), + ExaSolutionBuilder(), + ) + + theta = (2π) .* rand(np) + phi = π .* rand(np) + x_init = [cos(theta[i]) * sin(phi[i]) for i in 1:np] + y_init = [sin(theta[i]) * sin(phi[i]) for i in 1:np] + z_init = [cos(phi[i]) for i in 1:np] + init = (x=x_init, y=y_init, z=z_init) + + sol = missing + + return (prob=prob, init=init, sol=sol) +end \ No newline at end of file diff --git a/test/problems/max1minusx2.jl b/test/problems/max1minusx2.jl new file mode 100644 index 00000000..62fc5ab1 --- /dev/null +++ b/test/problems/max1minusx2.jl @@ -0,0 +1,54 @@ +# Simple 1D maximization problem: max f(x) = 1 - x^2 + +function max1minusx2_objective(x) + return 1.0 - x[1]^2 +end + +function max1minusx2_constraint(x) + return x[1] +end + +function max1minusx2_is_minimize() + return false +end + +function Max1MinusX2() + # define common functions + F(x) = max1minusx2_objective(x) + c(x) = max1minusx2_constraint(x) # unconstrained problem are not working with MadNCL + lcon = [-5.0] + ucon = [5.0] + minimize = max1minusx2_is_minimize() + + # ADNLPModels builder: simple equality-constrained problem + function build_adnlp_model( + initial_guess::AbstractVector; kwargs... + )::ADNLPModels.ADNLPModel + return ADNLPModels.ADNLPModel( + F, initial_guess, c, lcon, ucon; minimize=minimize, kwargs... + ) + end + + # ExaModels builder: same equality constraint + function build_exa_model( + ::Type{BaseType}, initial_guess::AbstractVector; kwargs... + )::ExaModels.ExaModel where {BaseType<:AbstractFloat} + m = ExaModels.ExaCore(BaseType; minimize=minimize, kwargs...) + x = ExaModels.variable(m, length(initial_guess); start=initial_guess) + ExaModels.objective(m, F(x)) + ExaModels.constraint(m, c(x); lcon=lcon, ucon=ucon) + return ExaModels.ExaModel(m) + end + + prob = OptimizationProblem( + CTModels.ADNLPModelBuilder(build_adnlp_model), + CTModels.ExaModelBuilder(build_exa_model), + ADNLPSolutionBuilder(), + ExaSolutionBuilder(), + ) + + init = [2.0] + sol = [0.0] + + return (prob=prob, init=init, sol=sol) +end \ No newline at end of file diff --git a/test/problems/problems_definition.jl b/test/problems/problems_definition.jl new file mode 100644 index 00000000..cd687dca --- /dev/null +++ b/test/problems/problems_definition.jl @@ -0,0 +1,40 @@ +# Helper optimization problem and solution-builder types used by benchmark test problems. +# Helper types +abstract type AbstractNLPSolutionBuilder <: CTModels.AbstractSolutionBuilder end +struct ADNLPSolutionBuilder <: AbstractNLPSolutionBuilder end +struct ExaSolutionBuilder <: AbstractNLPSolutionBuilder end + +# +struct OptimizationProblem <: CTModels.AbstractOptimizationProblem + build_adnlp_model::CTModels.ADNLPModelBuilder + build_exa_model::CTModels.ExaModelBuilder + adnlp_solution_builder::ADNLPSolutionBuilder + exa_solution_builder::ExaSolutionBuilder +end + +function CTModels.get_adnlp_model_builder(prob::OptimizationProblem) + return prob.build_adnlp_model +end + +function CTModels.get_exa_model_builder(prob::OptimizationProblem) + return prob.build_exa_model +end + +function (builder::ADNLPSolutionBuilder)(nlp_solution::SolverCore.AbstractExecutionStats) + return nlp_solution +end + +function (builder::ExaSolutionBuilder)(nlp_solution::SolverCore.AbstractExecutionStats) + return nlp_solution +end + +function CTModels.get_adnlp_solution_builder(prob::OptimizationProblem) + return prob.adnlp_solution_builder +end + +function CTModels.get_exa_solution_builder(prob::OptimizationProblem) + return prob.exa_solution_builder +end + +# +struct DummyProblem <: CTModels.AbstractOptimizationProblem end \ No newline at end of file diff --git a/test/problems/rosenbrock.jl b/test/problems/rosenbrock.jl new file mode 100644 index 00000000..a6005a72 --- /dev/null +++ b/test/problems/rosenbrock.jl @@ -0,0 +1,50 @@ +# Rosenbrock benchmark problem definition used by CTSolvers tests. +function rosenbrock_objective(x) + return (x[1] - 1.0)^2 + 100*(x[2] - x[1]^2)^2 +end +function rosenbrock_constraint(x) + return x[1] +end +function rosenbrock_is_minimize() + return true +end + +function Rosenbrock() + # define common functions + F(x) = rosenbrock_objective(x) + c(x) = rosenbrock_constraint(x) + lcon = [-Inf] + ucon = [10.0] + minimize = rosenbrock_is_minimize() + + # define ADNLPModels builder + function build_adnlp_model( + initial_guess::AbstractVector; kwargs... + )::ADNLPModels.ADNLPModel + return ADNLPModels.ADNLPModel( + F, initial_guess, c, lcon, ucon; minimize=minimize, kwargs... + ) + end + + # define ExaModels builder + function build_exa_model( + ::Type{BaseType}, initial_guess::AbstractVector; kwargs... + )::ExaModels.ExaModel where {BaseType<:AbstractFloat} + m = ExaModels.ExaCore(BaseType; minimize=minimize, kwargs...) + x = ExaModels.variable(m, length(initial_guess); start=initial_guess) + ExaModels.objective(m, F(x)) + ExaModels.constraint(m, c(x); lcon=lcon, ucon=ucon) + return ExaModels.ExaModel(m) + end + + prob = OptimizationProblem( + CTModels.ADNLPModelBuilder(build_adnlp_model), + CTModels.ExaModelBuilder(build_exa_model), + ADNLPSolutionBuilder(), + ExaSolutionBuilder(), + ) + init = [-1.2; 1.0] + sol = [1.0; 1.0] + + return (prob=prob, init=init, sol=sol) +end \ No newline at end of file diff --git a/test/solution_example.jl b/test/problems/solution_example.jl similarity index 100% rename from test/solution_example.jl rename to test/problems/solution_example.jl diff --git a/test/problems/solution_example_dual.jl b/test/problems/solution_example_dual.jl new file mode 100644 index 00000000..eb5a1aec --- /dev/null +++ b/test/problems/solution_example_dual.jl @@ -0,0 +1,139 @@ +function solution_example_dual() + t0 = 0 + tf = 1 + x0 = -1 + + # the model (explicit CTModels.PreModel construction) + function OCP(t0, tf, x0) + + pre_ocp = CTModels.PreModel() + + # No variables + CTModels.variable!(pre_ocp, 0) + + # Time, state, control + CTModels.time!(pre_ocp; t0=t0, tf=tf) + CTModels.state!(pre_ocp, 1) + CTModels.control!(pre_ocp, 1) + + # Dynamics: ẋ(t) == u(t) + dynamics!(r, t, x, u, v) = begin + r[1] = u[1] + return nothing + end + CTModels.dynamics!(pre_ocp, dynamics!) + + # Objective: ∫(-u(t)) → min + lagrange(t, x, u, v) = -u[1] + CTModels.objective!(pre_ocp, :min; lagrange=lagrange) + + # Boundary constraint: x(t0) == x0 (label: initial_con) + f_initial(r, x0_state, xf, v) = begin + r[1] = x0_state[1] - x0 + return nothing + end + CTModels.constraint!( + pre_ocp, + :boundary; + f=f_initial, + lb=[0.0], + ub=[0.0], + label=:initial_con, + ) + + # Control box constraint: 0 ≤ u(t) ≤ +Inf (label: u_con) + CTModels.constraint!( + pre_ocp, + :control; + rg=1:1, + lb=[0.0], + ub=[Inf], + label=:u_con, + ) + + # Path constraint: -Inf ≤ x(t) + u(t) ≤ 0 + f_path1(r, t, x, u, v) = begin + r[1] = x[1] + u[1] + return nothing + end + CTModels.constraint!( + pre_ocp, + :path; + f=f_path1, + lb=[-Inf], + ub=[0.0], + ) + + # Path constraint: [-3, 1] ≤ [x(t)+1, u(t)+1] ≤ [1, 2.5] (label: 2) + f_path2(r, t, x, u, v) = begin + r[1] = x[1] + 1 + r[2] = u[1] + 1 + return nothing + end + CTModels.constraint!( + pre_ocp, + :path; + f=f_path2, + lb=[-3.0, 1.0], + ub=[1.0, 2.5], + label=:con2, + ) + + # Keep a DSL-style definition expression for printing only + definition = quote + t ∈ [t0, tf], time + x ∈ R, state + u ∈ R, control + x(t0) == x0, (initial_con) + 0 ≤ u(t) ≤ +Inf, (u_con) + -Inf ≤ x(t) + u(t) ≤ 0 + [-3, 1] ≤ [x(t) + 1, u(t) + 1] ≤ [1, 2.5], (2) + ẋ(t) == u(t) + ∫(-u(t)) → min + end + CTModels.definition!(pre_ocp, definition) + + # Non-autonomous (matches the original DSL semantics) + CTModels.time_dependence!(pre_ocp; autonomous=false) + + ocp = CTModels.build(pre_ocp) + return ocp + end + + # the solution + function SOL(ocp, t0, tf) + x(t) = -exp(-t) + p(t) = exp(t-1) - 1 + u(t) = -x(t) + objective = exp(-1) - 1 + v = Float64[] + + # + path_constraints_dual(t) = [-(p(t)+1), 0, t] + + # + times = range(t0, tf, 201) + sol = CTModels.build_solution( + ocp, + Vector{Float64}(times), + x, + u, + v, + p; + objective=objective, + iterations=-1, + constraints_violation=0.0, + message="", + status=:optimal, + successful=true, + path_constraints_dual=path_constraints_dual, + ) + + return sol + end + + ocp = OCP(t0, tf, x0) + sol = SOL(ocp, t0, tf) + + return ocp, sol +end diff --git a/test/runtests.jl b/test/runtests.jl index 96490220..220728ad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,57 +2,215 @@ using Test using Aqua using CTBase using CTModels +using ADNLPModels +using SolverCore +using NLPModels +using ExaModels +using OrderedCollections: OrderedDict -# -include("solution_example.jl") +# Tests parameters +const VERBOSE = true +const SHOWTIMING = true # -@testset verbose = true showtiming = true "CTModels tests" begin - for name in ( - :ext_exceptions, - :aqua, - :times, - :control, - :state, - :variable, - :dynamics, - :objective, - :constraints, - :model, - :ocp, - :init, - :utils, - :solution, +include(joinpath("problems", "solution_example.jl")) +include(joinpath("problems", "problems_definition.jl")) +include(joinpath("problems", "rosenbrock.jl")) +include(joinpath("problems", "max1minusx2.jl")) +include(joinpath("problems", "elec.jl")) +include(joinpath("problems", "beam.jl")) +include(joinpath("problems", "solution_example_dual.jl")) + +# ---------------------------------------------------------------------------# +# Test selection infrastructure (aligned with CTSolvers) +# ---------------------------------------------------------------------------# + +function default_tests() + return OrderedDict( + # Extension exceptions, before any extensions are triggered + :notrigger => OrderedDict( + :ext_exceptions => true, + ), + + # Meta / quality tests + :meta => OrderedDict( + :aqua => true, + :CTModels => true, + ), + + # Tests in test/ocp + :ocp => OrderedDict( + :times => true, + :time_dependence => true, + :state => true, + :control => true, + :variable => true, + :dynamics => true, + :objective => true, + :constraints => true, + :definition => true, + :model => true, + :ocp => true, + :dual_model => true, + :print => true, + :solution => true, + ), + + # Core utilities and type-level tests in test/core + :core => OrderedDict( + :utils => true, + :default => true, + :types => true, + :ocp_components => true, + :ocp_model_types => true, + :ocp_solution_types => true, + :nlp_types => true, + :initial_guess_types => true, + ), + + # Tests in test/nlp + :nlp => OrderedDict( + :problem_core => true, + :options_schema => true, + :nlp_backends => true, + :discretized_ocp => true, + :model_api => true, + ), + + # Tests in test/init + :init => OrderedDict( + :initial_guess => true, + ), + + # IO-related tests in test/io + :io => OrderedDict( + :export_import => true, + ), + + # Plot-related tests in test/plot + :plot => OrderedDict( + :plot => true, + ), ) - @testset "$(name)" begin - test_name = Symbol(:test_, name) - println("testing: ", string(name)) - include("$(test_name).jl") - @eval $test_name() +end + +const TEST_SELECTIONS = isempty(ARGS) ? Symbol[] : Symbol.(ARGS) + +const TEST_GROUP_INFO = Dict( + :notrigger => (title="Extension exceptions", subdir="io"), + :meta => (title="Meta / quality", subdir="meta"), + :ocp => (title="OCP continuous-time layer", subdir="ocp"), + :core => (title="Core utilities and types", subdir="core"), + :nlp => (title="NLP / backends / discretized OCP", subdir="nlp"), + :init => (title="Initial guess", subdir="init"), + :io => (title="IO / export / import", subdir="io"), + :plot => (title="Plotting", subdir="plot"), +) + +function selected_tests() + tests = default_tests() + sels = TEST_SELECTIONS + + # No selection: default configuration + if isempty(sels) + return tests + end + + # Single :all selection: enable everything + if length(sels) == 1 && sels[1] == :all + for (_, group_tests) in tests + for k in keys(group_tests) + group_tests[k] = true + end end + return tests end -end -# test with CTDirect and CTParser: must be commented if new version of CTModels, that is breaking + # Otherwise start with everything disabled + for (_, group_tests) in tests + for k in keys(group_tests) + group_tests[k] = false + end + end -using CTDirect -using NLPModelsIpopt -using ADNLPModels -import CTParser: CTParser, @def + # Apply each selector + for sel in sels + # :all mixed with others -> just enable everything and stop + if sel == :all + for (_, group_tests) in tests + for k in keys(group_tests) + group_tests[k] = true + end + end + break + end -# -include("solution_example_path_constraints.jl") + # sel = group key (e.g. :meta, :ocp, :nlp, :io, :plot, ...) + if haskey(tests, sel) + for k in keys(tests[sel]) + tests[sel][k] = true + end + continue + end -@testset verbose = true showtiming = true "CTModels tests" begin - for name in ( - :plot, - # :export_import, - ) - @testset "$(name)" begin - test_name = Symbol(:test_, name) - println("testing: ", string(name)) - include("$(test_name).jl") - @eval $test_name() + # sel = leaf key (e.g. :times, :nlp_backends, :plot, ...) + for (_, group_tests) in tests + if haskey(group_tests, sel) + group_tests[sel] = true + break + end + end + end + + return tests +end + +const SELECTED_TESTS = selected_tests() + +function run_test_group(group::Symbol, tests::OrderedDict{Symbol,Bool}) + any(values(tests)) || return nothing + info = TEST_GROUP_INFO[group] + title = info.title + subdir = info.subdir + println("========== $(title) tests ==========") + @testset "$(title)" verbose=VERBOSE showtiming=SHOWTIMING begin + for (name, enabled) in tests + enabled || continue + @testset "$(name)" verbose=VERBOSE showtiming=SHOWTIMING begin + test_name = Symbol(:test_, name) + println("testing: ", string(name)) + include(joinpath(subdir, string(test_name, ".jl"))) + @eval $test_name() + end end end + println("✓ $(title) tests passed\n") +end + +for (group, tests) in SELECTED_TESTS + run_test_group(group, tests) end + +# test with CTDirect and CTParser: must be commented if new version of CTModels, that is breaking + +# using CTDirect +# using NLPModelsIpopt +# using ADNLPModels +# import CTParser: CTParser, @def + +# # +# include(joinpath("problems", "solution_example_dual.jl")) + +# @testset verbose=VERBOSE showtiming=SHOWTIMING "CTModels tests" begin +# for name in ( +# :plot, +# # :export_import, +# ) +# @testset "$(name)" begin +# test_name = Symbol(:test_, name) +# println("testing: ", string(name)) +# include(_testfile_path(name)) +# @eval $test_name() +# end +# end +# end diff --git a/test/solution_example_path_constraints.jl b/test/solution_example_path_constraints.jl deleted file mode 100644 index 7ae71cdc..00000000 --- a/test/solution_example_path_constraints.jl +++ /dev/null @@ -1,59 +0,0 @@ -function solution_example_path_constraints() - t0 = 0 - tf = 1 - x0 = -1 - - # the model - function OCP(t0, tf, x0) - @def ocp begin - t ∈ [t0, tf], time - x ∈ R, state - u ∈ R, control - x(t0) == x0, (initial_con) - 0 ≤ u(t) ≤ +Inf, (u_con) - -Inf ≤ x(t) + u(t) ≤ 0 - [-3, 1] ≤ [x(t) + 1, u(t) + 1] ≤ [1, 2.5], (2) - ẋ(t) == u(t) - ∫(-u(t)) → min - end; - - return ocp - end - - # the solution - function SOL(ocp, t0, tf) - x(t) = -exp(-t) - p(t) = exp(t-1) - 1 - u(t) = -x(t) - objective = exp(-1) - 1 - v = Float64[] - - # - path_constraints_dual(t) = [-(p(t)+1), 0, t] - - # - times = range(t0, tf, 201) - sol = CTModels.build_solution( - ocp, - Vector{Float64}(times), - x, - u, - v, - p; - objective=objective, - iterations=-1, - constraints_violation=0.0, - message="", - status=:optimal, - successful=true, - path_constraints_dual=path_constraints_dual, - ) - - return sol - end - - ocp = OCP(t0, tf, x0) - sol = SOL(ocp, t0, tf) - - return ocp, sol -end diff --git a/test/test_export_import.jl b/test/test_export_import.jl deleted file mode 100644 index ec8cc5b6..00000000 --- a/test/test_export_import.jl +++ /dev/null @@ -1,70 +0,0 @@ -using JLD2 -using JSON3 - -function remove_if_exists(filename::String) - isfile(filename) && rm(filename) -end - -function test_export_import() - - # - ocp, sol = solution_example() - - # JSON - CTModels.export_ocp_solution(sol; filename="solution_test", format=:JSON) - sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_test", format=:JSON) - @test sol.objective ≈ sol_reloaded.objective atol=1e-8 - @test CTModels.objective(sol) ≈ CTModels.objective(sol_reloaded) atol=1e-8 - - # JLD - CTModels.export_ocp_solution(sol; filename="solution_test") # default is :JLD) - sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_test", format=:JLD) - @test sol.objective ≈ sol_reloaded.objective atol=1e-8 - @test CTModels.objective(sol) ≈ CTModels.objective(sol_reloaded) atol=1e-8 - - # - ocp, sol = solution_example(; fun=true) - - # JSON - CTModels.export_ocp_solution(sol; filename="solution_test", format=:JSON) - sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_test", format=:JSON) - @test sol.objective ≈ sol_reloaded.objective atol=1e-8 - @test CTModels.objective(sol) ≈ CTModels.objective(sol_reloaded) atol=1e-8 - - # JLD - CTModels.export_ocp_solution(sol; filename="solution_test", format=:JLD) - sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_test", format=:JLD) - @test sol.objective ≈ sol_reloaded.objective atol=1e-8 - @test CTModels.objective(sol) ≈ CTModels.objective(sol_reloaded) atol=1e-8 - - # -------------------------------------------------------------------------------------- - # Other problem - ocp = @def begin - t ∈ [0, 1], time - x ∈ R², state - u ∈ R, control - x₂(t) ≤ 1.2 - x(0) == [-1, 0] - x(1) == [0, 0] - ẋ(t) == [x₂(t), u(t)] - ∫(0.5u(t)^2) → min - end; - - sol = CTDirect.solve(ocp) - - # JLD - CTModels.export_ocp_solution(sol; filename="solution_test") - sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_test") - @test sol.objective ≈ sol_reloaded.objective atol=1e-8 - @test CTModels.objective(sol) ≈ CTModels.objective(sol_reloaded) atol=1e-8 - - # JSON - CTModels.export_ocp_solution(sol; filename="solution_test", format=:JSON) - sol_reloaded = CTModels.import_ocp_solution(ocp; filename="solution_test", format=:JSON) - @test sol.objective ≈ sol_reloaded.objective atol=1e-8 - @test CTModels.objective(sol) ≈ CTModels.objective(sol_reloaded) atol=1e-8 - - # Cleanup - remove_if_exists("solution_test.jld2") - remove_if_exists("solution_test.json") -end diff --git a/test/test_init.jl b/test/test_init.jl deleted file mode 100644 index 7cfe5ad9..00000000 --- a/test/test_init.jl +++ /dev/null @@ -1,84 +0,0 @@ -function test_init() - - # test checkDim function - @test_throws Exception CTModels.checkDim(1, 2) - - # test isaVectVect function - # Return true if argument is a vector of vectors - @test CTModels.isaVectVect([[1, 2], [3, 4]]) - @test !CTModels.isaVectVect([1, 2, 3, 4]) - - # test formatData function - # Convert matrix to vector of vectors (could be expanded) - @test CTModels.formatData([[1, 2], [3, 4]]) == [[1, 2], [3, 4]] - @test CTModels.formatData([1, 2, 3, 4]) == [1, 2, 3, 4] - - # test formatTimeGrid function - # Convert matrix time-grid to vector - @test CTModels.formatTimeGrid([1, 2, 3, 4]) == [1, 2, 3, 4] - @test CTModels.formatTimeGrid(nothing) === nothing - @test CTModels.formatTimeGrid([[1, 2]; [3, 4]]) == [1, 2, 3, 4] - @test CTModels.formatTimeGrid([[1, 2], [3, 4]]) == [[1, 2], [3, 4]] - - # test buildFunctionalInit function - # Build functional initialization: default case - @test CTModels.buildFunctionalInit(nothing, range(0, 1, 11), 2)(0) === nothing - - # Build functional initialization: function case - @test CTModels.buildFunctionalInit(t -> [t, t^2], range(0, 1, 11), 2)(0) == [0, 0] - @test_throws Exception CTModels.buildFunctionalInit(t -> [t, t^2], range(0, 1, 11), 1)( - 0 - ) - - # test buildFunctionalInit function: general interpolation case - # Build functional initialization: general interpolation case - @test CTModels.buildFunctionalInit( - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], range(0, 1, 11), 1 - )( - 0 - ) == 0 - - # construction of Init - - # constant initial guess - x_const = [0.5, 0.2] - u_const = 0.5 - v_const = 0.15 - - # functional initial guess - x_func = t -> [t^2, sqrt(t)] - u_func = t -> (cos(10 * t) + 1) * 0.5 - - # interpolated initial guess - x_vec = [[0, 0], [1, 2], [5, -1]] - x_matrix = [0 0; 1 2; 5 -1] - u_vec = [0, 0.3, 0.1] - t_vec = [0, 0.1, 0.2] - - init = (state=x_const,) - @test CTModels.Init(init; state_dim=2, control_dim=1, variable_dim=1) isa CTModels.Init - - init = (state=x_const, control=u_const) - @test CTModels.Init(init; state_dim=2, control_dim=1, variable_dim=1) isa CTModels.Init - - init = (state=x_const, control=u_const, variable=v_const) - @test CTModels.Init(init; state_dim=2, control_dim=1, variable_dim=1) isa CTModels.Init - - init = (state=x_func,) - @test CTModels.Init(init; state_dim=2, control_dim=1, variable_dim=1) isa CTModels.Init - - init = (state=x_func, control=u_func) - @test CTModels.Init(init; state_dim=2, control_dim=1, variable_dim=1) isa CTModels.Init - - init = (state=x_func, control=u_func, variable=v_const) - @test CTModels.Init(init; state_dim=2, control_dim=1, variable_dim=1) isa CTModels.Init - - init = (time=t_vec, state=x_vec) - @test CTModels.Init(init; state_dim=2, control_dim=1, variable_dim=1) isa CTModels.Init - - init = (time=t_vec, state=x_vec, control=u_vec) - @test CTModels.Init(init; state_dim=2, control_dim=1, variable_dim=1) isa CTModels.Init - - init = (time=t_vec, state=x_matrix, control=u_vec) - @test CTModels.Init(init; state_dim=2, control_dim=1, variable_dim=1) isa CTModels.Init -end diff --git a/test/test_plot.jl b/test/test_plot.jl deleted file mode 100644 index 1f711b17..00000000 --- a/test/test_plot.jl +++ /dev/null @@ -1,238 +0,0 @@ -using Plots - -function test_plot() - - # - ocp, sol, pre_ocp = solution_example() - - # - @test plot(sol; time=:default) isa Plots.Plot - @test plot(sol; time=:normalize) isa Plots.Plot - @test plot(sol; time=:normalise) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; time=:wrong_choice) - - # - @test plot(sol; layout=:group, control=:components) isa Plots.Plot - @test plot(sol; layout=:group, control=:norm) isa Plots.Plot - @test plot(sol; layout=:group, control=:all) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; layout=:group, control=:wrong_choice) - - # - @test plot(sol; layout=:split, control=:components) isa Plots.Plot - @test plot(sol; layout=:split, control=:norm) isa Plots.Plot - @test plot(sol; layout=:split, control=:all) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; layout=:split, control=:wrong_choice) - - # - @test plot(sol; layout=:split) isa Plots.Plot - @test plot(sol; layout=:group) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; layout=:wrong_choice) - - # - plt = plot(sol; time=:default) - @test plot!(plt, sol; time=:default) isa Plots.Plot - @test plot!(plt, sol; time=:normalize) isa Plots.Plot - @test plot!(plt, sol; time=:normalise) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot!(plt, sol; time=:wrong_choice) - - plot(sol; time=:default) - @test plot!(sol; time=:default) isa Plots.Plot - @test plot!(sol; time=:normalize) isa Plots.Plot - @test plot!(sol; time=:normalise) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot!(sol; time=:wrong_choice) - - plt = plot() - @test plot!(plt, sol; time=:default) isa Plots.Plot - @test plot!(plt, sol; time=:normalize) isa Plots.Plot - @test plot!(plt, sol; time=:normalise) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot!(plt, sol; time=:wrong_choice) - - plot() - @test plot!(sol; time=:default) isa Plots.Plot - @test plot!(sol; time=:normalize) isa Plots.Plot - @test plot!(sol; time=:normalise) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot!(sol; time=:wrong_choice) - - # - plt = plot(sol; layout=:group, control=:components) - @test plot!(plt, sol; layout=:group, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:group, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:group, control=:norm) - @test plot!(plt, sol; layout=:group, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:group, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:group, control=:all) - @test plot!(plt, sol; layout=:group, control=:all) isa Plots.Plot - - @test_throws CTBase.IncorrectArgument plot!( - plt, sol; layout=:group, control=:wrong_choice - ) - - # - plt = plot(sol; layout=:split, control=:components) - @test plot!(plt, sol; layout=:split, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:split, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:split, control=:norm) - @test plot!(plt, sol; layout=:split, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:split, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:split, control=:all) - @test plot!(plt, sol; layout=:split, control=:all) isa Plots.Plot - - @test_throws CTBase.IncorrectArgument plot!( - plt, sol; layout=:split, control=:wrong_choice - ) - - # - plt = plot(sol; layout=:split) - @test plot!(plt, sol; layout=:split) isa Plots.Plot - - plt = plot(sol; layout=:group) - @test plot!(plt, sol; layout=:group) isa Plots.Plot - - @test_throws CTBase.IncorrectArgument plot!(plt, sol; layout=:wrong_choice) - - # - @test display(sol) isa Nothing - - # -------------------------------------------------------- - # -------------------------------------------------------- - # other example with path constraints - ocp, sol = solution_example_path_constraints() - - # - @test plot(sol; time=:default) isa Plots.Plot - @test plot(sol; time=:normalize) isa Plots.Plot - @test plot(sol; time=:normalise) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; time=:wrong_choice) - - @test plot(sol; time=:default) isa Plots.Plot - @test plot(sol; time=:normalize) isa Plots.Plot - @test plot(sol; time=:normalise) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; time=:wrong_choice) - - # - @test plot(sol; layout=:group, control=:components) isa Plots.Plot - @test plot(sol; layout=:group, control=:norm) isa Plots.Plot - @test plot(sol; layout=:group, control=:all) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; layout=:group, control=:wrong_choice) - - @test plot(sol; layout=:group, control=:components) isa Plots.Plot - @test plot(sol; layout=:group, control=:norm) isa Plots.Plot - @test plot(sol; layout=:group, control=:all) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; layout=:group, control=:wrong_choice) - - # - @test plot(sol; layout=:split, control=:components) isa Plots.Plot - @test plot(sol; layout=:split, control=:norm) isa Plots.Plot - @test plot(sol; layout=:split, control=:all) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; layout=:split, control=:wrong_choice) - - @test plot(sol; layout=:split, control=:components) isa Plots.Plot - @test plot(sol; layout=:split, control=:norm) isa Plots.Plot - @test plot(sol; layout=:split, control=:all) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; layout=:split, control=:wrong_choice) - - # - @test plot(sol; layout=:split) isa Plots.Plot - @test plot(sol; layout=:group) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; layout=:wrong_choice) - - @test plot(sol; layout=:split) isa Plots.Plot - @test plot(sol; layout=:group) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot(sol; layout=:wrong_choice) - - # - plt = plot(sol; time=:default) - @test plot!(plt, sol; time=:default) isa Plots.Plot - @test plot!(plt, sol; time=:normalize) isa Plots.Plot - @test plot!(plt, sol; time=:normalise) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot!(plt, sol; time=:wrong_choice) - - plot(sol; time=:default) - @test plot!(sol; time=:default) isa Plots.Plot - @test plot!(sol; time=:normalize) isa Plots.Plot - @test plot!(sol; time=:normalise) isa Plots.Plot - @test_throws CTBase.IncorrectArgument plot!(sol; time=:wrong_choice) - - # - plt = plot(sol; layout=:group, control=:components) - @test plot!(plt, sol; layout=:group, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:group, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:group, control=:norm) - @test plot!(plt, sol; layout=:group, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:group, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:group, control=:all) - @test plot!(plt, sol; layout=:group, control=:all) isa Plots.Plot - - @test_throws CTBase.IncorrectArgument plot!( - plt, sol; layout=:group, control=:wrong_choice - ) - - plt = plot(sol; layout=:group, control=:components) - @test plot!(plt, sol; layout=:group, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:group, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:group, control=:norm) - @test plot!(plt, sol; layout=:group, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:group, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:group, control=:all) - @test plot!(plt, sol; layout=:group, control=:all) isa Plots.Plot - - @test_throws CTBase.IncorrectArgument plot!( - plt, sol; layout=:group, control=:wrong_choice - ) - - # - plt = plot(sol; layout=:split, control=:components) - @test plot!(plt, sol; layout=:split, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:split, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:split, control=:norm) - @test plot!(plt, sol; layout=:split, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:split, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:split, control=:all) - @test plot!(plt, sol; layout=:split, control=:all) isa Plots.Plot - - @test_throws CTBase.IncorrectArgument plot!( - plt, sol; layout=:split, control=:wrong_choice - ) - - plt = plot(sol; layout=:split, control=:components) - @test plot!(plt, sol; layout=:split, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:split, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:split, control=:norm) - @test plot!(plt, sol; layout=:split, control=:components) isa Plots.Plot - @test plot!(plt, sol; layout=:split, control=:norm) isa Plots.Plot - - plt = plot(sol; layout=:split, control=:all) - @test plot!(plt, sol; layout=:split, control=:all) isa Plots.Plot - - @test_throws CTBase.IncorrectArgument plot!( - plt, sol; layout=:split, control=:wrong_choice - ) - - # - plt = plot(sol; layout=:split) - @test plot!(plt, sol; layout=:split) isa Plots.Plot - - plt = plot(sol; layout=:group) - @test plot!(plt, sol; layout=:group) isa Plots.Plot - - @test_throws CTBase.IncorrectArgument plot!(plt, sol; layout=:wrong_choice) - - plt = plot(sol; layout=:split) - @test plot!(plt, sol; layout=:split) isa Plots.Plot - - plt = plot(sol; layout=:group) - @test plot!(plt, sol; layout=:group) isa Plots.Plot - - @test_throws CTBase.IncorrectArgument plot!(plt, sol; layout=:wrong_choice) -end