11# onepass
22# todo:
3+ # - add tests to compare runs of exa with scalar vs vec dynamics
4+ #
35# - as_range / as_vector for rg / lb, ub could be done here in p_constraint when calling constraint!
46# - cannot call solve if problem not fully defined (dynamics not defined...)
57# - doc: explain projections wrt to t0, tf, t; (...x1...x2...)(t) -> ...gensym1...gensym2... (most internal first)
@@ -118,7 +120,10 @@ $(TYPEDEF)
118120 lnum:: Int = 0
119121 line:: String = " "
120122 dt:: Symbol = __symgen (:dt )
123+ is_global_dyn:: Bool = false
124+ is_coord_dyn:: Bool = false
121125 dyn_coords:: Vector{Int64} = Int64[]
126+ dyn_con:: Union{Symbol,Nothing} = nothing
122127 l_v:: Symbol = __symgen (:l_v )
123128 u_v:: Symbol = __symgen (:u_v )
124129 box_v:: Expr = :(LineNumberNode (0 , " box constraints: variable" ))
@@ -543,10 +548,10 @@ function p_state_exa!(p, p_ocp, x, n, xx; components_names=nothing)
543548 start= init[2 ],
544549 ))
545550 code = __wrap (code, p. lnum, p. line)
546- dyn_con = Symbol (:dyn_con , x ) # name for the constraints associated with the dynamics
551+ p . dyn_con = __symgen (:dyn_con ) # name for the constraints associated with the dynamics
547552 code = quote
548553 $ x = $ code
549- $ dyn_con = Vector {$pref.Constraint} (undef, $ n) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope)
554+ $ (p . dyn_con) = Vector {$pref.Constraint} (undef, $ n) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope)
550555 end
551556 return code
552557end
@@ -691,24 +696,33 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
691696 isnothing (e3) && (e3 = :(Inf * ones (length ($ e1))))
692697 code = @match c_type begin
693698 :boundary || :variable_fun => begin
694- code = :(length ($ e1) == length ($ e3) == 1 || throw (" this constraint must be scalar" )) # (vs. __throw) since raised at runtime
695699 x0 = __symgen (:x0 )
696700 xf = __symgen (:xf )
697701 k = __symgen (:k )
702+ l = __symgen (:l )
698703 e2 = replace_call (e2, p. x, p. t0, x0)
699704 e2 = replace_call (e2, p. x, p. tf, xf)
700705 e2 = subs2 (e2, x0, p. x, 0 )
701706 e2 = subs (e2, x0, :([$ (p. x)[$ k, 0 ] for $ k ∈ 1 : $ (p. dim_x)]))
702707 e2 = subs2 (e2, xf, p. x, :grid_size )
703708 e2 = subs (e2, xf, :([$ (p. x)[$ k, grid_size] for $ k ∈ 1 : $ (p. dim_x)]))
704- concat (code, :($ pref. constraint ($ p_ocp, $ e2; lcon= ($ e1[1 ]), ucon= ($ e3[1 ])))) # todo: e1/3[1] will be e1/3[k] when vectorised over dim
709+ quote
710+ length ($ e1) == length ($ e3) || throw (" wrong bound dimension" ) # (vs. __throw) since raised at runtime
711+ if length ($ e1) == 1
712+ $ pref. constraint ($ p_ocp, $ e2; lcon= ($ e1[1 ]), ucon= ($ e3[1 ]))
713+ else
714+ for $ l ∈ 1 : length ($ e1)
715+ $ pref. constraint ($ p_ocp, $ e2[$ l]; lcon= ($ e1[$ l]), ucon= ($ e3[$ l]))
716+ end
717+ end
718+ end
705719 end
706720 (:initial , rg) => begin
707721 if isnothing (rg)
708722 rg = :(1 : ($ (p. dim_x))) # x(t0) implies rg == nothing but means x[1:p.dim_x](t0)
709723 e2 = subs (e2, p. x, :($ (p. x)[$ rg]))
710724 else
711- rg = as_range (rg)
725+ rg = as_range (rg) # case rg = i (vs i:j or i:p:j)
712726 end
713727 code = :(
714728 length ($ e1) == length ($ e3) == length ($ rg) || throw (" wrong bound dimension" )
@@ -794,23 +808,27 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
794808 :()
795809 end
796810 :state_fun || :control_fun || :mixed => begin
797- code = :(length ($ e1) == length ($ e3) == 1 || throw (" this constraint must be scalar" )) # (vs. __throw) since raised at runtime
798811 xt = __symgen (:xt )
799812 ut = __symgen (:ut )
800- e2 = replace_call (e2, [p. x, p. u], p. t, [xt, ut])
801813 j = __symgen (:j )
802814 k = __symgen (:k )
815+ l = __symgen (:l )
816+ e2 = replace_call (e2, [p. x, p. u], p. t, [xt, ut])
803817 e2 = subs2 (e2, xt, p. x, j)
804818 e2 = subs (e2, xt, :([$ (p. x)[$ k, $ j] for $ k ∈ 1 : $ (p. dim_x)]))
805819 e2 = subs2 (e2, ut, p. u, j)
806820 e2 = subs (e2, ut, :([$ (p. u)[$ k, $ j] for $ k ∈ 1 : $ (p. dim_u)]))
807821 e2 = subs (e2, p. t, :($ (p. t0) + $ j * $ (p. dt)))
808- concat (
809- code,
810- :($ pref. constraint (
811- $ p_ocp, $ e2 for $ j in 0 : grid_size; lcon= ($ e1[1 ]), ucon= ($ e3[1 ])
812- )),
813- )
822+ quote
823+ length ($ e1) == length ($ e3) || throw (" wrong bound dimension" ) # (vs. __throw) since raised at runtime
824+ if length ($ e1) == 1
825+ $ pref. constraint ($ p_ocp, $ e2 for $ j in 0 : grid_size; lcon= ($ e1[1 ]), ucon= ($ e3[1 ]))
826+ else
827+ for $ l ∈ 1 : length ($ e1)
828+ $ pref. constraint ($ p_ocp, $ e2[$ l] for $ j in 0 : grid_size; lcon= ($ e1[$ l]), ucon= ($ e3[$ l]))
829+ end
830+ end
831+ end
814832 end
815833 _ => return __throw (" bad constraint declaration" , p. lnum, p. line)
816834 end
@@ -827,6 +845,9 @@ function p_dynamics!(
827845 isnothing (p. t) && return __throw (" time not yet declared" , p. lnum, p. line)
828846 x ≠ p. x && return __throw (" wrong state $x for dynamics" , p. lnum, p. line)
829847 t ≠ p. t && return __throw (" wrong time $t for dynamics" , p. lnum, p. line)
848+ p. is_global_dyn && return __throw (" dynamics already defined" , p. lnum, p. line)
849+ p. is_coord_dyn && return __throw (" dynamics already partially defined" , p. lnum, p. line)
850+ p. is_global_dyn = true
830851 xut = __symgen (:xut )
831852 ee = replace_call (e, [p. x, p. u], p. t, [xut, xut])
832853 has (ee, p. t) && (p. is_autonomous = false )
@@ -852,7 +873,51 @@ function p_dynamics_fun!(p, p_ocp, x, t, e)
852873end
853874
854875function p_dynamics_exa! (p, p_ocp, x, t, e)
855- return __throw (" dynamics must be defined coordinatewise" , p. lnum, p. line) # note: scalar case is redirected before coordinatewise case
876+ pref = prefix_exa ()
877+ xt = __symgen (:xt )
878+ ut = __symgen (:ut )
879+ e = replace_call (e, [p. x, p. u], p. t, [xt, ut])
880+ j1 = __symgen (:j )
881+ j2 = :($ j1 + 1 )
882+ j12 = :($ j1 + 0.5 )
883+ k = __symgen (:k )
884+ ej1 = subs2 (e, xt, p. x, j1)
885+ ej1 = subs (ej1, xt, :([$ (p. x)[$ k, $ j1] for $ k ∈ 1 : $ (p. dim_x)]))
886+ ej1 = subs2 (ej1, ut, p. u, j1)
887+ ej1 = subs (ej1, ut, :([$ (p. u)[$ k, $ j1] for $ k ∈ 1 : $ (p. dim_u)]))
888+ ej1 = subs (ej1, p. t, :($ (p. t0) + $ j1 * $ (p. dt)))
889+ ej2 = subs2 (e, xt, p. x, j2)
890+ ej2 = subs (ej2, xt, :([$ (p. x)[$ k, $ j2] for $ k ∈ 1 : $ (p. dim_x)]))
891+ ej2 = subs2 (ej2, ut, p. u, j2)
892+ ej2 = subs (ej2, ut, :([$ (p. u)[$ k, $ j2] for $ k ∈ 1 : $ (p. dim_u)]))
893+ ej2 = subs (ej2, p. t, :($ (p. t0) + $ j2 * $ (p. dt)))
894+ ej12 = subs2m (e, xt, p. x, j1)
895+ ej12 = subs (ej12, xt, :([(($ (p. x)[$ k, $ j1] + $ (p. x)[$ k, $ j1 + 1 ]) / 2 ) for $ k ∈ 1 : $ (p. dim_x)]))
896+ ej12 = subs2 (ej12, ut, p. u, j1)
897+ ej12 = subs (ej12, ut, :([$ (p. u)[$ k, $ j1] for $ k ∈ 1 : $ (p. dim_u)]))
898+ ej12 = subs (ej12, p. t, :($ (p. t0) + $ j12 * $ (p. dt)))
899+ dxj = :([$ (p. x)[$ k, $ j2] - $ (p. x)[$ k, $ j1] for $ k ∈ 1 : $ (p. dim_x)])
900+ i = __symgen (:i )
901+ code = quote
902+ for $ i ∈ 1 : $ (p. dim_x)
903+ $ (p. dyn_con)[$ i] = if scheme == :euler # dyn_con already defined outside try catch
904+ $ pref. constraint ($ p_ocp, $ dxj[$ i] - $ (p. dt) * $ ej1[$ i] for $ j1 in 0 : grid_size- 1 )
905+ elseif scheme ∈ (:euler_implicit , :euler_b ) # euler_b is deprecated
906+ $ pref. constraint ($ p_ocp, $ dxj[$ i] - $ (p. dt) * $ ej2[$ i] for $ j1 in 0 : grid_size- 1 )
907+ elseif scheme == :midpoint
908+ $ pref. constraint ($ p_ocp, $ dxj[$ i] - $ (p. dt) * $ ej12[$ i] for $ j1 in 0 : grid_size- 1 )
909+ elseif scheme ∈ (:trapeze , :trapezoidal ) # trapezoidal is deprecated
910+ $ pref. constraint (
911+ $ p_ocp, $ dxj[$ i] - $ (p. dt) * ($ ej1[$ i] + $ ej2[$ i]) / 2 for $ j1 in 0 : grid_size- 1
912+ )
913+ else
914+ throw (
915+ " unknown numerical scheme: $scheme (possible choices are :euler, :euler_implicit, :midpoint, :trapeze)" ,
916+ ) # (vs. __throw) since raised at runtime (and __wrap-ped)
917+ end
918+ end
919+ end
920+ return __wrap (code, p. lnum, p. line)
856921end
857922
858923function p_dynamics_coord! (
@@ -865,6 +930,8 @@ function p_dynamics_coord!(
865930 isnothing (p. t) && return __throw (" time not yet declared" , p. lnum, p. line)
866931 x ≠ p. x && return __throw (" wrong state $x for dynamics" , p. lnum, p. line)
867932 t ≠ p. t && return __throw (" wrong time $t for dynamics" , p. lnum, p. line)
933+ p. is_global_dyn && return __throw (" dynamics already defined" , p. lnum, p. line)
934+ p. is_coord_dyn = true
868935 xut = __symgen (:xut )
869936 ee = replace_call (e, [p. x, p. u], p. t, [xut, xut])
870937 has (ee, p. t) && (p. is_autonomous = false )
@@ -881,7 +948,7 @@ function p_dynamics_coord_fun!(p, p_ocp, x, i, t, e)
881948 args = [r, p. t, xt, ut, p. v]
882949 code = quote
883950 function $fun ($ (args... ))
884- @views $ r[:] .= $ e # note that i can be a full range (allowed for :fun in CTModels, not for :exa)
951+ @views $ r[:] .= $ e # note that i can be a full range for :fun, still todo for :exa
885952 return nothing
886953 end
887954 $ pref. dynamics! ($ p_ocp, $ i, $ fun)
@@ -890,11 +957,12 @@ function p_dynamics_coord_fun!(p, p_ocp, x, i, t, e)
890957end
891958
892959function p_dynamics_coord_exa! (p, p_ocp, x, i, t, e)
960+ return __throw (" dynamics coordinate $i should be an integer" , p. lnum, p. line)
961+ end
962+
963+ function p_dynamics_coord_exa! (p, p_ocp, x, i:: Integer , t, e) # debug: also also add coord = range for :exa
893964 pref = prefix_exa ()
894- i isa Integer ||
895- return __throw (" dynamics coordinate $i should be an integer" , p. lnum, p. line)
896- i ∈ p. dyn_coords &&
897- return __throw (" dynamics coordinate $i already defined" , p. lnum, p. line)
965+ i ∈ p. dyn_coords && return __throw (" dynamics coordinate $i already defined" , p. lnum, p. line)
898966 append! (p. dyn_coords, i)
899967 xt = __symgen (:xt )
900968 ut = __symgen (:ut )
@@ -920,7 +988,7 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e)
920988 ej12 = subs (ej12, p. t, :($ (p. t0) + $ j12 * $ (p. dt)))
921989 dxij = :($ (p. x)[$ i, $ j2] - $ (p. x)[$ i, $ j1])
922990 code = quote
923- if scheme == :euler
991+ $ (p . dyn_con)[ $ i] = if scheme == :euler # dyn_con already defined outside try catch
924992 $ pref. constraint ($ p_ocp, $ dxij - $ (p. dt) * $ ej1 for $ j1 in 0 : grid_size- 1 )
925993 elseif scheme ∈ (:euler_implicit , :euler_b ) # euler_b is deprecated
926994 $ pref. constraint ($ p_ocp, $ dxij - $ (p. dt) * $ ej2 for $ j1 in 0 : grid_size- 1 )
@@ -936,10 +1004,7 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e)
9361004 ) # (vs. __throw) since raised at runtime (and __wrap-ped)
9371005 end
9381006 end
939- dyn_con = Symbol (:dyn_con , p. x) # named constraint to allow retrieval of the dynamics multiplier that approximates the adjoint state
940- code = __wrap (code, p. lnum, p. line)
941- code = :($ dyn_con[$ i] = $ code) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope)
942- return code
1007+ return __wrap (code, p. lnum, p. line)
9431008end
9441009
9451010function p_lagrange! (p, p_ocp, e, type; log= false , backend= __default_parsing_backend ())
@@ -1305,16 +1370,16 @@ function def_exa(e; log=false)
13051370 p = ParsingInfo ()
13061371 code = parse! (p, p_ocp, e; log= log, backend= :exa )
13071372 dyn_check = quote
1308- ! isempty ($ (p. dyn_coords)) || throw ($ e_pref. ParsingError (" dynamics not defined" ))
1309- sort ($ (p. dyn_coords)) == 1 : ($ (p. dim_x)) ||
1310- throw ($ e_pref. ParsingError (" some coordinates of dynamics undefined" ))
1373+ ! $ p. is_global_dyn && ! $ p. is_coord_dyn && throw ($ e_pref. ParsingError (" dynamics not defined" )) # not $(p.xxxx) as these infos are known statically
1374+ if $ p. is_coord_dyn # same: also known statically
1375+ sort ($ (p. dyn_coords)) == 1 : ($ (p. dim_x)) || throw ($ e_pref. ParsingError (" some coordinates of dynamics undefined" ))
1376+ end
13111377 end
13121378 default_scheme = QuoteNode (__default_scheme_exa ())
13131379 default_grid_size = __default_grid_size_exa ()
13141380 default_backend = __default_backend_exa ()
13151381 default_init = __default_init_exa ()
13161382 default_base_type = __default_base_type_exa ()
1317- dyn_con = Symbol (:dyn_con , p. x)
13181383
13191384 getter = quote
13201385 function (sol; val)
@@ -1327,7 +1392,7 @@ function def_exa(e; log=false)
13271392 elseif val == :costate
13281393 px = zeros (base_type, $ (p. dim_x), grid_size)
13291394 for i in 1 : ($ (p. dim_x))
1330- px[i, :] = Array ($ pref. multipliers (sol, $ dyn_con[i])) # Array to copy from GPU
1395+ px[i, :] = Array ($ pref. multipliers (sol, $ (p . dyn_con) [i])) # Array to copy from GPU
13311396 end
13321397 px
13331398 elseif val == :state_l
0 commit comments