Skip to content

Commit 3ea0c9c

Browse files
committed
fixed bug with retrieval of free symbols
1 parent 51ceb56 commit 3ea0c9c

4 files changed

Lines changed: 68 additions & 9 deletions

File tree

doc/source/changelog.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
Changelog
33
==========
44

5+
v3.4.3
6+
======
7+
8+
* Fixed bug with retrieval of free symbols from symbolic expressions.
9+
10+
511
v3.4.2
612
======
713

spb/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.4.2"
1+
__version__ = "3.4.3"

spb/utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from sympy import (
33
Tuple, sympify, Expr, Dummy, sin, cos, Symbol, Indexed, ImageSet,
44
FiniteSet, Basic, Float, Integer, Rational, Poly, fraction, exp,
5-
NumberSymbol
5+
NumberSymbol, IndexedBase
66
)
77
from sympy.vector import BaseScalar
88
from sympy.core.function import AppliedUndef
@@ -170,11 +170,35 @@ def _get_free_symbols(exprs):
170170
if all(callable(e) for e in exprs):
171171
return set()
172172

173-
free = set().union(*[e.atoms(Indexed) for e in exprs])
174-
free = free.union(*[e.atoms(AppliedUndef) for e in exprs])
175-
if len(free) > 0:
176-
return free
177-
return set().union(*[e.free_symbols for e in exprs])
173+
# NOTE:
174+
# 1. srepr(IndexedBase("a")) is "IndexedBase(Symbol('a'))"
175+
# So, if expr = IndexedBase("a")[0] + 1, it follows that
176+
# expr.free_symbols is {IndexedBase("a")[0], Symbol("a")}
177+
# This must be filtered to {IndexedBase("a")[0]}
178+
# 2. Let a = IndexedBase("a"). Even though as of sympy 1.14.0 it is
179+
# possible to write expressions like a + 1, for simplicity,
180+
# I don't allow them, because of Note 1, which would increase
181+
# complexity in this code.
182+
183+
undefined_func = set().union(*[e.atoms(AppliedUndef) for e in exprs])
184+
undefined_func_args = set().union(*[f.args for f in undefined_func])
185+
indexed_base = set().union(*[e.atoms(IndexedBase) for e in exprs])
186+
indexed_base_args = set().union(*[i.args for i in indexed_base])
187+
188+
# select all free symbols, be them instances of Symbol, Indexed
189+
# or the arguments of IndexedBase
190+
free_symbols = set().union(*[e.free_symbols for e in exprs])
191+
# remove instances of IndexedBase
192+
free_symbols = free_symbols.difference(indexed_base)
193+
# remove free symbols that are arguments of applied undef functions
194+
# it is unlikely that these symbols are being used as parameters as well.
195+
free_symbols = free_symbols.difference(undefined_func_args)
196+
# remove free symbols that are arguments of indexed base
197+
free_symbols = free_symbols.difference(indexed_base_args)
198+
199+
free = free_symbols.union(undefined_func)
200+
201+
return free
178202

179203

180204
def _check_arguments(args, nexpr, npar, **kwargs):

tests/test_utils.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
_create_missing_ranges, _plot_sympify,
1010
_validate_kwargs, prange, extract_solution,
1111
tf_to_control, tf_to_sympy, is_discrete_time, tf_find_time_delay,
12-
is_number
12+
is_number, _get_free_symbols
1313
)
1414
from sympy import (
15-
symbols, Expr, Tuple, Integer, sin, cos, Matrix,
15+
symbols, Expr, Tuple, Integer, sin, cos, Matrix, Function, IndexedBase,
1616
I, Polygon, solveset, FiniteSet, ImageSet, exp, Rational, Float, pi
1717
)
1818
from sympy.external import import_module
@@ -571,3 +571,32 @@ def test_tf_find_time_delay():
571571
)
572572
def test_is_number(num, expected):
573573
assert is_number(num) is expected
574+
575+
576+
def test_get_free_symbols():
577+
x, y, z, t = symbols("x, y, z, t")
578+
f = Function("f")(t)
579+
g = Function("f")(x)
580+
w = IndexedBase("w")
581+
582+
e = x + y + 1
583+
assert _get_free_symbols(e) == {x, y}
584+
585+
e = f + 1
586+
assert _get_free_symbols(e) == {f}
587+
588+
e = w[0] + 1
589+
assert _get_free_symbols(e) == {w[0]}
590+
591+
e = w + 1
592+
assert _get_free_symbols(e) == set()
593+
594+
e = x + y + z + f
595+
assert _get_free_symbols(e) == {x, y, z, f}
596+
597+
e = f + g
598+
assert _get_free_symbols(e) == {f, g}
599+
600+
e = f + g + y
601+
assert _get_free_symbols(e) == {f, g, y}
602+

0 commit comments

Comments
 (0)