Skip to content

Commit 8cc5c13

Browse files
committed
api: fix symmetric interp mode and add lots of tests
1 parent 5e25bfb commit 8cc5c13

9 files changed

Lines changed: 1333 additions & 78 deletions

File tree

devito/core/cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _normalize_kwargs(cls, **kwargs):
8787

8888
# Code generation options for derivatives
8989
o['expand'] = oo.pop('expand', cls.EXPAND)
90-
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
90+
o['interp-mode'] = oo.pop('interp-mode', cls.INTERP_MODE)
9191
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
9292
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
9393
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

devito/core/gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _normalize_kwargs(cls, **kwargs):
102102

103103
# Code generation options for derivatives
104104
o['expand'] = oo.pop('expand', cls.EXPAND)
105-
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
105+
o['interp-mode'] = oo.pop('interp-mode', cls.INTERP_MODE)
106106
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
107107
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
108108
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

devito/core/operator.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,23 @@ class BasicOperator(Operator):
125125
finite-difference derivatives.
126126
"""
127127

128-
MUL_FIRST = False
129-
"""
130-
When evaluating expressions location, prioritize multiplication
131-
operations.
128+
INTERP_MODE = 'direct'
129+
"""
130+
Interpolation mode used by ``Mul._eval_at`` when projecting a multi-factor
131+
expression onto a target staggered location:
132+
133+
* ``'direct'`` (default): each factor is shifted to ``func``'s location
134+
independently (``Function._eval_at`` per arg). Cheapest stencil; the
135+
mode to pick unless you need an explicitly self-adjoint discretization.
136+
* ``'symmetric'``: when every factor lives at a staggering different from
137+
``func``'s, the symmetric form ``I * (a * I^T * b)`` is built — all
138+
factors are gathered at the highest-priority "block" location via
139+
``I^T``, multiplied there, and the product is interpolated to ``func``
140+
via ``I``. Use this for operators whose continuous form decomposes as
141+
``I * A * I^T`` (e.g. the elastic stiffness ``σ = C ε``).
142+
143+
See ``examples/userapi/08_staggered_interp.ipynb`` for the maths and a
144+
worked elastic-stiffness example.
132145
"""
133146

134147
DERIV_COLLECT = True

devito/finite_differences/derivative.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def T(self):
472472

473473
return self._rebuild(transpose=adjoint)
474474

