|
2 | 2 | from sympy import ( |
3 | 3 | Tuple, sympify, Expr, Dummy, sin, cos, Symbol, Indexed, ImageSet, |
4 | 4 | FiniteSet, Basic, Float, Integer, Rational, Poly, fraction, exp, |
5 | | - NumberSymbol |
| 5 | + NumberSymbol, IndexedBase |
6 | 6 | ) |
7 | 7 | from sympy.vector import BaseScalar |
8 | 8 | from sympy.core.function import AppliedUndef |
@@ -170,11 +170,35 @@ def _get_free_symbols(exprs): |
170 | 170 | if all(callable(e) for e in exprs): |
171 | 171 | return set() |
172 | 172 |
|
173 | | - free = set().union(*[e.atoms(Indexed) for e in exprs]) |
174 | | - free = free.union(*[e.atoms(AppliedUndef) for e in exprs]) |
175 | | - if len(free) > 0: |
176 | | - return free |
177 | | - return set().union(*[e.free_symbols for e in exprs]) |
| 173 | + # NOTE: |
| 174 | + # 1. srepr(IndexedBase("a")) is "IndexedBase(Symbol('a'))" |
| 175 | + # So, if expr = IndexedBase("a")[0] + 1, it follows that |
| 176 | + # expr.free_symbols is {IndexedBase("a")[0], Symbol("a")} |
| 177 | + # This must be filtered to {IndexedBase("a")[0]} |
| 178 | + # 2. Let a = IndexedBase("a"). Even though as of sympy 1.14.0 it is |
| 179 | + # possible to write expressions like a + 1, for simplicity, |
| 180 | + # I don't allow them, because of Note 1, which would increase |
| 181 | + # complexity in this code. |
| 182 | + |
| 183 | + undefined_func = set().union(*[e.atoms(AppliedUndef) for e in exprs]) |
| 184 | + undefined_func_args = set().union(*[f.args for f in undefined_func]) |
| 185 | + indexed_base = set().union(*[e.atoms(IndexedBase) for e in exprs]) |
| 186 | + indexed_base_args = set().union(*[i.args for i in indexed_base]) |
| 187 | + |
| 188 | + # select all free symbols, be them instances of Symbol, Indexed |
| 189 | + # or the arguments of IndexedBase |
| 190 | + free_symbols = set().union(*[e.free_symbols for e in exprs]) |
| 191 | + # remove instances of IndexedBase |
| 192 | + free_symbols = free_symbols.difference(indexed_base) |
| 193 | + # remove free symbols that are arguments of applied undef functions |
| 194 | + # it is unlikely that these symbols are being used as parameters as well. |
| 195 | + free_symbols = free_symbols.difference(undefined_func_args) |
| 196 | + # remove free symbols that are arguments of indexed base |
| 197 | + free_symbols = free_symbols.difference(indexed_base_args) |
| 198 | + |
| 199 | + free = free_symbols.union(undefined_func) |
| 200 | + |
| 201 | + return free |
178 | 202 |
|
179 | 203 |
|
180 | 204 | def _check_arguments(args, nexpr, npar, **kwargs): |
|
0 commit comments