Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Optim = "1, 2"
NLSolversBase = "7, 8"
PrecompileTools = "1.2.1"
Reexport = "1.2.2"
SymbolicUtils = "4"
SymbolicUtils = "4.1"
Zygote = "0.7"
julia = "1.10"
Random = "1"
Expand Down
8 changes: 7 additions & 1 deletion ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,13 @@ function bumper_kern!(
op::F, cumulators::Tuple{Vararg{Any,degree}}, ::EvalOptions{true,true,early_exit}
) where {F,degree,early_exit}
cumulator_1 = first(cumulators)
@turbo @. cumulator_1 = op(cumulators...)

# Avoid `@turbo @.` here: older LoopVectorization versions (used by downgrade-compat)
# can error during macro expansion on vararg tuple construction.
@inbounds for j in eachindex(cumulator_1)
cumulator_1[j] = op(map(c -> c[j], cumulators)...)
end

return cumulator_1
end

Expand Down
27 changes: 19 additions & 8 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,30 @@ function parse_tree_to_eqs(
# Convert children to symbolic form
sym_children = map(x -> parse_tree_to_eqs(x, operators, index_functions), children)

# Only a small subset of functions have symbolic methods in SymbolicUtils.
# For unsupported operators:
# - when `index_functions=true`, we encode them as function-like symbols so they
# can be round-tripped back to operators via their name/arity.
# - when `index_functions=false`, we throw a clear error, since attempting to
# construct a SymbolicUtils term headed by an arbitrary function object can
# fail with a MethodError.
# SymbolicUtils only defines methods for some Base/stdlib functions, but
# user-defined operators may still be traceable if they are written in terms
# of symbolic-friendly primitives (e.g. `pow2(x) = x*x`).
#
# Strategy:
# - when `index_functions=true`, we encode operators as function-like symbols so
# they can be round-tripped back to operators via their name/arity.
# - when `index_functions=false`, we attempt to *trace* the operator by calling
# it on symbolic children. If this fails with a MethodError, throw a clear
# error instead of a cryptic MethodError.
if !(op ∈ SUPPORTED_OPS)
if index_functions
op = _sym_fn(Symbol(op), tree.degree)
return subs_bad(op(sym_children...))
else
throw(error("Unsupported operation $(op) in SymbolicUtils conversion"))
traced = try
op(sym_children...)
catch e
if e isa MethodError
throw(error("Unsupported operation $(op) in SymbolicUtils conversion"))
end
rethrow()
end
Comment on lines +116 to +123
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be type unstable? i.e., might need to do

local traced
try
    traced = op...
catch e
    ...
end

Can you check if this is type unstable? Please show the inferred type in each instance (REPL output with Cthulhu.jl)

@MilesCranmerBot

return subs_bad(traced)
end
end

Expand Down
39 changes: 38 additions & 1 deletion test/test_symbolic_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,47 @@ end
expr_custom = parse_expression(
:(myop(x)); unary_operators=[myop], binary_operators=[+, *, /], variable_names=["x"]
)
@test_throws ErrorException node_to_symbolic(

# If the custom operator is traceable (i.e. it operates on symbolic inputs),
# `index_functions=false` should still work by tracing it.
eqn_custom_traced = node_to_symbolic(
expr_custom, operators_custom; index_functions=false
)
expr_custom_traced_rt = symbolic_to_node(
eqn_custom_traced, operators_custom; variable_names=["x"]
)
@test eval_expr(expr_custom_traced_rt, operators_custom, ["x"], X1) == [3.0]

# If you want to preserve the custom operator itself for round-tripping, use
# `index_functions=true`.
eqn_custom = node_to_symbolic(expr_custom, operators_custom; index_functions=true)
expr_custom_rt = symbolic_to_node(eqn_custom, operators_custom; variable_names=["x"])
@test eval_expr(expr_custom_rt, operators_custom, ["x"], X1) == [3.0]

# If a custom operator cannot be traced (e.g. no method for symbolic inputs),
# `index_functions=false` should throw a clear error rather than a MethodError.
float_only(x::Float64) = x + 1
operators_float_only = OperatorEnum(;
unary_operators=(float_only,), binary_operators=(+, *, /)
)
expr_float_only = parse_expression(
:(float_only(x));
unary_operators=[float_only],
binary_operators=[+, *, /],
variable_names=["x"],
)
@test_throws ErrorException node_to_symbolic(
expr_float_only, operators_float_only; index_functions=false
)

# Non-MethodError exceptions should be rethrown, so downstream code sees the
# original failure.
boom(x) = throw(ArgumentError("boom"))
operators_boom = OperatorEnum(; unary_operators=(boom,), binary_operators=(+, *, /))
expr_boom = parse_expression(
:(boom(x)); unary_operators=[boom], binary_operators=[+, *, /], variable_names=["x"]
)
@test_throws ArgumentError node_to_symbolic(
expr_boom, operators_boom; index_functions=false
)
end
Loading