Skip to content

Commit 7d02099

Browse files
authored
Merge pull request #205 from control-toolbox/dynamics-exa
Dynamics exa
2 parents ffc5c3b + 7f4364b commit 7d02099

9 files changed

Lines changed: 243 additions & 40 deletions

File tree

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
name = "CTParser"
22
uuid = "32681960-a1b1-40db-9bff-a1ca817385d1"
3-
version = "0.8.2-beta.1"
3+
version = "0.8.2-beta.2"
44
authors = ["Jean-Baptiste Caillau <jean-baptiste.caillau@univ-cotedazur.fr>"]
55

66
[deps]
77
CTBase = "54762871-cc72-4466-b8e8-f6c8b58076cd"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
9+
ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
911
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
1012
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1113
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
@@ -14,6 +16,7 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
1416
[compat]
1517
CTBase = "0.16, 0.17"
1618
DocStringExtensions = "0.9"
19+
ExaModels = "0.9.3"
1720
MLStyle = "0.4"
1821
OrderedCollections = "1"
1922
Parameters = "0.12"

src/CTParser.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using Unicode
2222
# sources
2323
include("defaults.jl")
2424
include("utils.jl")
25+
include("exa_linalg.jl")
2526
include("onepass.jl")
2627
include("initial_guess.jl")
2728

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ These specialized methods properly handle Null nodes, preserving constants
1010
and avoiding unnecessary expression tree nodes.
1111
1212
# Public API (Exported)
13-
- Canonical nodes: `zero`, `one`
13+
- Canonical nodes: `zero`, `one`, `zeros`, `ones`
1414
- Basic operations: `zero`, `one`, `adjoint`, `transpose`, `*`, `+`, `-`, `sum`
1515
- Linear algebra: `dot`, `det`, `tr`, `norm`, `diag`, `diagm`
1616
"""
@@ -24,10 +24,10 @@ import Base: inv, abs, sqrt, cbrt, abs2, exp, exp2, exp10, log, log2, log10, log
2424
import Base: sin, cos, tan, csc, sec, cot, asin, acos, atan, acot
2525
import Base: sind, cosd, tand, cscd, secd, cotd, atand, acotd
2626
import Base: sinh, cosh, tanh, csch, sech, coth, asinh, acosh, atanh, acoth
27-
import Base: ^
27+
import Base: ^, zeros, ones
2828
import LinearAlgebra: dot, Adjoint, det, tr, norm, diag, diagm
2929

30-
export zero, one, adjoint, transpose, *, +, -, sum, dot, det, tr, norm, diag, diagm
30+
export zero, one, zeros, ones, adjoint, transpose, *, +, -, sum, dot, det, tr, norm, diag, diagm
3131
export inv, abs, sqrt, cbrt, abs2, exp, exp2, exp10, log, log2, log10, log1p
3232
export sin, cos, tan, csc, sec, cot, asin, acos, atan, acot
3333
export sind, cosd, tand, cscd, secd, cotd, atand, acotd
@@ -75,6 +75,26 @@ Return the canonical one AbstractNode: Null(1).
7575
"""
7676
one(::ExaModels.AbstractNode) = ExaModels.Null(1)
7777

78+
"""
79+
zeros(::Type{T}, dims...) where {T <: ExaModels.AbstractNode}
80+
81+
Create an array of AbstractNode zeros with the specified dimensions.
82+
Uses fill with the canonical zero node: Null(0).
83+
"""
84+
function zeros(::Type{T}, dims::Integer...) where {T <: ExaModels.AbstractNode}
85+
return fill(zero(T), dims...)
86+
end
87+
88+
"""
89+
ones(::Type{T}, dims...) where {T <: ExaModels.AbstractNode}
90+
91+
Create an array of AbstractNode ones with the specified dimensions.
92+
Uses fill with the canonical one node: Null(1).
93+
"""
94+
function ones(::Type{T}, dims::Integer...) where {T <: ExaModels.AbstractNode}
95+
return fill(one(T), dims...)
96+
end
97+
7898
# ============================================================================
7999
# Section 2: Scalar Operations on Null Nodes
80100
# ============================================================================

