|
1 | 1 | from collections import ChainMap |
| 2 | +from contextlib import suppress |
2 | 3 | from functools import cached_property, singledispatch |
3 | 4 | from itertools import product |
4 | 5 |
|
@@ -185,8 +186,10 @@ def coefficients(self): |
185 | 186 | return sorted(coefficients, key=key, reverse=True)[0] |
186 | 187 |
|
187 | 188 | def _eval_at(self, func, **kwargs): |
188 | | - return self.func(*[getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs) |
189 | | - for a in self.args]) |
| 189 | + return self.func(*[ |
| 190 | + getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs) # noqa: B023 |
| 191 | + for a in self.args # false positive: lambda is invoked in-place |
| 192 | + ]) |
190 | 193 |
|
191 | 194 | def _subs(self, old, new, **hints): |
192 | 195 | if old == self: |
@@ -668,67 +671,61 @@ def _gather_for_diff(self): |
668 | 671 | other = self.func(*other)._eval_at(highest_priority(self)) |
669 | 672 | return self.func(other, *derivs) |
670 | 673 |
|
671 | | - def _eval_at(self, func, mul_first=False, **kwargs): |
672 | | - # Dont evaluate mul first |
673 | | - if not mul_first: |
674 | | - return super()._eval_at(func, mul_first=mul_first) |
| 674 | + def _eval_at(self, func, interp_mode='direct', **kwargs): |
| 675 | + """ |
| 676 | + Evaluate a Mul at the location of ``func``. |
675 | 677 |
|
676 | | - # Same staggering, no need to interpolate |
677 | | - if self.staggered == func.staggered: |
678 | | - return self |
| 678 | + Two modes: |
679 | 679 |
|
680 | | - # Get highest priority function for evaluation |
681 | | - func0 = highest_priority(self, ref=func) |
| 680 | + - ``interp_mode='direct'`` (default): per-arg evaluation; each factor is |
| 681 | + independently evaluated at ``func``'s location via |
| 682 | + ``Differentiable._eval_at``. |
682 | 683 |
|
683 | | - # Not a basic a*b*c... expression, expand |
684 | | - if any(isinstance(f, DifferentiableOp) for f in self.args): |
685 | | - return diffify(self._eval_expand_mul())._eval_at(func, mul_first=mul_first) |
| 684 | + - ``interp_mode='symmetric'``: when every Differentiable factor has a |
| 685 | + staggering different from ``func``'s, apply the ``I * (a * I^T * b)`` |
| 686 | + form: |
686 | 687 |
|
687 | | - # Split Derivative and Differentiable args |
688 | | - derivs, other = split(self.args, lambda e: isinstance(e, sympy.Derivative)) |
| 688 | + 1. Pick a ``block`` location (the highest-priority factor's |
| 689 | + staggering). Each factor not at ``block`` is brought there via |
| 690 | + ``I^T`` (an explicit 0-order FD interpolation operator). |
| 691 | + Derivatives additionally set ``x0`` on their own derivative |
| 692 | + dimensions to ``func``'s indices. |
| 693 | + 2. The product is formed at ``block``'s location. |
| 694 | + 3. The whole product is interpolated to ``func`` via ``I`` (an |
| 695 | + explicit 0-order FD operator). |
689 | 696 |
|
690 | | - # Evaluate all at highest priority function |
691 | | - if derivs: |
692 | | - derivs = self.func(*[d._eval_at(func, mul_first=mul_first) for d in derivs]) |
693 | | - else: |
694 | | - derivs = 1 |
695 | | - |
696 | | - if not other: |
697 | | - return derivs |
698 | | - expr = self.func(*other) |
699 | | - |
700 | | - # Non differentiable expr (e.g., number) |
701 | | - if not isinstance(expr, Differentiable): |
702 | | - return self.func(derivs, expr) |
703 | | - |
704 | | - # Evaluate expression at func_args |
705 | | - print(f"\nEvaluating expr {expr} at func0 {func0} for func {func} from {self}") |
706 | | - expr = Differentiable._eval_at(expr, func0, mul_first=False) |
707 | | - |
708 | | - # Interpolate derivatives at func0 |
709 | | - x0 = {d: v for d, v in func0.indices_ref.getters.items() |
710 | | - if not d.is_Time and v is not func.indices_ref.getters.get(d, d)} |
711 | | - if x0 and not derivs == 1: |
712 | | - print(f"Interpolating derivs {derivs} x0={derivs.x0} at {x0}") |
713 | | - derivs = derivs.diff(*x0.keys(), deriv_order=(0,)*len(x0), |
714 | | - fd_order=(self.interp_order,)*len(x0), |
715 | | - x0=x0) |
716 | | - newexpr = self.func(derivs, expr) |
717 | | - |
718 | | - # Finally at func |
719 | | - if not func.staggered == func0.staggered: |
720 | | - x0_f = {d: v for d, v in func.indices_ref.getters.items() |
721 | | - if not d.is_Time and v is not func0.indices_ref.getters.get(d)} |
722 | | - if x0_f: |
723 | | - print(f"Final interpolation of derivs {self.func(derivs, expr)} at func {x0_f}") |
724 | | - return newexpr.diff(*x0_f.keys(), deriv_order=(0,)*len(x0_f), |
725 | | - fd_order=(self.interp_order,)*len(x0_f), |
726 | | - x0=x0_f) |
| 697 | + When the trigger does not hold (e.g. some factor already matches |
| 698 | + ``func``'s staggering), we fall back to ``direct``. |
| 699 | + """ |
| 700 | + if interp_mode != 'symmetric': |
| 701 | + return super()._eval_at(func, **kwargs) |
| 702 | + |
| 703 | + diff_args = [a for a in self.args if isinstance(a, Differentiable)] |
| 704 | + other_args = [a for a in self.args if not isinstance(a, Differentiable)] |
| 705 | + |
| 706 | + # Symmetric form requires every Differentiable factor to differ from |
| 707 | + # func; otherwise direct evaluation is cleaner and equivalent. |
| 708 | + if len(diff_args) < 2 or \ |
| 709 | + any(a.staggered == func.staggered for a in diff_args): |
| 710 | + return super()._eval_at(func, **kwargs) |
| 711 | + |
| 712 | + block_indices = highest_priority(self).indices_ref |
| 713 | + |
| 714 | + # Bring each factor to block's location (I^T where needed) |
| 715 | + new_factors = list(other_args) |
| 716 | + for a in diff_args: |
| 717 | + if isinstance(a, sympy.Derivative): |
| 718 | + source = _post_x0_indices(a, func) |
| 719 | + a = a._rebuild(x0={dim: func.indices_ref[dim] for dim in a.dims |
| 720 | + if dim in func.indices_ref.getters}) |
727 | 721 | else: |
728 | | - return newexpr |
729 | | - else: |
730 | | - # Return the full expression with Derivatives |
731 | | - return newexpr |
| 722 | + source = a.indices_ref |
| 723 | + new_factors.append(_interp_at(a, source, block_indices, |
| 724 | + self.interp_order)) |
| 725 | + |
| 726 | + # Final I from block's location to func |
| 727 | + return _interp_at(self.func(*new_factors), block_indices, |
| 728 | + func.indices_ref, self.interp_order) |
732 | 729 |
|
733 | 730 |
|
734 | 731 | class Pow(DifferentiableOp, sympy.Pow): |
@@ -1255,6 +1252,61 @@ def _diff2sympy(obj): |
1255 | 1252 | evalf_table[Pow] = evalf_table[sympy.Pow] |
1256 | 1253 |
|
1257 | 1254 |
|
| 1255 | +def _interp_mapper(source, target, dims): |
| 1256 | + """ |
| 1257 | + Build a ``{dim: target_index}`` mapper for dimensions in ``dims`` where |
| 1258 | + ``source[dim]`` differs from ``target[dim]``. |
| 1259 | +
|
| 1260 | + ``source`` and ``target`` are dict-like ``{dim: index_expr}`` (e.g. a plain |
| 1261 | + dict or a ``DimensionTuple``). Dimensions missing from either side are |
| 1262 | + skipped silently. |
| 1263 | + """ |
| 1264 | + mapper = {} |
| 1265 | + for d in dims: |
| 1266 | + try: |
| 1267 | + s = source[d] |
| 1268 | + t = target[d] |
| 1269 | + except (KeyError, IndexError): |
| 1270 | + continue |
| 1271 | + if s is not t: |
| 1272 | + mapper[d] = t |
| 1273 | + return mapper |
| 1274 | + |
| 1275 | + |
| 1276 | +def _interp_at(expr, source, target, interp_order): |
| 1277 | + """ |
| 1278 | + Build a symbolic 0-order FD interpolation operator on ``expr`` that maps |
| 1279 | + values from ``source`` indices to ``target`` indices, only on the |
| 1280 | + dimensions where the two locations differ. |
| 1281 | + """ |
| 1282 | + if not isinstance(expr, Differentiable): |
| 1283 | + return expr |
| 1284 | + mapper = _interp_mapper(source, target, expr.dimensions) |
| 1285 | + if not mapper: |
| 1286 | + return expr |
| 1287 | + return expr.diff(*mapper.keys(), |
| 1288 | + deriv_order=(0,) * len(mapper), |
| 1289 | + fd_order=(interp_order,) * len(mapper), |
| 1290 | + x0=mapper) |
| 1291 | + |
| 1292 | + |
| 1293 | +def _post_x0_indices(deriv, func): |
| 1294 | + """ |
| 1295 | + Conceptual indices of ``deriv`` after setting ``x0`` on its own derivative |
| 1296 | + dimensions to ``func``'s indices. Derivative dims take ``func``'s indices; |
| 1297 | + other dims keep the underlying expression's natural location (so that |
| 1298 | + ``interp_for_fd`` does not introduce a spurious second shift). |
| 1299 | + """ |
| 1300 | + ref = {} |
| 1301 | + for dim in deriv.dimensions: |
| 1302 | + if dim in deriv.dims and dim in func.indices_ref.getters: |
| 1303 | + ref[dim] = func.indices_ref[dim] |
| 1304 | + else: |
| 1305 | + with suppress(KeyError): |
| 1306 | + ref[dim] = deriv.indices_ref[dim] |
| 1307 | + return ref |
| 1308 | + |
| 1309 | + |
1258 | 1310 | # Interpolation for finite differences |
1259 | 1311 | @singledispatch |
1260 | 1312 | def interp_for_fd(expr, x0, **kwargs): |
|
0 commit comments