Skip to content

Commit 7691a29

Browse files
Merge pull request #4508 from SciML/as/rm-extra-reshape-view
refactor: remove unnecessary `reshape(view(...))` in codegen
2 parents 963fb0b + 1ac0700 commit 7691a29

2 files changed

Lines changed: 29 additions & 7 deletions

File tree

lib/ModelingToolkitBase/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ SpecialFunctions = "1, 2"
168168
StaticArrays = "1.9.14"
169169
StochasticDiffEq = "6.82.0, 7"
170170
SymbolicIndexingInterface = "0.3.39"
171-
SymbolicUtils = "4.23.1"
171+
SymbolicUtils = "4.27"
172172
Symbolics = "7.18.1"
173173
UnPack = "0.1, 1.0"
174174
julia = "1.9"

lib/ModelingToolkitBase/src/systems/codegen_utils.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,14 @@ end
5858
return var_to_arridxs
5959
end
6060

61-
function array_variable_buffer_idxs_to_assignments(var_to_arridxs::Dict{SymbolicT, Vector{Tuple{Int, Int}}}; argument_name = generated_argument_name, buffer_offset = 0)
61+
function array_variable_buffer_idxs_to_assignments(
62+
var_to_arridxs::Dict{SymbolicT, Vector{Tuple{Int, Int}}};
63+
argument_name = generated_argument_name, buffer_offset = 0,
64+
filter_vars = nothing,
65+
)
6266
assignments = Assignment[]
6367
for (arrvar, idxs) in var_to_arridxs
68+
filter_vars === nothing || arrvar in filter_vars || continue
6469
# all elements of the array need to be present in `args` to form the
6570
# reconstructing assignment
6671
any(iszero first, idxs) && continue
@@ -126,9 +131,12 @@ reconstruct array variables if they are present scalarized in `args`.
126131
an argument to the generated function and returns the name of the argument in the
127132
generated function.
128133
"""
129-
function array_variable_assignments(args...; ignore_vars = Set{SymbolicT}(), argument_name = generated_argument_name, buffer_offset = 0)
134+
function array_variable_assignments(
135+
args...; ignore_vars = Set{SymbolicT}(), filter_vars = nothing,
136+
argument_name = generated_argument_name, buffer_offset = 0
137+
)
130138
var_to_arridxs = compute_array_variable_buffer_idxs(args; ignore_vars)
131-
return array_variable_buffer_idxs_to_assignments(var_to_arridxs; argument_name, buffer_offset)
139+
return array_variable_buffer_idxs_to_assignments(var_to_arridxs; argument_name, buffer_offset, filter_vars)
132140
end
133141

134142
"""
@@ -231,6 +239,10 @@ function should_invalidate_mutable_cache_entry(::Type{ParameterArrayAssignments}
231239
return haskey(patch, :ps)
232240
end
233241

242+
function find_arrvars_is_atomic(ex::SymbolicT)
243+
SU.default_is_atomic(ex) && Symbolics.isarraysymbolic(ex)
244+
end
245+
234246
"""
235247
$(TYPEDSIGNATURES)
236248
@@ -319,9 +331,19 @@ Base.@nospecializeinfer function build_function_wrapper(
319331
ir_info = get_ir_info(sys)
320332
expr = ir_info.obs_subber(expr)
321333

334+
ir = get_irstructure(sys)
335+
required_arrvars = Set{SymbolicT}()
336+
search_buffer = SU.IRStructureSearchBuffer(ir, required_arrvars)
337+
SU.search_variables!(search_buffer, expr; is_atomic = find_arrvars_is_atomic, recurse = !SU.default_is_atomic)
338+
# TODO: This is only required because `CacheWriter` has its body in `extra_assignments`. Rewrite
339+
# that to use `ArrayMaker` and remove this.
340+
for assign in extra_assignments
341+
SU.search_variables!(search_buffer, assign; is_atomic = find_arrvars_is_atomic, recurse = !SU.default_is_atomic)
342+
end
343+
322344
# assignments for reconstructing scalarized array symbolics
323345
if non_standard_param_layout
324-
append!(assignments, array_variable_assignments(args...))
346+
append!(assignments, array_variable_assignments(args...; filter_vars = required_arrvars))
325347
else
326348
cached = check_mutable_cache(sys, ParameterArrayAssignments, ParameterArrayAssignments, nothing)
327349
if cached isa ParameterArrayAssignments
@@ -332,10 +354,10 @@ Base.@nospecializeinfer function build_function_wrapper(
332354
end
333355
append!(
334356
assignments, array_variable_buffer_idxs_to_assignments(
335-
param_var_to_arridxs; buffer_offset = p_start - 1
357+
param_var_to_arridxs; buffer_offset = p_start - 1, filter_vars = required_arrvars
336358
)
337359
)
338-
other_assigns = array_variable_assignments(args...; ignore_vars = keys(param_var_to_arridxs))
360+
other_assigns = array_variable_assignments(args...; ignore_vars = keys(param_var_to_arridxs), filter_vars = required_arrvars)
339361
append!(assignments, other_assigns)
340362
end
341363
append!(assignments, extra_assignments)

0 commit comments

Comments
 (0)