src/onepass.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -709,10 +709,10 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
709709
quote
710710
length($e1) == length($e3) || throw("wrong bound dimension") # (vs. __throw) since raised at runtime
711711
if length($e1) == 1
712-
$pref.constraint($p_ocp, $e2; lcon=($e1[1]), ucon=($e3[1])) # debug: add _denull
712+
$pref.constraint($p_ocp, $e2; lcon=($e1[1]), ucon=($e3[1])) # todo: add _denull
713713
else
714714
for $l 1:length($e1)
715-
$pref.constraint($p_ocp, $e2[$l]; lcon=($e1[$l]), ucon=($e3[$l])) # debug: add _denull
715+
$pref.constraint($p_ocp, $e2[$l]; lcon=($e1[$l]), ucon=($e3[$l])) # todo: add _denull
716716
end
717717
end
718718
end
@@ -822,10 +822,10 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
822822
quote
823823
length($e1) == length($e3) || throw("wrong bound dimension") # (vs. __throw) since raised at runtime
824824
if length($e1) == 1
825-
$pref.constraint($p_ocp, $e2 for $j in 0:grid_size; lcon=($e1[1]), ucon=($e3[1])) # debug: add _denull
825+
$pref.constraint($p_ocp, $e2 for $j in 0:grid_size; lcon=($e1[1]), ucon=($e3[1])) # todo: add _denull
826826
else
827827
for $l 1:length($e1)
828-
$pref.constraint($p_ocp, $e2[$l] for $j in 0:grid_size; lcon=($e1[$l]), ucon=($e3[$l])) # debug: add _denull
828+
$pref.constraint($p_ocp, $e2[$l] for $j in 0:grid_size; lcon=($e1[$l]), ucon=($e3[$l])) # todo: add _denull
829829
end
830830
end
831831
end
@@ -901,14 +901,14 @@ function p_dynamics_exa!(p, p_ocp, x, t, e)
901901
code = quote
902902
for $i 1:$(p.dim_x)
903903
$(p.dyn_con)[$i] = if scheme == :euler # dyn_con already defined outside try catch
904-
$pref.constraint($p_ocp, $dxj[$i] - $(p.dt) * $ej1[$i] for $j1 in 0:grid_size-1) # debug: add _denull
904+
$pref.constraint($p_ocp, $dxj[$i] - $(p.dt) * $ej1[$i] for $j1 in 0:grid_size-1) # todo: add _denull
905905
elseif scheme (:euler_implicit, :euler_b) # euler_b is deprecated
906-
$pref.constraint($p_ocp, $dxj[$i] - $(p.dt) * $ej2[$i] for $j1 in 0:grid_size-1) # debug: add _denull
906+
$pref.constraint($p_ocp, $dxj[$i] - $(p.dt) * $ej2[$i] for $j1 in 0:grid_size-1) # todo: add _denull
907907
elseif scheme == :midpoint
908-
$pref.constraint($p_ocp, $dxj[$i] - $(p.dt) * $ej12[$i] for $j1 in 0:grid_size-1) # debug: add _denull
908+
$pref.constraint($p_ocp, $dxj[$i] - $(p.dt) * $ej12[$i] for $j1 in 0:grid_size-1) # todo: add _denull
909909
elseif scheme (:trapeze, :trapezoidal) # trapezoidal is deprecated
910910
$pref.constraint(
911-
$p_ocp, $dxj[$i] - $(p.dt) * ($ej1[$i] + $ej2[$i]) / 2 for $j1 in 0:grid_size-1 # debug: add _denull
911+
$p_ocp, $dxj[$i] - $(p.dt) * ($ej1[$i] + $ej2[$i]) / 2 for $j1 in 0:grid_size-1 # todo: add _denull
912912
)
913913
else
914914
throw(
@@ -960,7 +960,7 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e)
960960
return __throw("dynamics coordinate $i should be an integer", p.lnum, p.line)
961961
end
962962