475-
def _eval_at(self, func, mul_first=False, **kwargs):
475+
def _eval_at(self, func, interp_mode='direct', **kwargs):
476476
"""
477477
Evaluates the derivative at the location of `func`. It is necessary for staggered
478478
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
@@ -522,7 +522,7 @@ def _eval_at(self, func, mul_first=False, **kwargs):
522522
return self._rebuild(self.expr, **rkw)
523523
args = [self.expr.func(*v) for v in mapper.values()]
524524
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
525-
args = [self._rebuild(a)._eval_at(func, mul_first=mul_first, **kwargs)
525+
args = [self._rebuild(a)._eval_at(func, interp_mode=interp_mode, **kwargs)
526526
for a in args]
527527
return self.expr.func(*args)
528528
elif self.expr.is_Mul:

devito/finite_differences/differentiable.py

Lines changed: 109 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import ChainMap
2+
from contextlib import suppress
23
from functools import cached_property, singledispatch
34
from itertools import product
45

@@ -185,8 +186,10 @@ def coefficients(self):
185186
return sorted(coefficients, key=key, reverse=True)[0]
186187

187188
def _eval_at(self, func, **kwargs):
188-
return self.func(*[getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs)
189-
for a in self.args])
189+
return self.func(*[
190+
getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs) # noqa: B023
191+
for a in self.args # false positive: lambda is invoked in-place
192+
])
190193

191194
def _subs(self, old, new, **hints):
192195
if old == self:
@@ -668,67 +671,61 @@ def _gather_for_diff(self):
668671
other = self.func(*other)._eval_at(highest_priority(self))
669672
return self.func(other, *derivs)
670673

671-
def _eval_at(self, func, mul_first=False, **kwargs):
672-
# Dont evaluate mul first
673-
if not mul_first:
674-
return super()._eval_at(func, mul_first=mul_first)
674+
def _eval_at(self, func, interp_mode='direct', **kwargs):
675+
"""
676+
Evaluate a Mul at the location of ``func``.
675677
676-
# Same staggering, no need to interpolate
677-
if self.staggered == func.staggered:
678-
return self
678+
Two modes:
679679
680-
# Get highest priority function for evaluation
681-
func0 = highest_priority(self, ref=func)
680+
- ``interp_mode='direct'`` (default): per-arg evaluation; each factor is
681+
independently evaluated at ``func``'s location via
682+
``Differentiable._eval_at``.
682683
683-
# Not a basic a*b*c... expression, expand
684-
if any(isinstance(f, DifferentiableOp) for f in self.args):
685-
return diffify(self._eval_expand_mul())._eval_at(func, mul_first=mul_first)
684+
- ``interp_mode='symmetric'``: when every Differentiable factor has a
685+
staggering different from ``func``'s, apply the ``I * (a * I^T * b)``
686+
form:
686687
687-
# Split Derivative and Differentiable args
688-
derivs, other = split(self.args, lambda e: isinstance(e, sympy.Derivative))
688+
1. Pick a ``block`` location (the highest-priority factor's
689+
staggering). Each factor not at ``block`` is brought there via
690+
``I^T`` (an explicit 0-order FD interpolation operator).
691+
Derivatives additionally set ``x0`` on their own derivative
692+
dimensions to ``func``'s indices.
693+
2. The product is formed at ``block``'s location.
694+
3. The whole product is interpolated to ``func`` via ``I`` (an
695+
explicit 0-order FD operator).
689696
690-
# Evaluate all at highest priority function
691-
if derivs:
692-
derivs = self.func(*[d._eval_at(func, mul_first=mul_first) for d in derivs])
693-
else:
694-
derivs = 1
695-
696-
if not other:
697-
return derivs
698-
expr = self.func(*other)
699-
700-
# Non differentiable expr (e.g., number)
701-
if not isinstance(expr, Differentiable):
702-
return self.func(derivs, expr)
703-
704-
# Evaluate expression at func_args
705-
print(f"\nEvaluating expr {expr} at func0 {func0} for func {func} from {self}")
706-
expr = Differentiable._eval_at(expr, func0, mul_first=False)
707-
708-
# Interpolate derivatives at func0
709-
x0 = {d: v for d, v in func0.indices_ref.getters.items()
710-
if not d.is_Time and v is not func.indices_ref.getters.get(d, d)}
711-
if x0 and not derivs == 1:
712-
print(f"Interpolating derivs {derivs} x0={derivs.x0} at {x0}")
713-
derivs = derivs.diff(*x0.keys(), deriv_order=(0,)*len(x0),
714-
fd_order=(self.interp_order,)*len(x0),
715-
x0=x0)
716-
newexpr = self.func(derivs, expr)
717-
718-
# Finally at func
719-
if not func.staggered == func0.staggered:
720-
x0_f = {d: v for d, v in func.indices_ref.getters.items()
721-
if not d.is_Time and v is not func0.indices_ref.getters.get(d)}
722-
if x0_f:
723-
print(f"Final interpolation of derivs {self.func(derivs, expr)} at func {x0_f}")
724-
return newexpr.diff(*x0_f.keys(), deriv_order=(0,)*len(x0_f),
725-
fd_order=(self.interp_order,)*len(x0_f),
726-
x0=x0_f)
697+
When the trigger does not hold (e.g. some factor already matches
698+
``func``'s staggering), we fall back to ``direct``.
699+
"""
700+
if interp_mode != 'symmetric':
701+
return super()._eval_at(func, **kwargs)
702+
703+
diff_args = [a for a in self.args if isinstance(a, Differentiable)]
704+
other_args = [a for a in self.args if not isinstance(a, Differentiable)]
705+
706+
# Symmetric form requires every Differentiable factor to differ from
707+
# func; otherwise direct evaluation is cleaner and equivalent.
708+
if len(diff_args) < 2 or \
709+
any(a.staggered == func.staggered for a in diff_args):
710+
return super()._eval_at(func, **kwargs)
711+
712+
block_indices = highest_priority(self).indices_ref
713+
714+
# Bring each factor to block's location (I^T where needed)
715+
new_factors = list(other_args)
716+
for a in diff_args:
717+
if isinstance(a, sympy.Derivative):
718+
source = _post_x0_indices(a, func)
719+
a = a._rebuild(x0={dim: func.indices_ref[dim] for dim in a.dims
720+
if dim in func.indices_ref.getters})
727721
else:
728-
return newexpr
729-
else:
730-
# Return the full expression with Derivatives
731-
return newexpr
722+
source = a.indices_ref
723+
new_factors.append(_interp_at(a, source, block_indices,
724+
self.interp_order))
725+
726+
# Final I from block's location to func
727+
return _interp_at(self.func(*new_factors), block_indices,
728+
func.indices_ref, self.interp_order)
732729

733730

734731
class Pow(DifferentiableOp, sympy.Pow):
@@ -1255,6 +1252,61 @@ def _diff2sympy(obj):
12551252
evalf_table[Pow] = evalf_table[sympy.Pow]
12561253

12571254

1255+
def _interp_mapper(source, target, dims):
1256+
"""
1257+
Build a ``{dim: target_index}`` mapper for dimensions in ``dims`` where
1258+
``source[dim]`` differs from ``target[dim]``.
1259+
1260+
``source`` and ``target`` are dict-like ``{dim: index_expr}`` (e.g. a plain
1261+
dict or a ``DimensionTuple``). Dimensions missing from either side are
1262+
skipped silently.
1263+
"""
1264+
mapper = {}
1265+
for d in dims:
1266+
try:
1267+
s = source[d]
1268+
t = target[d]
1269+
except (KeyError, IndexError):
1270+
continue
1271+
if s is not t:
1272+
mapper[d] = t
1273+
return mapper
1274+
1275+
1276+
def _interp_at(expr, source, target, interp_order):
1277+
"""
1278+
Build a symbolic 0-order FD interpolation operator on ``expr`` that maps
1279+
values from ``source`` indices to ``target`` indices, only on the
1280+
dimensions where the two locations differ.
1281+
"""
1282+
if not isinstance(expr, Differentiable):
1283+
return expr
1284+
mapper = _interp_mapper(source, target, expr.dimensions)
1285+
if not mapper:
1286+
return expr
1287+
return expr.diff(*mapper.keys(),
1288+
deriv_order=(0,) * len(mapper),
1289+
fd_order=(interp_order,) * len(mapper),
1290+
x0=mapper)
1291+
1292+
1293+
def _post_x0_indices(deriv, func):
1294+
"""
1295+
Conceptual indices of ``deriv`` after setting ``x0`` on its own derivative
1296+
dimensions to ``func``'s indices. Derivative dims take ``func``'s indices;
1297+
other dims keep the underlying expression's natural location (so that
1298+
``interp_for_fd`` does not introduce a spurious second shift).
1299+
"""
1300+
ref = {}
1301+
for dim in deriv.dimensions:
1302+
if dim in deriv.dims and dim in func.indices_ref.getters:
1303+
ref[dim] = func.indices_ref[dim]
1304+
else:
1305+
with suppress(KeyError):
1306+
ref[dim] = deriv.indices_ref[dim]
1307+
return ref
1308+
1309+
12581310
# Interpolation for finite differences
12591311
@singledispatch
12601312
def interp_for_fd(expr, x0, **kwargs):

devito/operator/operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _lower_exprs(cls, expressions, **kwargs):
341341
* Shift indices for domain alignment.
342342
"""
343343
expand = kwargs['options'].get('expand', True)
344-
mul_first = kwargs['options'].get('eval-mul-first', False)
344+
interp_mode = kwargs['options'].get('interp-mode', 'direct')
345345

