diff --git a/Project.toml b/Project.toml index 22b20989..332b0e83 100644 --- a/Project.toml +++ b/Project.toml @@ -41,7 +41,7 @@ Optim = "1, 2" NLSolversBase = "7, 8" PrecompileTools = "1.2.1" Reexport = "1.2.2" -SymbolicUtils = "4" +SymbolicUtils = "4.1" Zygote = "0.7" julia = "1.10" Random = "1" diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index 7b41f767..2db4e07a 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -215,7 +215,13 @@ function bumper_kern!( op::F, cumulators::Tuple{Vararg{Any,degree}}, ::EvalOptions{true,true,early_exit} ) where {F,degree,early_exit} cumulator_1 = first(cumulators) - @turbo @. cumulator_1 = op(cumulators...) + + # Avoid `@turbo @.` here: older LoopVectorization versions (used by downgrade-compat) + # can error during macro expansion on vararg tuple construction. + @inbounds for j in eachindex(cumulator_1) + cumulator_1[j] = op(map(c -> c[j], cumulators)...) + end + return cumulator_1 end diff --git a/ext/DynamicExpressionsSymbolicUtilsExt.jl b/ext/DynamicExpressionsSymbolicUtilsExt.jl index 637faa94..b4d5cc55 100644 --- a/ext/DynamicExpressionsSymbolicUtilsExt.jl +++ b/ext/DynamicExpressionsSymbolicUtilsExt.jl @@ -98,19 +98,30 @@ function parse_tree_to_eqs( # Convert children to symbolic form sym_children = map(x -> parse_tree_to_eqs(x, operators, index_functions), children) - # Only a small subset of functions have symbolic methods in SymbolicUtils. - # For unsupported operators: - # - when `index_functions=true`, we encode them as function-like symbols so they - # can be round-tripped back to operators via their name/arity. - # - when `index_functions=false`, we throw a clear error, since attempting to - # construct a SymbolicUtils term headed by an arbitrary function object can - # fail with a MethodError. + # SymbolicUtils only defines methods for some Base/stdlib functions, but + # user-defined operators may still be traceable if they are written in terms + # of symbolic-friendly primitives (e.g. `pow2(x) = x*x`). + # + # Strategy: + # - when `index_functions=true`, we encode operators as function-like symbols so + # they can be round-tripped back to operators via their name/arity. + # - when `index_functions=false`, we attempt to *trace* the operator by calling + # it on symbolic children. If this fails with a MethodError, throw a clear + # error instead of a cryptic MethodError. if !(op ∈ SUPPORTED_OPS) if index_functions op = _sym_fn(Symbol(op), tree.degree) return subs_bad(op(sym_children...)) else - throw(error("Unsupported operation $(op) in SymbolicUtils conversion")) + traced = try + op(sym_children...) + catch e + if e isa MethodError + throw(error("Unsupported operation $(op) in SymbolicUtils conversion")) + end + rethrow() + end + return subs_bad(traced) end end diff --git a/test/test_symbolic_utils.jl b/test/test_symbolic_utils.jl index 08bfe034..8a3aa5a5 100644 --- a/test/test_symbolic_utils.jl +++ b/test/test_symbolic_utils.jl @@ -149,10 +149,47 @@ end expr_custom = parse_expression( :(myop(x)); unary_operators=[myop], binary_operators=[+, *, /], variable_names=["x"] ) - @test_throws ErrorException node_to_symbolic( + + # If the custom operator is traceable (i.e. it operates on symbolic inputs), + # `index_functions=false` should still work by tracing it. + eqn_custom_traced = node_to_symbolic( expr_custom, operators_custom; index_functions=false ) + expr_custom_traced_rt = symbolic_to_node( + eqn_custom_traced, operators_custom; variable_names=["x"] + ) + @test eval_expr(expr_custom_traced_rt, operators_custom, ["x"], X1) == [3.0] + + # If you want to preserve the custom operator itself for round-tripping, use + # `index_functions=true`. eqn_custom = node_to_symbolic(expr_custom, operators_custom; index_functions=true) expr_custom_rt = symbolic_to_node(eqn_custom, operators_custom; variable_names=["x"]) @test eval_expr(expr_custom_rt, operators_custom, ["x"], X1) == [3.0] + + # If a custom operator cannot be traced (e.g. no method for symbolic inputs), + # `index_functions=false` should throw a clear error rather than a MethodError. + float_only(x::Float64) = x + 1 + operators_float_only = OperatorEnum(; + unary_operators=(float_only,), binary_operators=(+, *, /) + ) + expr_float_only = parse_expression( + :(float_only(x)); + unary_operators=[float_only], + binary_operators=[+, *, /], + variable_names=["x"], + ) + @test_throws ErrorException node_to_symbolic( + expr_float_only, operators_float_only; index_functions=false + ) + + # Non-MethodError exceptions should be rethrown, so downstream code sees the + # original failure. + boom(x) = throw(ArgumentError("boom")) + operators_boom = OperatorEnum(; unary_operators=(boom,), binary_operators=(+, *, /)) + expr_boom = parse_expression( + :(boom(x)); unary_operators=[boom], binary_operators=[+, *, /], variable_names=["x"] + ) + @test_throws ArgumentError node_to_symbolic( + expr_boom, operators_boom; index_functions=false + ) end