Skip to content

Commit 8f7b2dc

Browse files
authored
Merge pull request #201 from control-toolbox/test-before-exa-linalg
Vectorisation bis
2 parents 5aad902 + bcd782a commit 8f7b2dc

12 files changed

Lines changed: 1456 additions & 182 deletions

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ Manifest.toml
2828

2929
# Local reports (analysis, status reports, previews) should not be tracked
3030
reports/
31+
32+
# claude
33+
CLAUDE.local.md

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "CTParser"
22
uuid = "32681960-a1b1-40db-9bff-a1ca817385d1"
3-
version = "0.8.0"
3+
version = "0.8.1"
44
authors = ["Jean-Baptiste Caillau <jean-baptiste.caillau@univ-cotedazur.fr>"]
55

66
[deps]

src/onepass.jl

Lines changed: 54 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,6 @@ function e_prefix!(p)
9696
return nothing
9797
end
9898

99-
# Utils
100-
101-
"""
102-
$(TYPEDSIGNATURES)
103-
104-
Generate a fresh symbol by concatenating the given components and a
105-
`gensym()` suffix.
106-
107-
This is used throughout the parser to create unique internal names that
108-
do not collide with user-defined identifiers.
109-
"""
110-
__symgen(s...) = Symbol(s..., gensym())
111-
11299
"""
113100
$(TYPEDEF)
114101
@@ -191,37 +178,14 @@ case of an exception, prints the originating line number and source
191178
text before rethrowing.
192179
"""
193180
__wrap(e, n, line) = quote
194-
local ex
195181
try
196182
$e
197-
catch ex
183+
catch
198184
println("Line ", $n, ": ", $line)
199-
throw(ex)
185+
rethrow()
200186
end
201187
end
202188

203-
"""
204-
$(TYPEDSIGNATURES)
205-
206-
Return `true` if `x` represents a range.
207-
208-
This predicate is specialised for `AbstractRange` values and for
209-
expressions of the form `i:j` or `i:p:j`.
210-
"""
211-
is_range(x) = false
212-
is_range(x::T) where {T<:AbstractRange} = true
213-
is_range(x::Expr) = (x.head == :call) && (x.args[1] == :(:))
214-
215-
"""
216-
$(TYPEDSIGNATURES)
217-
218-
Return `x` itself if it is a range, or a one-element array `[x]`.
219-
220-
This is a normalisation helper used when interpreting constraint
221-
indices.
222-
"""
223-
as_range(x) = is_range(x) ? x : [x]
224-
225189
# Main code
226190

