@@ -256,6 +256,10 @@ generated functions, and `args` are the arguments.
256256 `MTKParameters` object are present. These are collapsed into a single argument and
257257 destructured inside the function. `p_start` must also be provided for non-split systems
258258 since it is used by `wrap_delays`.
259+ - `u_arg`: The index in `args` of the argument corresponding to `unknowns(sys)` (the `u`
260+ vector). If `-1` (the default), the u vector is not treated specially. Otherwise, the
261+ argument must be a `Vector` and is wrapped in a `DestructuredArgs` with the common
262+ identifier `MTKUNKNOWNS_ARG`, giving it the predictable name `___mtkunknowns___`.
259263- `compress_args`: A list of argument ranges that end before `p_start`.
260264 Each range will be compressed into a single argument to the function. For example,
261265 If there are 5 elements in `args` and `compress_args = [2:3]`, then the generated function
@@ -296,7 +300,7 @@ All other keyword arguments are forwarded to `build_function`.
296300Base. @nospecializeinfer function build_function_wrapper (
297301 sys:: AbstractSystem , @nospecialize (expr), @nospecialize (args... ); p_start = 2 ,
298302 p_end = is_time_dependent (sys) ? length (args) - 1 : length (args), compress_args = UnitRange{Int}[],
299- non_standard_param_layout = false ,
303+ non_standard_param_layout = false , u_arg :: Integer = - 1 ,
300304 wrap_delays = is_dde (sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity,
301305 add_observed = true , obsidxs_to_use = nothing ,
302306 create_bindings = false , output_type = nothing , mkarray = nothing ,
@@ -307,6 +311,17 @@ Base.@nospecializeinfer function build_function_wrapper(
307311 obs = observed (sys)
308312 args = Vector {Any} (collect (args))
309313 assignments = Assignment[]
314+
315+ if u_arg != - 1
316+ args[u_arg] isa AbstractVector ||
317+ throw (ArgumentError (" argument at u_arg = $u_arg must be a Vector, got $(typeof (args[u_arg])) " ))
318+ end
319+
320+ u_argument_name = if u_arg == - 1
321+ generated_argument_name
322+ else
323+ i -> i == u_arg ? :___mtkunknowns___ : generated_argument_name (i)
324+ end
310325 # turn delayed unknowns into calls to the history function
311326 if wrap_delays
312327 param_arg = is_split (sys) ? MTKPARAMETERS_ARG : generated_argument_name (p_start)
@@ -343,7 +358,7 @@ Base.@nospecializeinfer function build_function_wrapper(
343358
344359 # assignments for reconstructing scalarized array symbolics
345360 if non_standard_param_layout
346- append! (assignments, array_variable_assignments (args... ; filter_vars = required_arrvars))
361+ append! (assignments, array_variable_assignments (args... ; filter_vars = required_arrvars, argument_name = u_argument_name ))
347362 else
348363 cached = check_mutable_cache (sys, ParameterArrayAssignments, ParameterArrayAssignments, nothing )
349364 if cached isa ParameterArrayAssignments
@@ -357,14 +372,16 @@ Base.@nospecializeinfer function build_function_wrapper(
357372 param_var_to_arridxs; buffer_offset = p_start - 1 , filter_vars = required_arrvars
358373 )
359374 )
360- other_assigns = array_variable_assignments (args... ; ignore_vars = keys (param_var_to_arridxs), filter_vars = required_arrvars)
375+ other_assigns = array_variable_assignments (args... ; ignore_vars = keys (param_var_to_arridxs), filter_vars = required_arrvars, argument_name = u_argument_name )
361376 append! (assignments, other_assigns)
362377 end
363378 append! (assignments, extra_assignments)
364379
365380 for (i, arg) in enumerate (args)
366381 # Make sure to use the proper names for arguments
367- args[i] = if symbolic_type (arg) == NotSymbolic () && arg isa AbstractArray
382+ args[i] = if u_arg != - 1 && i == u_arg
383+ DestructuredArgs (arg, MTKUNKNOWNS_ARG; create_bindings)
384+ elseif symbolic_type (arg) == NotSymbolic () && arg isa AbstractArray
368385 DestructuredArgs (arg, generated_argument_name (i); create_bindings)
369386 else
370387 arg
0 commit comments