963-
function p_dynamics_coord_exa!(p, p_ocp, x, i::Integer, t, e) # debug: also also add coord = range for :exa
963+
function p_dynamics_coord_exa!(p, p_ocp, x, i::Integer, t, e) # todo: also also add coord = range for :exa
964964
pref = prefix_exa()
965965
i p.dyn_coords && return __throw("dynamics coordinate $i already defined", p.lnum, p.line)
966966
append!(p.dyn_coords, i)
@@ -989,14 +989,14 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i::Integer, t, e) # debug: also also
989989
dxij = :($(p.x)[$i, $j2] - $(p.x)[$i, $j1])
990990
code = quote
991991
$(p.dyn_con)[$i] = if scheme == :euler # dyn_con already defined outside try catch
992-
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:grid_size-1) # debug: add _denull
992+
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:grid_size-1) # todo: add _denull
993993
elseif scheme (:euler_implicit, :euler_b) # euler_b is deprecated
994-
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:grid_size-1) # debug: add _denull
994+
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:grid_size-1) # todo: add _denull
995995
elseif scheme == :midpoint
996-
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:grid_size-1) # debug: add _denull
996+
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:grid_size-1) # todo: add _denull
997997
elseif scheme (:trapeze, :trapezoidal) # trapezoidal is deprecated
998998
$pref.constraint(
999-
$p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:grid_size-1 # debug: add _denull
999+
$p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:grid_size-1 # todo: add _denull
10001000
)
10011001
else
10021002
throw(
@@ -1057,14 +1057,14 @@ function p_lagrange_exa!(p, p_ocp, e, type)
10571057
ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt)))
10581058
code = quote
10591059
if scheme == :euler
1060-
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:grid_size-1) # debug: add _denull
1060+
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:grid_size-1) # todo: add _denull
10611061
elseif scheme (:euler_implicit, :euler_b) # euler_b is deprecated
1062-
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size) # debug: add _denull
1062+
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size) # todo: add _denull
10631063
elseif scheme == :midpoint
1064-
$pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:grid_size-1) # debug: add _denull
1064+
$pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:grid_size-1) # todo: add _denull
10651065
elseif scheme (:trapeze, :trapezoidal) # trapezoidal is deprecated
1066-
$pref.objective($p_ocp, $(p.dt) * $ej1 / 2 for $j1 in (0, grid_size)) # debug: add _denull
1067-
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size-1) # debug: add _denull
1066+
$pref.objective($p_ocp, $(p.dt) * $ej1 / 2 for $j1 in (0, grid_size)) # todo: add _denull
1067+
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size-1) # todo: add _denull
10681068
else
10691069
throw(
10701070
"unknown numerical scheme: $scheme (possible choices are :euler, :euler_implicit, :midpoint, :trapeze)",
@@ -1118,7 +1118,7 @@ function p_mayer_exa!(p, p_ocp, e, type)
11181118
e = subs2(e, xf, p.x, :grid_size)
11191119
e = subs(e, xf, :([$(p.x)[$k, grid_size] for $k 1:$(p.dim_x)]))
11201120
# now, x[i](t0) has been replaced by x[i, 0] and x[i](tf) by x[i, grid_size]
1121-
code = :($pref.objective($p_ocp, $e)) # debug: add _denull
1121+
code = :($pref.objective($p_ocp, $e)) # todo: add _denull
11221122
return __wrap(code, p.lnum, p.line)
11231123
end
11241124

@@ -1436,8 +1436,8 @@ function def_exa(e; log=false)
14361436
$p_ocp = $pref.ExaCore(
14371437
base_type; backend=backend, minimize=($p.criterion == :min) # not $(p.xxxx) as this info is known statically
14381438
)
1439-
# _denull(e) = e # debug
1440-
# _denull(e::$pref.Null{T}) where {T<:Real} = e.value # debug: Null must be imported by OC.jl
1439+
# _denull(e) = e # todo
1440+
# _denull(e::$pref.Null{T}) where {T<:Real} = e.value # todo: Null must be imported by OC.jl
14411441
$code
14421442
$dyn_check
14431443
return $pref.ExaModel($p_ocp), $getter

test/runtests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ using Interpolations
5454
using NLPModels
5555
using LinearAlgebra
5656

57-
include("exa_linalg.jl")
58-
5957
macro ignore(e)
6058
return :()
6159
end

test/test_aqua.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ function test_aqua()
22
@testset "Aqua.jl" begin
33
Aqua.test_all(
44
CTParser;
5-
ambiguities=false,
5+
ambiguities=true, # also tests submodules
66
#stale_deps=(ignore=[:MLStyle],),
77
deps_compat=(ignore=[:LinearAlgebra, :Unicode],),
88
piracies=true,
99
)
10-
# do not warn about ambiguities in dependencies
11-
Aqua.test_ambiguities(CTParser)
10+
# Test ExaLinAlg submodule for type piracy
11+
#@testset "ExaLinAlg piracy" begin
12+
# Aqua.test_piracies(CTParser.ExaLinAlg)
13+
#end
1214
end
1315
end

test/test_dynamics_exa.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# test vector-form dynamics for exa backend
22
# Tests for p_dynamics_exa! which allows defining all state dynamics in one expression: ∂(x)(t) == [e1, e2, ...]
33

4-
using .ExaLinAlg # Load ExaLinAlg module for linear algebra operations on ExaModels.AbstractNode arrays
5-
64
activate_backend(:exa)
75

86
# Mock up of CTDirect.discretise for tests

test/test_exa_linalg.jl

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# Pure unit tests for ExaModels linear algebra extensions
33
# No dependencies on CTParser - only ExaModels and LinearAlgebra
44

5-
using .ExaLinAlg
6-
75
# Helper to check if a Null node represents zero
86
is_null_zero(x::ExaModels.Null) = iszero(x.value)
97
is_null_zero(x::ExaModels.AbstractNode) = false
@@ -678,6 +676,87 @@ function test_exa_linalg()
678676
@test isone(o.value)
679677
@test o.value == 1 # Canonical one is Null(1)
680678
end
679+
680+
@testset "zeros and ones array creation" begin
681+
# Test zeros with different dimensions
682+
z1 = zeros(ExaModels.AbstractNode, 3)
683+
@test length(z1) == 3
684+
@test z1 isa Vector{<:ExaModels.AbstractNode}
685+
@test all(is_null_zero.(z1))
686+
@test all(x -> x isa ExaModels.Null, z1)
687+
688+
z2 = zeros(ExaModels.AbstractNode, 2, 3)
689+
@test size(z2) == (2, 3)
690+
@test z2 isa Matrix{<:ExaModels.AbstractNode}
691+
@test all(is_null_zero.(z2))
692+
693+
z3 = zeros(ExaModels.AbstractNode, 2, 2, 2)
694+
@test size(z3) == (2, 2, 2)
695+
@test z3 isa Array{<:ExaModels.AbstractNode, 3}
696+
@test all(is_null_zero.(z3))
697+
698+
# Test ones with different dimensions
699+
o1 = ones(ExaModels.AbstractNode, 3)
700+
@test length(o1) == 3
701+
@test o1 isa Vector{<:ExaModels.AbstractNode}
702+
@test all(x -> x isa ExaModels.Null && isone(x.value), o1)
703+
704+
o2 = ones(ExaModels.AbstractNode, 2, 3)
705+
@test size(o2) == (2, 3)
706+
@test o2 isa Matrix{<:ExaModels.AbstractNode}
707+
@test all(x -> x isa ExaModels.Null && isone(x.value), o2)
708+
709+
o3 = ones(ExaModels.AbstractNode, 2, 2, 2)
710+
@test size(o3) == (2, 2, 2)
711+
@test o3 isa Array{<:ExaModels.AbstractNode, 3}
712+
@test all(x -> x isa ExaModels.Null && isone(x.value), o3)
713+
714+
# Test with specific ExaModels types (Variable <: AbstractNode)
715+
# Note: Variable is an alias/subtype, zeros/ones with AbstractNode works for all subtypes
716+
z_var = zeros(ExaModels.AbstractNode, 3)
717+
@test length(z_var) == 3
718+
@test eltype(z_var) <: ExaModels.AbstractNode
719+
@test all(is_null_zero.(z_var))
720+
721+
o_var = ones(ExaModels.AbstractNode, 2, 2)
722+
@test size(o_var) == (2, 2)
723+
@test eltype(o_var) <: ExaModels.AbstractNode
724+
@test all(x -> x isa ExaModels.Null && isone(x.value), o_var)
725+
end
726+
727+
@testset "zeros and ones in operations" begin
728+
x, y, z, w = create_nodes()
729+
730+
# Test operations with zeros array
731+
z_vec = zeros(ExaModels.AbstractNode, 3)
732+
vec_nodes = [x, y, z]
733+
734+
# Addition with zeros
735+
result1 = vec_nodes + z_vec
736+
@test result1[1].value == x.value
737+
@test result1[2].value == y.value
738+
@test result1[3].value == z.value
739+
740+
# Multiplication with zeros (all zeros)
741+
result2 = [ExaModels.Null(2), ExaModels.Null(3), ExaModels.Null(4)] .* z_vec
742+
@test all(is_null_zero.(result2))
743+
744+
# Test operations with ones array
745+
o_vec = ones(ExaModels.AbstractNode, 3)
746+
747+
# Multiplication with ones
748+
result3 = vec_nodes .* o_vec
749+
@test result3[1].value == x.value
750+
@test result3[2].value == y.value
751+
@test result3[3].value == z.value
752+
753+
# Matrix-vector with identity-like structure
754+
I_like = diagm(ones(ExaModels.AbstractNode, 2))
755+
vec2 = [x, y]
756+
result4 = I_like * vec2
757+
@test result4[1].value == x.value
758+
@test result4[2].value == y.value
759+
end
681760
end
682761

683762
@testset "Zero multiplication optimizations" begin

0 commit comments

Comments
 (0)