227191
"""
@@ -580,7 +544,10 @@ function p_state_exa!(p, p_ocp, x, n, xx; components_names=nothing)
580544
))
581545
code = __wrap(code, p.lnum, p.line)
582546
dyn_con = Symbol(:dyn_con, x) # name for the constraints associated with the dynamics
583-
code = :($x = $code; $dyn_con = Vector{$pref.Constraint}(undef, $n)) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope)
547+
code = quote
548+
$x = $code
549+
$dyn_con = Vector{$pref.Constraint}(undef, $n) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope)
550+
end
584551
return code
585552
end
586553

@@ -696,7 +663,7 @@ function p_constraint_fun!(p, p_ocp, e1, e2, e3, c_type, label)
696663
(:variable_range, rg) => :($pref.constraint!(
697664
$p_ocp, :variable; rg=($rg), lb=($e1), ub=($e3), label=($llabel)
698665
))
699-
:state_fun || control_fun || :mixed => begin # now all treated as path
666+
:state_fun || :control_fun || :mixed => begin # now all treated as path
700667
fun = __symgen(:fun)
701668
xt = __symgen(:xt)
702669
ut = __symgen(:ut)
@@ -727,17 +694,20 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
727694
code = :(length($e1) == length($e3) == 1 || throw("this constraint must be scalar")) # (vs. __throw) since raised at runtime
728695
x0 = __symgen(:x0)
729696
xf = __symgen(:xf)
697+
k = __symgen(:k)
730698
e2 = replace_call(e2, p.x, p.t0, x0)
731699
e2 = replace_call(e2, p.x, p.tf, xf)
732700
e2 = subs2(e2, x0, p.x, 0)
701+
e2 = subs(e2, x0, :([$(p.x)[$k, 0] for $k 1:$(p.dim_x)]))
733702
e2 = subs2(e2, xf, p.x, :grid_size)
734-
concat(code, :($pref.constraint($p_ocp, $e2; lcon=($e1), ucon=($e3))))
703+
e2 = subs(e2, xf, :([$(p.x)[$k, grid_size] for $k 1:$(p.dim_x)]))
704+
concat(code, :($pref.constraint($p_ocp, $e2; lcon=($e1[1]), ucon=($e3[1])))) # todo: e1/3[1] will be e1/3[k] when vectorised over dim
735705
end
736706
(:initial, rg) => begin
737707
if isnothing(rg)
738708
rg = :(1:($(p.dim_x))) # x(t0) implies rg == nothing but means x[1:p.dim_x](t0)
739709
e2 = subs(e2, p.x, :($(p.x)[$rg]))
740-
elseif !is_range(rg)
710+
else
741711
rg = as_range(rg)
742712
end
743713
code = :(
@@ -756,8 +726,8 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
756726
if isnothing(rg)
757727
rg = :(1:($(p.dim_x)))
758728
e2 = subs(e2, p.x, :($(p.x)[$rg]))
759-
elseif !is_range(rg)
760-
rg = as_range(rg)
729+
else
730+
rg = as_range(rg) # case rg = i (vs i:j or i:p:j)
761731
end
762732
code = :(
763733
length($e1) == length($e3) == length($rg) || throw("wrong bound dimension")
@@ -775,8 +745,8 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
775745
if isnothing(rg)
776746
rg = :(1:($(p.dim_v)))
777747
e2 = subs(e2, p.v, :($(p.v)[$rg]))
778-
elseif !is_range(rg)
779-
rg = as_range(rg)
748+
else
749+
rg = as_range(rg) # case rg = i (vs i:j or i:p:j)
780750
end
781751
code_box = :(
782752
length($e1) == length($e3) == length($rg) || throw("wrong bound dimension")
@@ -791,10 +761,9 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
791761
end
792762
(:state_range, rg) => begin
793763
if isnothing(rg)
794-
rg = :(1:($(p.dim_x)))
795-
e2 = subs(e2, p.x, :($(p.x)[$rg]))
796-
elseif !is_range(rg)
797-
rg = as_range(rg)
764+
rg = :(1:($(p.dim_x))) # NB. no need to update e2 (unused) here
765+
else
766+
rg = as_range(rg) # case rg = i (vs i:j or i:p:j)
798767
end
799768
code_box = :(
800769
length($e1) == length($e3) == length($rg) || throw("wrong bound dimension")
@@ -809,10 +778,9 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
809778
end
810779
(:control_range, rg) => begin
811780
if isnothing(rg)
812-
rg = :(1:($(p.dim_u)))
813-
e2 = subs(e2, p.u, :($(p.u)[$rg]))
814-
elseif !is_range(rg)
815-
rg = as_range(rg)
781+
rg = :(1:($(p.dim_u))) # NB. no need to update e2 (unused here)
782+
else
783+
rg = as_range(rg) # case rg = i (vs i:j or i:p:j)
816784
end
817785
code_box = :(
818786
length($e1) == length($e3) == length($rg) || throw("wrong bound dimension")
@@ -825,19 +793,22 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
825793
p.box_u = concat(p.box_u, code_box) # not __wrapped since contains definition of l_u/u_u
826794
:()
827795
end
828-
:state_fun || control_fun || :mixed => begin
796+
:state_fun || :control_fun || :mixed => begin
829797
code = :(length($e1) == length($e3) == 1 || throw("this constraint must be scalar")) # (vs. __throw) since raised at runtime
830798
xt = __symgen(:xt)
831799
ut = __symgen(:ut)
832800
e2 = replace_call(e2, [p.x, p.u], p.t, [xt, ut])
833801
j = __symgen(:j)
802+
k = __symgen(:k)
834803
e2 = subs2(e2, xt, p.x, j)
804+
e2 = subs(e2, xt, :([$(p.x)[$k, $j] for $k 1:$(p.dim_x)]))
835805
e2 = subs2(e2, ut, p.u, j)
806+
e2 = subs(e2, ut, :([$(p.u)[$k, $j] for $k 1:$(p.dim_u)]))
836807
e2 = subs(e2, p.t, :($(p.t0) + $j * $(p.dt)))
837808
concat(
838809
code,
839810
:($pref.constraint(
840-
$p_ocp, $e2 for $j in 0:grid_size; lcon=($e1), ucon=($e3)
811+
$p_ocp, $e2 for $j in 0:grid_size; lcon=($e1[1]), ucon=($e3[1])
841812
)),
842813
)
843814
end
@@ -931,26 +902,33 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e)
931902
j1 = __symgen(:j)
932903
j2 = :($j1 + 1)
933904
j12 = :($j1 + 0.5)
905+
k = __symgen(:k)
934906
ej1 = subs2(e, xt, p.x, j1)
907+
ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k 1:$(p.dim_x)]))
935908
ej1 = subs2(ej1, ut, p.u, j1)
909+
ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
936910
ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt)))
937911
ej2 = subs2(e, xt, p.x, j2)
912+
ej2 = subs(ej2, xt, :([$(p.x)[$k, $j2] for $k 1:$(p.dim_x)]))
938913
ej2 = subs2(ej2, ut, p.u, j2)
914+
ej2 = subs(ej2, ut, :([$(p.u)[$k, $j2] for $k 1:$(p.dim_u)]))
939915
ej2 = subs(ej2, p.t, :($(p.t0) + $j2 * $(p.dt)))
940-
ej12 = subs5(e, xt, p.x, j1)
916+
ej12 = subs2m(e, xt, p.x, j1)
917+
ej12 = subs(ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k 1:$(p.dim_x)]))
941918
ej12 = subs2(ej12, ut, p.u, j1)
919+
ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
942920
ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt)))
943921
dxij = :($(p.x)[$i, $j2] - $(p.x)[$i, $j1])
944922
code = quote
945923
if scheme == :euler
946-
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:(grid_size - 1))
924+
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:grid_size-1)
947925
elseif scheme (:euler_implicit, :euler_b) # euler_b is deprecated
948-
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:(grid_size - 1))
926+
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:grid_size-1)
949927
elseif scheme == :midpoint
950-
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:(grid_size - 1))
928+
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:grid_size-1)
951929
elseif scheme (:trapeze, :trapezoidal) # trapezoidal is deprecated
952930
$pref.constraint(
953-
$p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:(grid_size - 1)
931+
$p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:grid_size-1
954932
)
955933
else
956934
throw(
@@ -1001,22 +979,27 @@ function p_lagrange_exa!(p, p_ocp, e, type)
1001979
j1 = __symgen(:j)
1002980
j2 = :($j1 + 1)
1003981
j12 = :($j1 + 0.5)
982+
k = __symgen(:k)
1004983
ej1 = subs2(e, xt, p.x, j1)
984+
ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k 1:$(p.dim_x)]))
1005985
ej1 = subs2(ej1, ut, p.u, j1)
986+
ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
1006987
ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt)))
1007-
ej12 = subs5(e, xt, p.x, j1)
988+
ej12 = subs2m(e, xt, p.x, j1)
989+
ej12 = subs(ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k 1:$(p.dim_x)]))
1008990
ej12 = subs2(ej12, ut, p.u, j1)
991+
ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
1009992
ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt)))
1010993
code = quote
1011994
if scheme == :euler
1012-
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:(grid_size - 1))
995+
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:grid_size-1)
1013996
elseif scheme (:euler_implicit, :euler_b) # euler_b is deprecated
1014997
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size)
1015998
elseif scheme == :midpoint
1016-
$pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:(grid_size - 1))
999+
$pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:grid_size-1)
10171000
elseif scheme (:trapeze, :trapezoidal) # trapezoidal is deprecated
10181001
$pref.objective($p_ocp, $(p.dt) * $ej1 / 2 for $j1 in (0, grid_size))
1019-
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:(grid_size - 1))
1002+
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size-1)
10201003
else
10211004
throw(
10221005
"unknown numerical scheme: $scheme (possible choices are :euler, :euler_implicit, :midpoint, :trapeze)",
@@ -1062,10 +1045,13 @@ function p_mayer_exa!(p, p_ocp, e, type)
10621045
pref = prefix_exa()
10631046
x0 = __symgen(:x0)
10641047
xf = __symgen(:xf)
1048+
k = __symgen(:k)
10651049
e = replace_call(e, p.x, p.t0, x0)
10661050
e = replace_call(e, p.x, p.tf, xf)
10671051
e = subs2(e, x0, p.x, 0)
1052+
e = subs(e, x0, :([$(p.x)[$k, 0] for $k 1:$(p.dim_x)]))
10681053
e = subs2(e, xf, p.x, :grid_size)
1054+
e = subs(e, xf, :([$(p.x)[$k, grid_size] for $k 1:$(p.dim_x)]))
10691055
# now, x[i](t0) has been replaced by x[i, 0] and x[i](tf) by x[i, grid_size]
10701056
code = :($pref.objective($p_ocp, $e))
10711057
return __wrap(code, p.lnum, p.line)
@@ -1133,7 +1119,7 @@ PARSING_FUN[:lagrange] = p_lagrange_fun!
11331119
PARSING_FUN[:mayer] = p_mayer_fun!
11341120
PARSING_FUN[:bolza] = p_bolza_fun!
11351121

1136-
# Summary of available parsing subfunctions (:fun backend)
1122+
# Summary of available parsing subfunctions (:exa backend)
11371123

11381124
const PARSING_EXA = OrderedDict{Symbol,Function}()
11391125
PARSING_EXA[:pragma] = p_pragma_exa!
@@ -1295,7 +1281,7 @@ function def_fun(e; log=false)
12951281
$p_ocp = $pref.PreModel()
12961282
$code
12971283
$pref.definition!($p_ocp, $ee)
1298-
$pref.time_dependence!($p_ocp; autonomous=$p.is_autonomous)
1284+
$pref.time_dependence!($p_ocp; autonomous=$p.is_autonomous) # not $(p.xxxx) as this info is known statically
12991285
end
13001286

13011287
if is_active_backend(:exa)
@@ -1383,7 +1369,7 @@ function def_exa(e; log=false)
13831369
$(p.box_u) # lvar and uvar for control
13841370
$(p.box_v) # lvar and uvar for variable (after x and u for compatibility with CTDirect)
13851371
$p_ocp = $pref.ExaCore(
1386-
base_type; backend=backend, minimize=($p.criterion == :min)
1372+
base_type; backend=backend, minimize=($p.criterion == :min) # not $(p.xxxx) as this info is known statically
13871373
)
13881374
$code
13891375
$dyn_check

0 commit comments

Comments
 (0)