Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/initial_guess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ The function walks through the expression `ex` and splits it into
and substituted into subsequent statements at parse-time using `subs`;
- *initialisation specifications* of the form `lhs := rhs` or
`lhs(arg) := rhs`, which are converted into structured specification
tuples after alias expansion.
tuples after alias and cross-spec expansion.

For expressions of the form `lhs(arg) := rhs`, this function uses `has(rhs, arg)`
to determine whether `arg` appears in the right-hand side. This information
Expand All @@ -289,6 +289,19 @@ Alias substitution happens before each statement is matched, enabling
time-dependent aliases like `phi = 2π * t` and accumulated aliases like
`a = t; s = a`.

After each specification is matched, its (pattern, rhs) pair is stored
for *cross-spec substitution*: subsequent specifications can reference
previous ones in their right-hand side. Substitution is applied on the
RHS only (post-matching) to avoid corrupting the LHS. The already-substituted
RHS is stored, so transitive references are resolved automatically.

All combinations are supported:
- temporal → temporal: `q(t) := sin(t); v(t) := 1.0 + q(t)`
- constant → temporal: `c := 0.1; v(t) := c + sin(t)`
- constant → constant: `a := 1.0; b := a + 2.0`

Order matters: a spec can only reference specs that appear before it.

# Arguments

- `ex::Any`: expression or block coming from the body of `@init`.
Expand All @@ -305,6 +318,7 @@ time-dependent aliases like `phi = 2π * t` and accumulated aliases like
"""
function _collect_init_specs(ex, lnum::Int, line_str::String)
aliases = OrderedCollections.OrderedDict{Union{Symbol,Expr},Any}()
spec_subs = Tuple{Any,Any}[] # substitution list: (pattern, rhs) for cross-spec references
keys = Symbol[] # keys of the NamedTuple (q, v, x, u, tf, ...)
specs = Tuple[] # specification tuples

Expand Down Expand Up @@ -336,20 +350,37 @@ function _collect_init_specs(ex, lnum::Int, line_str::String)
:($lhs($arg) := $rhs) => begin
lhs isa Symbol || error("Unsupported left-hand side in @init: $lhs")

# Apply cross-spec substitutions on rhs only (post-matching)
for (pattern, replacement) in spec_subs
rhs = subs(rhs, pattern, replacement)
end

# Check if arg appears in rhs using has() from utils.jl
# Note: if arg is not a Symbol (e.g., after alias expansion to a literal array),
# has() will return false, which is correct for grid specifications
arg_in_rhs = (arg isa Symbol) && has(rhs, arg)

push!(keys, lhs)
push!(specs, (:temporal, arg, rhs, arg_in_rhs))

# Store for cross-spec substitution in subsequent specs
push!(spec_subs, (Expr(:call, lhs, arg), rhs))
end

# Constant / variable form: lhs := rhs
:($lhs := $rhs) => begin
lhs isa Symbol || error("Unsupported left-hand side in @init: $lhs")

# Apply cross-spec substitutions on rhs only (post-matching)
for (pattern, replacement) in spec_subs
rhs = subs(rhs, pattern, replacement)
end

push!(keys, lhs)
push!(specs, (:constant, rhs))

# Store for cross-spec substitution in subsequent specs
push!(spec_subs, (lhs, rhs))
end

# Fallback: strict mode - reject unrecognized statements
Expand Down
134 changes: 134 additions & 0 deletions test/test_initial_guess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -843,4 +843,138 @@ function test_initial_guess() # debug
end
end
end

# ============================================================================
# Tests for cross-spec substitution
# ============================================================================

@testset "cross-spec: temporal → temporal (basic)" begin
ig = @init ocp_fixed begin
q(t) := sin(t)
v(t) := 1.0 + q(t)
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_fixed, ig)

xfun = CTModels.state(ig)
for τ in (0.0, 0.5, 1.0)
x = xfun(τ)
@test x[1] ≈ sin(τ)
@test x[2] ≈ 1.0 + sin(τ)
end
end

@testset "cross-spec: temporal → temporal → temporal (transitive chain)" begin
ig = @init ocp_fixed begin
q(t) := sin(t)
v(t) := 1.0 + q(t)
u(t) := t + v(t)^2
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_fixed, ig)

xfun = CTModels.state(ig)
ufun = CTModels.control(ig)
for τ in (0.0, 0.5, 1.0)
x = xfun(τ)
u = ufun(τ)
@test x[1] ≈ sin(τ)
@test x[2] ≈ 1.0 + sin(τ)
@test u ≈ τ + (1.0 + sin(τ))^2
end
end

@testset "cross-spec: constant → temporal" begin
ig = @init ocp_fixed begin
q := -1.0
v(t) := q + sin(t)
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_fixed, ig)

xfun = CTModels.state(ig)
for τ in (0.0, 0.5, 1.0)
x = xfun(τ)
@test x[2] ≈ -1.0 + sin(τ)
end
end

@testset "cross-spec: constant → constant" begin
ig = @init ocp_var2 begin
tf := 1.0
a := tf + 0.5
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_var2, ig)

v = CTModels.variable(ig)
@test v[1] ≈ 1.0
@test v[2] ≈ 1.5
end

@testset "cross-spec: mixed aliases and cross-spec refs" begin
ig = @init ocp_fixed begin
A = 2.0
q(t) := A * sin(t)
v(t) := q(t) + 1.0
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_fixed, ig)

xfun = CTModels.state(ig)
for τ in (0.0, 0.5, 1.0)
x = xfun(τ)
@test x[1] ≈ 2.0 * sin(τ)
@test x[2] ≈ 2.0 * sin(τ) + 1.0
end
end

@testset "cross-spec: custom time name s" begin
ocp_s = @def begin
s ∈ [0, 1], time
x = (q, v) ∈ R², state
u ∈ R, control
x(0) == [-1, 0]
x(1) == [0, 0]
ẋ(s) == [v(s), u(s)]
∫(0.5u(s)^2) → min
end

ig = @init ocp_s begin
q(s) := sin(s)
v(s) := 1.0 + q(s)
u(s) := s + v(s)
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_s, ig)

xfun = CTModels.state(ig)
ufun = CTModels.control(ig)
for τ in (0.0, 0.5, 1.0)
x = xfun(τ)
u = ufun(τ)
@test x[1] ≈ sin(τ)
@test x[2] ≈ 1.0 + sin(τ)
@test u ≈ τ + 1.0 + sin(τ)
end
end

@testset "cross-spec: no substitution across different args (grid spec)" begin
T = [0.0, 0.5, 1.0]
Dq = [0.0, 0.5, 1.0]

ig = @init ocp_fixed begin
q(T) := Dq
v(t) := 1.0
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_fixed, ig)
end
end
Loading