diff --git a/src/lambdify.jl b/src/lambdify.jl index 77e2217..06a0d0d 100644 --- a/src/lambdify.jl +++ b/src/lambdify.jl @@ -43,34 +43,34 @@ __ZERO__(xs...) = 0 __HEAVISIDE__ = (a...) -> (a[1] < 0 ? 0 : (a[1] > 0 ? 1 : (length(a) > 1 ? a[2] : NaN))) # __SYMPY__ALL__, fn_map = Dict( - "Add" => :+, - "Sub" => :-, - "Mul" => :*, # :(SymPy.__PROD__) - "Div" => :/, - "Pow" => :^, - "re" => :real, - "im" => :imag, - "Abs" => :abs, - "Min" => :min, - "Max" => :max, - "Poly" => :identity, - "Piecewise" => :(SymPy._piecewise), - "Order" => :(SymPy.__ZERO__), # :(as...) -> 0, - "And" => :(SymPy.__ALL__), #:((as...) -> all(as)), #:(&), - "Or" => :(SymPy.__ANY__), #:((as...) -> any(as)), #:(|), - "Less" => :(<), - "LessThan" => :(<=), - "StrictLessThan" => :(<), - "Equal" => :(==), - "Equality" => :(==), - "Unequality" => :(!==), - "StrictGreaterThan" => :(>), - "GreaterThan" => :(>=), - "Greater" => :(>), + "Add" => :+, + "Sub" => :-, + "Mul" => :*, # :(SymPy.__PROD__) + "Div" => :/, + "Pow" => :^, + "re" => :real, + "im" => :imag, + "Abs" => :abs, + "Min" => :min, + "Max" => :max, + "Poly" => :identity, + "Piecewise" => :(SymPy._piecewise), + "Order" => :(SymPy.__ZERO__), # :(as...) -> 0, + "And" => :(SymPy.__ALL__), #:((as...) -> all(as)), #:(&), + "Or" => :(SymPy.__ANY__), #:((as...) -> any(as)), #:(|), + "Less" => :(<), + "LessThan" => :(<=), + "StrictLessThan" => :(<), + "Equal" => :(==), + "Equality" => :(==), + "Unequality" => :(!==), + "StrictGreaterThan" => :(>), + "GreaterThan" => :(>=), + "Greater" => :(>), "conjugate" => :conj, "atan2" => :atan, - "Heaviside" => :(SymPy.__HEAVISIDE__) - ) + "Heaviside" => :(SymPy.__HEAVISIDE__), +) map_fn(key, fn_map) = haskey(fn_map, key) ? fn_map[key] : Symbol(key) @@ -89,7 +89,7 @@ function walk_expression(ex; values=Dict(), fns=Dict()) return walk_expression(rhs(ex), values=values, fns=fns) end - if fn == "Symbol" || fn == "Dummy" + if fn == "Symbol" || fn == "Dummy" || fn == "IndexedBase" str_ex = string(ex) return get(vals_map, str_ex, Symbol(str_ex)) elseif fn in ["Integer" , "Float"] @@ -104,6 +104,8 @@ function walk_expression(ex; values=Dict(), fns=Dict()) return (val, walk_expression(cond, values=values, fns=fns)) elseif fn == "Tuple" return walk_expression.(Introspection.args(ex), values=values, fns=fns) + elseif fn == "Indexed" + return Expr(:ref, [walk_expression(a, values=values, fns=fns) for a in Introspection.args(ex)]...) elseif haskey(vals_map, fn) return vals_map[fn] end diff --git a/test/tests.jl b/test/tests.jl index 0d69552..2af11f4 100644 --- a/test/tests.jl +++ b/test/tests.jl @@ -809,3 +809,5 @@ end @test limit(ceil(x), x=>0, dir="+") != limit(ceil(x), x=>0, dir="-") @test limit(floor(x), x=>0, dir="+") != limit(floor(x), x=>0, dir="-") end + +@test SymPy.convert_expr(sympy.Indexed(sympy.IndexedBase(:x), 1, -2)) == :(x[1, -2])