Skip to content

Commit defcb16

Browse files
refactor: pass u_arg to build_function_wrapper where appropriate
1 parent d59d0fd commit defcb16

5 files changed

Lines changed: 28 additions & 22 deletions

File tree

lib/ModelingToolkitBase/src/inputoutput.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ function generate_control_function(
313313
args = (ddvs, args...)
314314
end
315315
f = build_function_wrapper(
316-
sys, rhss, args...; p_start = 3 + implicit_dae,
316+
sys, rhss, args...; u_arg = 1 + Int(implicit_dae), p_start = 3 + implicit_dae,
317317
p_end = length(p) + 2 + implicit_dae, kwargs...
318318
)
319319
f = eval_or_rgf.(f; eval_expression, eval_module)

lib/ModelingToolkitBase/src/systems/abstractsystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ function generate_custom_function(
4040
dvs,
4141
p...,
4242
get_iv(sys);
43+
u_arg = 1,
4344
kwargs...,
4445
expression = Val{true}
4546
)
@@ -48,6 +49,7 @@ function generate_custom_function(
4849
sys, exprs,
4950
dvs,
5051
p...;
52+
u_arg = 1,
5153
kwargs...,
5254
expression = Val{true}
5355
)

lib/ModelingToolkitBase/src/systems/callbacks.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ Base.@nospecializeinfer function compile_condition(
10021002
end
10031003

10041004
fs = build_function_wrapper(
1005-
sys, condit, u, p..., t; kwargs..., cse = false
1005+
sys, condit, u, p..., t; u_arg = 1, kwargs..., cse = false
10061006
)
10071007
fs = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(
10081008
Val{false}, fs...; eval_expression, eval_module
@@ -1451,13 +1451,13 @@ Base.@nospecializeinfer function compile_explicit_affect(
14511451
u_up,
14521452
u_up! = build_function_wrapper(
14531453
sys, (@view rhss[is_u]), dvs, _ps..., t;
1454-
wrap_code = add_integrator_header(sys, integ, :u),
1454+
u_arg = 1, wrap_code = add_integrator_header(sys, integ, :u),
14551455
outputidxs = u_idxs, wrap_mtkparameters, iip_config = (false, true)
14561456
)
14571457
p_up,
14581458
p_up! = build_function_wrapper(
14591459
sys, (@view rhss[is_p]), dvs, _ps..., t;
1460-
wrap_code = add_integrator_header(sys, integ, :p),
1460+
u_arg = 1, wrap_code = add_integrator_header(sys, integ, :p),
14611461
outputidxs = p_idxs, wrap_mtkparameters, iip_config = (false, true)
14621462
)
14631463

lib/ModelingToolkitBase/src/systems/codegen.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ function generate_rhs(
108108
p_start += 1
109109
end
110110

111+
u_arg = scalar ? -1 : (implicit_dae ? 2 : 1)
111112
res = build_function_wrapper(
112-
sys, rhss, args...; p_start, extra_assignments,
113+
sys, rhss, args...; p_start, extra_assignments, u_arg,
113114
expression = Val{true}, expression_module = eval_module, kwargs...
114115
)
115116
nargs = length(args) - length(p) + 1
@@ -147,7 +148,7 @@ function generate_diffusion_function(
147148
eqs = vec(eqs)
148149
end
149150
p = reorder_parameters(sys, ps)
150-
res = build_function_wrapper(sys, eqs, dvs, p..., get_iv(sys); kwargs...)
151+
res = build_function_wrapper(sys, eqs, dvs, p..., get_iv(sys); u_arg = 1, kwargs...)
151152
if expression == Val{true}
152153
return res
153154
end
@@ -262,7 +263,7 @@ function generate_jacobian(
262263
nargs = 3
263264
end
264265
res = build_function_wrapper(
265-
sys, jac, args...; wrap_code, expression = Val{true},
266+
sys, jac, args...; wrap_code, u_arg = 1, expression = Val{true},
266267
expression_module = eval_module, checkbounds, kwargs...
267268
)
268269
return maybe_compile_function(
@@ -309,6 +310,7 @@ function generate_tgrad(
309310
dvs,
310311
p...,
311312
get_iv(sys);
313+
u_arg = 1,
312314
expression = Val{true},
313315
expression_module = eval_module,
314316
kwargs...
@@ -392,7 +394,7 @@ function generate_W(
392394
p = reorder_parameters(sys, ps)
393395
res = build_function_wrapper(
394396
sys, W, dvs, p..., W_GAMMA, t; wrap_code,
395-
p_end = 1 + length(p), checkbounds, kwargs...
397+
u_arg = 1, p_end = 1 + length(p), checkbounds, kwargs...
396398
)
397399
return maybe_compile_function(
398400
expression, wrap_gfw, (2, 4, is_split(sys)), res; eval_expression, eval_module
@@ -432,7 +434,7 @@ function generate_dae_jacobian(
432434
p = reorder_parameters(sys, ps)
433435
res = build_function_wrapper(
434436
sys, jac, derivatives, dvs, p..., W_GAMMA, t;
435-
p_start = 3, p_end = 2 + length(p), kwargs...
437+
u_arg = 2, p_start = 3, p_end = 2 + length(p), kwargs...
436438
)
437439
return maybe_compile_function(
438440
expression, wrap_gfw, (3, 5, is_split(sys)), res; eval_expression, eval_module
@@ -694,8 +696,9 @@ function generate_cost(
694696
args = (dvs, ps...)
695697
nargs = 2
696698
end
699+
u_arg = is_time_dependent(sys) ? -1 : 1
697700
res = build_function_wrapper(
698-
sys, obj, args...; expression = Val{true}, p_start, p_end, wrap_delays,
701+
sys, obj, args...; expression = Val{true}, p_start, p_end, wrap_delays, u_arg,
699702
histfn = (p, t) -> BVP_SOLUTION(t), histfn_symbolic = BVP_SOLUTION, kwargs...
700703
)[1]
701704
if expression == Val{true}
@@ -788,7 +791,7 @@ function generate_cost_gradient(
788791
dvs = unknowns(sys)
789792
ps = reorder_parameters(sys)
790793
exprs = calculate_cost_gradient(sys; simplify)
791-
res = build_function_wrapper(sys, exprs, dvs, ps...; expression = Val{true}, kwargs...)
794+
res = build_function_wrapper(sys, exprs, dvs, ps...; u_arg = 1, expression = Val{true}, kwargs...)
792795
return maybe_compile_function(
793796
expression, wrap_gfw, (2, 2, is_split(sys)), res; eval_expression, eval_module
794797
)
@@ -847,7 +850,7 @@ function generate_cost_hessian(
847850
if sparse
848851
sparsity = similar(exprs, Float64)
849852
end
850-
res = build_function_wrapper(sys, exprs, dvs, ps...; expression = Val{true}, kwargs...)
853+
res = build_function_wrapper(sys, exprs, dvs, ps...; u_arg = 1, expression = Val{true}, kwargs...)
851854
fn = maybe_compile_function(
852855
expression, wrap_gfw, (2, 2, is_split(sys)), res; eval_expression, eval_module
853856
)
@@ -879,7 +882,7 @@ function generate_cons(
879882
cons = canonical_constraints(sys)
880883
dvs = unknowns(sys)
881884
ps = reorder_parameters(sys)
882-
res = build_function_wrapper(sys, cons, dvs, ps...; expression = Val{true}, kwargs...)
885+
res = build_function_wrapper(sys, cons, dvs, ps...; u_arg = 1, expression = Val{true}, kwargs...)
883886
return maybe_compile_function(
884887
expression, wrap_gfw, (2, 2, is_split(sys)), res; eval_expression, eval_module
885888
)
@@ -936,7 +939,7 @@ function generate_constraint_jacobian(
936939
sparsity = calculate_constraint_jacobian(
937940
sys; simplify, sparse, return_sparsity = true
938941
)
939-
res = build_function_wrapper(sys, jac, dvs, ps...; expression = Val{true}, kwargs...)
942+
res = build_function_wrapper(sys, jac, dvs, ps...; u_arg = 1, expression = Val{true}, kwargs...)
940943
fn = maybe_compile_function(
941944
expression, wrap_gfw, (2, 2, is_split(sys)), res; eval_expression, eval_module
942945
)
@@ -995,7 +998,7 @@ function generate_constraint_hessian(
995998
sparsity = calculate_constraint_hessian(
996999
sys; simplify, sparse, return_sparsity = true
9971000
)
998-
res = build_function_wrapper(sys, hess, dvs, ps...; expression = Val{true}, kwargs...)
1001+
res = build_function_wrapper(sys, hess, dvs, ps...; u_arg = 1, expression = Val{true}, kwargs...)
9991002
fn = maybe_compile_function(
10001003
expression, wrap_gfw, (2, 2, is_split(sys)), res; eval_expression, eval_module
10011004
)
@@ -1048,7 +1051,7 @@ function generate_control_jacobian(
10481051
ps = parameters(sys; initial_parameters = true)
10491052
jac = calculate_control_jacobian(sys; simplify = simplify, sparse = sparse)
10501053
p = reorder_parameters(sys, ps)
1051-
res = build_function_wrapper(sys, jac, dvs, p..., get_iv(sys); kwargs...)
1054+
res = build_function_wrapper(sys, jac, dvs, p..., get_iv(sys); u_arg = 1, kwargs...)
10521055
return maybe_compile_function(
10531056
expression, wrap_gfw, (2, 3, is_split(sys)), res; eval_expression, eval_module
10541057
)
@@ -1059,6 +1062,7 @@ function generate_rate_function(js::System, rate)
10591062
return build_function_wrapper(
10601063
js, rate, unknowns(js), p...,
10611064
get_iv(js),
1065+
u_arg = 1,
10621066
expression = Val{true},
10631067
iip_config = (true, false),
10641068
)[1]

src/systems/codegen.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,18 +334,18 @@ function generate_semiquadratic_functions(
334334
end
335335

336336
f1_iip = build_function_wrapper(
337-
sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., iv; p_start = 3,
337+
sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., iv; u_arg = 2, p_start = 3,
338338
extra_assignments = f1_iip_ir, expression = Val{true}, iip_config = (true, false), kwargs...
339339
)[1]
340340
f2_iip = build_function_wrapper(
341-
sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., iv; p_start = 3,
341+
sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., iv; u_arg = 2, p_start = 3,
342342
extra_assignments = f2_iip_ir, expression = Val{true}, iip_config = (true, false), kwargs...
343343
)[1]
344344
f1_oop = build_function_wrapper(
345-
sys, f1_expr, dvs, ps..., iv; expression = Val{true}, iip_config = (true, false), kwargs...
345+
sys, f1_expr, dvs, ps..., iv; u_arg = 1, expression = Val{true}, iip_config = (true, false), kwargs...
346346
)[1]
347347
f2_oop = build_function_wrapper(
348-
sys, f2_expr, dvs, ps..., iv; expression = Val{true}, iip_config = (true, false), kwargs...
348+
sys, f2_expr, dvs, ps..., iv; u_arg = 1, expression = Val{true}, iip_config = (true, false), kwargs...
349349
)[1]
350350

351351
f1 = maybe_compile_function(
@@ -511,11 +511,11 @@ function generate_semiquadratic_jacobian(
511511
oop_expr = length(terms) == 1 ? only(terms) : term(+, terms...)
512512

513513
j_iip, _ = build_function_wrapper(
514-
sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., iv; p_start = 3,
514+
sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., iv; u_arg = 2, p_start = 3,
515515
extra_assignments = iip_ir, expression = Val{true}, iip_config = (true, false), kwargs...
516516
)
517517
j_oop, _ = build_function_wrapper(
518-
sys, oop_expr, dvs, ps..., iv; expression = Val{true}, iip_config = (true, false), kwargs...
518+
sys, oop_expr, dvs, ps..., iv; u_arg = 1, expression = Val{true}, iip_config = (true, false), kwargs...
519519
)
520520
return maybe_compile_function(
521521
expression, wrap_gfw, (2, 3, is_split(sys)),

0 commit comments

Comments
 (0)