346346
# Specialization is performed on unevaluated expressions
347347
expressions = cls._specialize_dsl(expressions, **kwargs)
@@ -352,7 +352,7 @@ def _lower_exprs(cls, expressions, **kwargs):
352352
# ModuloDimensions
353353
if not expand:
354354
expand = lambda d: d.is_Stepping
355-
expressions = flatten([i._evaluate(expand=expand, mul_first=mul_first)
355+
expressions = flatten([i._evaluate(expand=expand, interp_mode=interp_mode)
356356
for i in expressions])
357357

358358
# Scalarize the tensor equations, if any

devito/types/dense.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from devito.deprecations import deprecations
1616
from devito.exceptions import InvalidArgument
1717
from devito.finite_differences import Differentiable, generate_fd_shortcuts
18+
from devito.finite_differences.differentiable import _interp_mapper
1819
from devito.finite_differences.tools import fd_weights_registry
1920
from devito.logger import debug, warning
2021
from devito.mpi import MPI
@@ -1122,16 +1123,12 @@ def _eval_at(self, func, **kwargs):
11221123
if self.staggered == func.staggered or self.interp_order == 0:
11231124
return self
11241125

1125-
mapper = {}
1126-
for d in self.dimensions:
1127-
try:
1128-
if self.indices_ref[d] is not func.indices_ref[d]:
1129-
f_idx = func.indices_ref[d]._subs(func.dimensions[d], d)
1130-
mapper[self.indices_ref[d]] = f_idx
1131-
except KeyError:
1132-
pass
1133-
1134-
return self.subs(mapper)
1126+
# Dims where self and func indices differ -> {dim: func_idx}
1127+
diff = _interp_mapper(self.indices_ref, func.indices_ref, self.dimensions)
1128+
# Translate into a subs mapper {self_idx: func_idx} aligned on self's dims
1129+
subs_map = {self.indices_ref[d]: t._subs(func.dimensions[d], d)
1130+
for d, t in diff.items()}
1131+
return self.subs(subs_map)
11351132

11361133
@classmethod
11371134
def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):

0 commit comments

Comments
 (0)