Skip to content

Commit 8d28fc0

Browse files
committed
migration
2 parents eed71d8 + 65d31be commit 8d28fc0

10 files changed

Lines changed: 2705 additions & 98 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ reports/
3131

3232
# claude
3333
CLAUDE.local.md
34+
.claude/

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "CTParser"
22
uuid = "32681960-a1b1-40db-9bff-a1ca817385d1"
3-
version = "0.8.1"
3+
version = "0.8.2-beta"
44
authors = ["Jean-Baptiste Caillau <jean-baptiste.caillau@univ-cotedazur.fr>"]
55

66
[deps]

src/onepass.jl

Lines changed: 94 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# onepass
22
# todo:
3+
# - add tests to compare runs of exa with scalar vs vec dynamics
4+
#
35
# - as_range / as_vector for rg / lb, ub could be done here in p_constraint when calling constraint!
46
# - cannot call solve if problem not fully defined (dynamics not defined...)
57
# - doc: explain projections wrt to t0, tf, t; (...x1...x2...)(t) -> ...gensym1...gensym2... (most internal first)
@@ -118,7 +120,10 @@ $(TYPEDEF)
118120
lnum::Int = 0
119121
line::String = ""
120122
dt::Symbol = __symgen(:dt)
123+
is_global_dyn::Bool = false
124+
is_coord_dyn::Bool = false
121125
dyn_coords::Vector{Int64} = Int64[]
126+
dyn_con::Union{Symbol,Nothing} = nothing
122127
l_v::Symbol = __symgen(:l_v)
123128
u_v::Symbol = __symgen(:u_v)
124129
box_v::Expr = :(LineNumberNode(0, "box constraints: variable"))
@@ -543,10 +548,10 @@ function p_state_exa!(p, p_ocp, x, n, xx; components_names=nothing)
543548
start=init[2],
544549
))
545550
code = __wrap(code, p.lnum, p.line)
546-
dyn_con = Symbol(:dyn_con, x) # name for the constraints associated with the dynamics
551+
p.dyn_con = __symgen(:dyn_con) # name for the constraints associated with the dynamics
547552
code = quote
548553
$x = $code
549-
$dyn_con = Vector{$pref.Constraint}(undef, $n) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope)
554+
$(p.dyn_con) = Vector{$pref.Constraint}(undef, $n) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope)
550555
end
551556
return code
552557
end
@@ -691,24 +696,33 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
691696
isnothing(e3) && (e3 = :(Inf * ones(length($e1))))
692697
code = @match c_type begin
693698
:boundary || :variable_fun => begin
694-
code = :(length($e1) == length($e3) == 1 || throw("this constraint must be scalar")) # (vs. __throw) since raised at runtime
695699
x0 = __symgen(:x0)
696700
xf = __symgen(:xf)
697701
k = __symgen(:k)
702+
l = __symgen(:l)
698703
e2 = replace_call(e2, p.x, p.t0, x0)
699704
e2 = replace_call(e2, p.x, p.tf, xf)
700705
e2 = subs2(e2, x0, p.x, 0)
701706
e2 = subs(e2, x0, :([$(p.x)[$k, 0] for $k 1:$(p.dim_x)]))
702707
e2 = subs2(e2, xf, p.x, :grid_size)
703708
e2 = subs(e2, xf, :([$(p.x)[$k, grid_size] for $k 1:$(p.dim_x)]))
704-
concat(code, :($pref.constraint($p_ocp, $e2; lcon=($e1[1]), ucon=($e3[1])))) # todo: e1/3[1] will be e1/3[k] when vectorised over dim
709+
quote
710+
length($e1) == length($e3) || throw("wrong bound dimension") # (vs. __throw) since raised at runtime
711+
if length($e1) == 1
712+
$pref.constraint($p_ocp, $e2; lcon=($e1[1]), ucon=($e3[1]))
713+
else
714+
for $l 1:length($e1)
715+
$pref.constraint($p_ocp, $e2[$l]; lcon=($e1[$l]), ucon=($e3[$l]))
716+
end
717+
end
718+
end
705719
end
706720
(:initial, rg) => begin
707721
if isnothing(rg)
708722
rg = :(1:($(p.dim_x))) # x(t0) implies rg == nothing but means x[1:p.dim_x](t0)
709723
e2 = subs(e2, p.x, :($(p.x)[$rg]))
710724
else
711-
rg = as_range(rg)
725+
rg = as_range(rg) # case rg = i (vs i:j or i:p:j)
712726
end
713727
code = :(
714728
length($e1) == length($e3) == length($rg) || throw("wrong bound dimension")
@@ -794,23 +808,27 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
794808
:()
795809
end
796810
:state_fun || :control_fun || :mixed => begin
797-
code = :(length($e1) == length($e3) == 1 || throw("this constraint must be scalar")) # (vs. __throw) since raised at runtime
798811
xt = __symgen(:xt)
799812
ut = __symgen(:ut)
800-
e2 = replace_call(e2, [p.x, p.u], p.t, [xt, ut])
801813
j = __symgen(:j)
802814
k = __symgen(:k)
815+
l = __symgen(:l)
816+
e2 = replace_call(e2, [p.x, p.u], p.t, [xt, ut])
803817
e2 = subs2(e2, xt, p.x, j)
804818
e2 = subs(e2, xt, :([$(p.x)[$k, $j] for $k 1:$(p.dim_x)]))
805819
e2 = subs2(e2, ut, p.u, j)
806820
e2 = subs(e2, ut, :([$(p.u)[$k, $j] for $k 1:$(p.dim_u)]))
807821
e2 = subs(e2, p.t, :($(p.t0) + $j * $(p.dt)))
808-
concat(
809-
code,
810-
:($pref.constraint(
811-
$p_ocp, $e2 for $j in 0:grid_size; lcon=($e1[1]), ucon=($e3[1])
812-
)),
813-
)
822+
quote
823+
length($e1) == length($e3) || throw("wrong bound dimension") # (vs. __throw) since raised at runtime
824+
if length($e1) == 1
825+
$pref.constraint($p_ocp, $e2 for $j in 0:grid_size; lcon=($e1[1]), ucon=($e3[1]))
826+
else
827+
for $l 1:length($e1)
828+
$pref.constraint($p_ocp, $e2[$l] for $j in 0:grid_size; lcon=($e1[$l]), ucon=($e3[$l]))
829+
end
830+
end
831+
end
814832
end
815833
_ => return __throw("bad constraint declaration", p.lnum, p.line)
816834
end
@@ -827,6 +845,9 @@ function p_dynamics!(
827845
isnothing(p.t) && return __throw("time not yet declared", p.lnum, p.line)
828846
x p.x && return __throw("wrong state $x for dynamics", p.lnum, p.line)
829847
t p.t && return __throw("wrong time $t for dynamics", p.lnum, p.line)
848+
p.is_global_dyn && return __throw("dynamics already defined", p.lnum, p.line)
849+
p.is_coord_dyn && return __throw("dynamics already partially defined", p.lnum, p.line)
850+
p.is_global_dyn = true
830851
xut = __symgen(:xut)
831852
ee = replace_call(e, [p.x, p.u], p.t, [xut, xut])
832853
has(ee, p.t) && (p.is_autonomous = false)
@@ -852,7 +873,51 @@ function p_dynamics_fun!(p, p_ocp, x, t, e)
852873
end
853874

854875
function p_dynamics_exa!(p, p_ocp, x, t, e)
855-
return __throw("dynamics must be defined coordinatewise", p.lnum, p.line) # note: scalar case is redirected before coordinatewise case
876+
pref = prefix_exa()
877+
xt = __symgen(:xt)
878+
ut = __symgen(:ut)
879+
e = replace_call(e, [p.x, p.u], p.t, [xt, ut])
880+
j1 = __symgen(:j)
881+
j2 = :($j1 + 1)
882+
j12 = :($j1 + 0.5)
883+
k = __symgen(:k)
884+
ej1 = subs2(e, xt, p.x, j1)
885+
ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k 1:$(p.dim_x)]))
886+
ej1 = subs2(ej1, ut, p.u, j1)
887+
ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
888+
ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt)))
889+
ej2 = subs2(e, xt, p.x, j2)
890+
ej2 = subs(ej2, xt, :([$(p.x)[$k, $j2] for $k 1:$(p.dim_x)]))
891+
ej2 = subs2(ej2, ut, p.u, j2)
892+
ej2 = subs(ej2, ut, :([$(p.u)[$k, $j2] for $k 1:$(p.dim_u)]))
893+
ej2 = subs(ej2, p.t, :($(p.t0) + $j2 * $(p.dt)))
894+
ej12 = subs2m(e, xt, p.x, j1)
895+
ej12 = subs(ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k 1:$(p.dim_x)]))
896+
ej12 = subs2(ej12, ut, p.u, j1)
897+
ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
898+
ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt)))
899+
dxj = :([$(p.x)[$k, $j2] - $(p.x)[$k, $j1] for $k 1:$(p.dim_x)])
900+
i = __symgen(:i)
901+
code = quote
902+
for $i 1:$(p.dim_x)
903+
$(p.dyn_con)[$i] = if scheme == :euler # dyn_con already defined outside try catch
904+
$pref.constraint($p_ocp, $dxj[$i] - $(p.dt) * $ej1[$i] for $j1 in 0:grid_size-1)
905+
elseif scheme (:euler_implicit, :euler_b) # euler_b is deprecated
906+
$pref.constraint($p_ocp, $dxj[$i] - $(p.dt) * $ej2[$i] for $j1 in 0:grid_size-1)
907+
elseif scheme == :midpoint
908+
$pref.constraint($p_ocp, $dxj[$i] - $(p.dt) * $ej12[$i] for $j1 in 0:grid_size-1)
909+
elseif scheme (:trapeze, :trapezoidal) # trapezoidal is deprecated
910+
$pref.constraint(
911+
$p_ocp, $dxj[$i] - $(p.dt) * ($ej1[$i] + $ej2[$i]) / 2 for $j1 in 0:grid_size-1
912+
)
913+
else
914+
throw(
915+
"unknown numerical scheme: $scheme (possible choices are :euler, :euler_implicit, :midpoint, :trapeze)",
916+
) # (vs. __throw) since raised at runtime (and __wrap-ped)
917+
end
918+
end
919+
end
920+
return __wrap(code, p.lnum, p.line)
856921
end
857922

