Skip to content

Commit ff9d3fd

Browse files
feat: add u_arg kwarg to build_function_wrapper
This allows marking the `u` argument of a generated function so it is specially named. Co-Authored-By: Claude <noreply@anthropic.com>
1 parent a84e002 commit ff9d3fd

3 files changed

Lines changed: 23 additions & 4 deletions

File tree

lib/ModelingToolkitBase/src/ModelingToolkitBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ function __init__()
386386
SU.hashcons(unwrap(ASSERTION_LOG_VARIABLE), true)
387387
SU.hashcons(DDE_AT_IDX_SYM, true)
388388
SU.hashcons(DDE_DELAY_SYM, true)
389+
SU.hashcons(MTKUNKNOWNS_ARG, true)
389390
return nothing
390391
end
391392

lib/ModelingToolkitBase/src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ function wrap_assignments(isscalar, assignments; let_block = false)
7474
end
7575

7676
const MTKPARAMETERS_ARG = SSym(:___mtkparameters___; type = Vector{Vector{Any}}, shape = SymbolicUtils.Unknown(1))
77+
const MTKUNKNOWNS_ARG = SSym(:___mtkunknowns___; type = Vector{Real}, shape = SymbolicUtils.Unknown(1))
7778

7879
"""
7980
$(TYPEDSIGNATURES)

lib/ModelingToolkitBase/src/systems/codegen_utils.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ generated functions, and `args` are the arguments.
256256
`MTKParameters` object are present. These are collapsed into a single argument and
257257
destructured inside the function. `p_start` must also be provided for non-split systems
258258
since it is used by `wrap_delays`.
259+
- `u_arg`: The index in `args` of the argument corresponding to `unknowns(sys)` (the `u`
260+
vector). If `-1` (the default), the u vector is not treated specially. Otherwise, the
261+
argument must be a `Vector` and is wrapped in a `DestructuredArgs` with the common
262+
identifier `MTKUNKNOWNS_ARG`, giving it the predictable name `___mtkunknowns___`.
259263
- `compress_args`: A list of argument ranges that end before `p_start`.
260264
Each range will be compressed into a single argument to the function. For example,
261265
If there are 5 elements in `args` and `compress_args = [2:3]`, then the generated function
@@ -296,7 +300,7 @@ All other keyword arguments are forwarded to `build_function`.
296300
Base.@nospecializeinfer function build_function_wrapper(
297301
sys::AbstractSystem, @nospecialize(expr), @nospecialize(args...); p_start = 2,
298302
p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), compress_args = UnitRange{Int}[],
299-
non_standard_param_layout = false,
303+
non_standard_param_layout = false, u_arg::Integer = -1,
300304
wrap_delays = is_dde(sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity,
301305
add_observed = true, obsidxs_to_use = nothing,
302306
create_bindings = false, output_type = nothing, mkarray = nothing,
@@ -307,6 +311,17 @@ Base.@nospecializeinfer function build_function_wrapper(
307311
obs = observed(sys)
308312
args = Vector{Any}(collect(args))
309313
assignments = Assignment[]
314+
315+
if u_arg != -1
316+
args[u_arg] isa AbstractVector ||
317+
throw(ArgumentError("argument at u_arg = $u_arg must be a Vector, got $(typeof(args[u_arg]))"))
318+
end
319+
320+
u_argument_name = if u_arg == -1
321+
generated_argument_name
322+
else
323+
i -> i == u_arg ? :___mtkunknowns___ : generated_argument_name(i)
324+
end
310325
# turn delayed unknowns into calls to the history function
311326
if wrap_delays
312327
param_arg = is_split(sys) ? MTKPARAMETERS_ARG : generated_argument_name(p_start)
@@ -343,7 +358,7 @@ Base.@nospecializeinfer function build_function_wrapper(
343358

344359
# assignments for reconstructing scalarized array symbolics
345360
if non_standard_param_layout
346-
append!(assignments, array_variable_assignments(args...; filter_vars = required_arrvars))
361+
append!(assignments, array_variable_assignments(args...; filter_vars = required_arrvars, argument_name = u_argument_name))
347362
else
348363
cached = check_mutable_cache(sys, ParameterArrayAssignments, ParameterArrayAssignments, nothing)
349364
if cached isa ParameterArrayAssignments
@@ -357,14 +372,16 @@ Base.@nospecializeinfer function build_function_wrapper(
357372
param_var_to_arridxs; buffer_offset = p_start - 1, filter_vars = required_arrvars
358373
)
359374
)
360-
other_assigns = array_variable_assignments(args...; ignore_vars = keys(param_var_to_arridxs), filter_vars = required_arrvars)
375+
other_assigns = array_variable_assignments(args...; ignore_vars = keys(param_var_to_arridxs), filter_vars = required_arrvars, argument_name = u_argument_name)
361376
append!(assignments, other_assigns)
362377
end
363378
append!(assignments, extra_assignments)
364379

365380
for (i, arg) in enumerate(args)
366381
# Make sure to use the proper names for arguments
367-
args[i] = if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray
382+
args[i] = if u_arg != -1 && i == u_arg
383+
DestructuredArgs(arg, MTKUNKNOWNS_ARG; create_bindings)
384+
elseif symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray
368385
DestructuredArgs(arg, generated_argument_name(i); create_bindings)
369386
else
370387
arg

0 commit comments

Comments
 (0)