Skip to content

Commit 7e5d4ea

Browse files
jbcaillauclaude
andcommitted
Add bare symbol expansion after all subs2 calls in onepass.jl
Implements the pattern from p_constraint_exa! boundary case across all functions that use subs2, enabling bare symbol handling (e.g., x(t) → [x[k, j] for k ∈ 1:dim_x]). Changes to src/onepass.jl: 1. p_mayer_exa! (lines 1030-1044): +3 lines - Added k = __symgen(:k) - Added subs for x0 and xf bare symbols 2. p_constraint_exa! path constraints (lines 797-815): +3 lines - Added k = __symgen(:k) - Added subs for xt and ut bare symbols 3. p_lagrange_exa! (lines 968-985): +4 lines - Added k = __symgen(:k) - Added subs for xt and ut bare symbols in ej1 and ej12 4. p_dynamics_coord_exa! (lines 900-920): +6 lines - Added k = __symgen(:k) - Added subs for xt and ut bare symbols in ej1, ej2, and ej12 Total: 16 lines added across 4 functions Pattern applied: - subs2 handles indexed cases: x[i] → y[i, j] and x[1:3] → [y[k, j] for k ∈ 1:3] - subs handles bare symbols: x → [y[k, j] for k ∈ 1:dim] - Same symbol k used for both state and control (sequential application) This completes the implementation of tensor support for bare symbols in the ExaModels backend. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 429619d commit 7e5d4ea

1 file changed

Lines changed: 19 additions & 0 deletions

File tree

src/onepass.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,10 +695,13 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
695695
code = :(length($e1) == length($e3) == 1 || throw("this constraint must be scalar")) # (vs. __throw) since raised at runtime
696696
x0 = __symgen(:x0)
697697
xf = __symgen(:xf)
698+
k = __symgen(:k)
698699
e2 = replace_call(e2, p.x, p.t0, x0)
699700
e2 = replace_call(e2, p.x, p.tf, xf)
700701
e2 = subs2(e2, x0, p.x, 0)
702+
e2 = subs(e2, x0, :([$(p.x)[$k, 0] for $k 1:$(p.dim_x)]))
701703
e2 = subs2(e2, xf, p.x, :grid_size)
704+
e2 = subs(e2, xf, :([$(p.x)[$k, :grid_size] for $k 1:$(p.dim_x)]))
702705
concat(code, :($pref.constraint($p_ocp, $e2; lcon=($e1), ucon=($e3)))) # debug: to vectorise
703706
end
704707
(:initial, rg) => begin
@@ -797,8 +800,11 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
797800
ut = __symgen(:ut)
798801
e2 = replace_call(e2, [p.x, p.u], p.t, [xt, ut])
799802
j = __symgen(:j)
803+
k = __symgen(:k)
800804
e2 = subs2(e2, xt, p.x, j)
805+
e2 = subs(e2, xt, :([$(p.x)[$k, $j] for $k 1:$(p.dim_x)]))
801806
e2 = subs2(e2, ut, p.u, j)
807+
e2 = subs(e2, ut, :([$(p.u)[$k, $j] for $k 1:$(p.dim_u)]))
802808
e2 = subs(e2, p.t, :($(p.t0) + $j * $(p.dt)))
803809
concat(
804810
code,
@@ -897,14 +903,20 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e)
897903
j1 = __symgen(:j)
898904
j2 = :($j1 + 1)
899905
j12 = :($j1 + 0.5)
906+
k = __symgen(:k)
900907
ej1 = subs2(e, xt, p.x, j1)
908+
ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k 1:$(p.dim_x)]))
901909
ej1 = subs2(ej1, ut, p.u, j1)
910+
ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
902911
ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt)))
903912
ej2 = subs2(e, xt, p.x, j2)
913+
ej2 = subs(ej2, xt, :([$(p.x)[$k, $j2] for $k 1:$(p.dim_x)]))
904914
ej2 = subs2(ej2, ut, p.u, j2)
915+
ej2 = subs(ej2, ut, :([$(p.u)[$k, $j2] for $k 1:$(p.dim_u)]))
905916
ej2 = subs(ej2, p.t, :($(p.t0) + $j2 * $(p.dt)))
906917
ej12 = subs5(e, xt, p.x, j1)
907918
ej12 = subs2(ej12, ut, p.u, j1)
919+
ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
908920
ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt)))
909921
dxij = :($(p.x)[$i, $j2] - $(p.x)[$i, $j1])
910922
code = quote
@@ -967,11 +979,15 @@ function p_lagrange_exa!(p, p_ocp, e, type)
967979
j1 = __symgen(:j)
968980
j2 = :($j1 + 1)
969981
j12 = :($j1 + 0.5)
982+
k = __symgen(:k)
970983
ej1 = subs2(e, xt, p.x, j1)
984+
ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k 1:$(p.dim_x)]))
971985
ej1 = subs2(ej1, ut, p.u, j1)
986+
ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
972987
ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt)))
973988
ej12 = subs5(e, xt, p.x, j1)
974989
ej12 = subs2(ej12, ut, p.u, j1)
990+
ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
975991
ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt)))
976992
code = quote
977993
if scheme == :euler
@@ -1028,10 +1044,13 @@ function p_mayer_exa!(p, p_ocp, e, type)
10281044
pref = prefix_exa()
10291045
x0 = __symgen(:x0)
10301046
xf = __symgen(:xf)
1047+
k = __symgen(:k)
10311048
e = replace_call(e, p.x, p.t0, x0)
10321049
e = replace_call(e, p.x, p.tf, xf)
10331050
e = subs2(e, x0, p.x, 0)
1051+
e = subs(e, x0, :([$(p.x)[$k, 0] for $k 1:$(p.dim_x)]))
10341052
e = subs2(e, xf, p.x, :grid_size)
1053+
e = subs(e, xf, :([$(p.x)[$k, :grid_size] for $k 1:$(p.dim_x)]))
10351054
# now, x[i](t0) has been replaced by x[i, 0] and x[i](tf) by x[i, grid_size]
10361055
code = :($pref.objective($p_ocp, $e))
10371056
return __wrap(code, p.lnum, p.line)

0 commit comments

Comments
 (0)