858923
function p_dynamics_coord!(
@@ -865,6 +930,8 @@ function p_dynamics_coord!(
865930
isnothing(p.t) && return __throw("time not yet declared", p.lnum, p.line)
866931
x p.x && return __throw("wrong state $x for dynamics", p.lnum, p.line)
867932
t p.t && return __throw("wrong time $t for dynamics", p.lnum, p.line)
933+
p.is_global_dyn && return __throw("dynamics already defined", p.lnum, p.line)
934+
p.is_coord_dyn = true
868935
xut = __symgen(:xut)
869936
ee = replace_call(e, [p.x, p.u], p.t, [xut, xut])
870937
has(ee, p.t) && (p.is_autonomous = false)
@@ -881,7 +948,7 @@ function p_dynamics_coord_fun!(p, p_ocp, x, i, t, e)
881948
args = [r, p.t, xt, ut, p.v]
882949
code = quote
883950
function $fun($(args...))
884-
@views $r[:] .= $e # note that i can be a full range (allowed for :fun in CTModels, not for :exa)
951+
@views $r[:] .= $e # note that i can be a full range for :fun, still todo for :exa
885952
return nothing
886953
end
887954
$pref.dynamics!($p_ocp, $i, $fun)
@@ -890,11 +957,12 @@ function p_dynamics_coord_fun!(p, p_ocp, x, i, t, e)
890957
end
891958

892959
function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e)
960+
return __throw("dynamics coordinate $i should be an integer", p.lnum, p.line)
961+
end
962+
963+
function p_dynamics_coord_exa!(p, p_ocp, x, i::Integer, t, e) # debug: also also add coord = range for :exa
893964
pref = prefix_exa()
894-
i isa Integer ||
895-
return __throw("dynamics coordinate $i should be an integer", p.lnum, p.line)
896-
i p.dyn_coords &&
897-
return __throw("dynamics coordinate $i already defined", p.lnum, p.line)
965+
i p.dyn_coords && return __throw("dynamics coordinate $i already defined", p.lnum, p.line)
898966
append!(p.dyn_coords, i)
899967
xt = __symgen(:xt)
900968
ut = __symgen(:ut)
@@ -920,7 +988,7 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e)
920988
ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt)))
921989
dxij = :($(p.x)[$i, $j2] - $(p.x)[$i, $j1])
922990
code = quote
923-
if scheme == :euler
991+
$(p.dyn_con)[$i] = if scheme == :euler # dyn_con already defined outside try catch
924992
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:grid_size-1)
925993
elseif scheme (:euler_implicit, :euler_b) # euler_b is deprecated
926994
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:grid_size-1)
@@ -936,10 +1004,7 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e)
9361004
) # (vs. __throw) since raised at runtime (and __wrap-ped)
9371005
end
9381006
end
939-
dyn_con = Symbol(:dyn_con, p.x) # named constraint to allow retrieval of the dynamics multiplier that approximates the adjoint state
940-
code = __wrap(code, p.lnum, p.line)
941-
code = :($dyn_con[$i] = $code) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope)
942-
return code
1007+
return __wrap(code, p.lnum, p.line)
9431008
end
9441009

9451010
function p_lagrange!(p, p_ocp, e, type; log=false, backend=__default_parsing_backend())
@@ -1305,16 +1370,16 @@ function def_exa(e; log=false)
13051370
p = ParsingInfo()
13061371
code = parse!(p, p_ocp, e; log=log, backend=:exa)
13071372
dyn_check = quote
1308-
!isempty($(p.dyn_coords)) || throw($e_pref.ParsingError("dynamics not defined"))
1309-
sort($(p.dyn_coords)) == 1:($(p.dim_x)) ||
1310-
throw($e_pref.ParsingError("some coordinates of dynamics undefined"))
1373+
!$p.is_global_dyn && !$p.is_coord_dyn && throw($e_pref.ParsingError("dynamics not defined")) # not $(p.xxxx) as these infos are known statically
1374+
if $p.is_coord_dyn # same: also known statically
1375+
sort($(p.dyn_coords)) == 1:($(p.dim_x)) || throw($e_pref.ParsingError("some coordinates of dynamics undefined"))
1376+
end
13111377
end
13121378
default_scheme = QuoteNode(__default_scheme_exa())
13131379
default_grid_size = __default_grid_size_exa()
13141380
default_backend = __default_backend_exa()
13151381
default_init = __default_init_exa()
13161382
default_base_type = __default_base_type_exa()
1317-
dyn_con = Symbol(:dyn_con, p.x)
13181383

13191384
getter = quote
13201385
function (sol; val)
@@ -1327,7 +1392,7 @@ function def_exa(e; log=false)
13271392
elseif val == :costate
13281393
px = zeros(base_type, $(p.dim_x), grid_size)
13291394
for i in 1:($(p.dim_x))
1330-
px[i, :] = Array($pref.multipliers(sol, $dyn_con[i])) # Array to copy from GPU
1395+
px[i, :] = Array($pref.multipliers(sol, $(p.dyn_con)[i])) # Array to copy from GPU
13311396
end
13321397
px
13331398
elseif val == :state_l

0 commit comments

Comments
 (0)