Skip to content

Commit 870a160

Browse files
authored
Merge pull request #2788 from devitocodes/strict-mul-eval-at
api: add mul interp mode
2 parents 347d2c7 + cffdd95 commit 870a160

20 files changed

Lines changed: 2136 additions & 52 deletions

devito/core/cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def _normalize_kwargs(cls, **kwargs):
111111
)
112112

113113
kwargs['options'].update(o)
114+
kwargs['sym_options'] = cls._normalize_sym_kwargs(**kwargs)
114115

115116
return kwargs
116117

devito/core/gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def _normalize_kwargs(cls, **kwargs):
121121
)
122122

123123
kwargs['options'].update(o)
124+
kwargs['sym_options'] = cls._normalize_sym_kwargs(**kwargs)
124125

125126
return kwargs
126127

devito/core/operator.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,30 @@ class BasicOperator(Operator):
171171
The target language constructor, to be specified by subclasses.
172172
"""
173173

174+
# ------------------------------------------------------------------
175+
# Symbolic-level option defaults (`sym_opt`).
176+
# These steer mathematical choices made during expression lowering,
177+
# *not* code generation or performance. They are kept separate from
178+
# the `opt` options above to keep the two concerns distinct.
179+
# ------------------------------------------------------------------
180+
181+
INTERP_MODE = 'direct'
182+
"""
183+
Default for the `sym_opt={'interp-mode': ...}` option. Controls how
184+
a product of fields living at different staggered locations is mapped
185+
onto a target location:
186+
187+
* `'direct'` (default): each factor is interpolated to the target
188+
independently. Cheapest stencil.
189+
* `'symmetric'`: factors are first gathered at a common "block"
190+
location, multiplied there, and the result is interpolated once to
191+
the target. Preserves the `I A I^T` matrix structure, so the
192+
discrete operator stays self-adjoint when the continuous one is
193+
(e.g. the elastic stiffness `sigma = C eps`).
194+
195+
See `examples/userapi/08_staggered_interp.ipynb` for a worked example.
196+
"""
197+
174198
@classmethod
175199
def _normalize_kwargs(cls, **kwargs):
176200
# Will be populated with dummy values; this method is actually overridden
@@ -188,12 +212,30 @@ def _normalize_kwargs(cls, **kwargs):
188212
)
189213

190214
kwargs['options'].update(o)
215+
kwargs['sym_options'] = cls._normalize_sym_kwargs(**kwargs)
191216

192217
return kwargs
193218

219+
@classmethod
220+
def _normalize_sym_kwargs(cls, **kwargs):
221+
"""
222+
Fill in defaults and validate keys for the `sym_opt` dict passed to
223+
the Operator. Returns the normalized `sym_options` dict.
224+
"""
225+
so = dict(kwargs.get('sym_options', {}))
226+
out = {'interp-mode': so.pop('interp-mode', cls.INTERP_MODE)}
227+
228+
if so:
229+
raise InvalidOperator(
230+
f'Unrecognized symbolic options: [{", ".join(list(so))}]'
231+
)
232+
233+
return out
234+
194235
@classmethod
195236
def _check_kwargs(cls, **kwargs):
196237
oo = kwargs['options']
238+
so = kwargs['sym_options']
197239

198240
if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES:
199241
raise InvalidOperator(f"Unsupported MPI mode `{oo['mpi']}`")
@@ -209,6 +251,9 @@ def _check_kwargs(cls, **kwargs):
209251
if oo['errctl'] not in (None, False, 'basic', 'max'):
210252
raise InvalidOperator("Illegal `errctl` value")
211253

254+
if so['interp-mode'] not in ('direct', 'symmetric'):
255+
raise InvalidOperator("Illegal `interp-mode` value")
256+
212257
def _autotune(self, args, setup):
213258
if setup in [False, 'off']:
214259
return args

devito/finite_differences/derivative.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,14 @@ class Derivative(sympy.Derivative, Differentiable, Pickable):
8989
evaluation are `x0`, `fd_order` and `side`.
9090
"""
9191

92-
_fd_priority = 3
92+
@cached_property
93+
def _fd_priority(self):
94+
# A Derivative inherits the priority of its underlying expression, so
95+
# that `highest_priority(C*v.dx)` and `highest_priority((C*v).dx)`
96+
# agree on the gather location and the two gathering paths
97+
# (`_gather_for_diff` and `Mul._eval_at(interp_mode='symmetric')`)
98+
# produce consistent answers.
99+
return getattr(self.expr, '_fd_priority', 0)
93100

94101
__rargs__ = ('expr', '*dims')
95102
__rkwargs__ = ('side', 'deriv_order', 'fd_order', 'transpose', '_ppsubs',
@@ -472,7 +479,7 @@ def T(self):
472479

473480
return self._rebuild(transpose=adjoint)
474481

475-
def _eval_at(self, func):
482+
def _eval_at(self, func, interp_mode='direct', **kwargs):
476483
"""
477484
Evaluates the derivative at the location of `func`. It is necessary for staggered
478485
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
@@ -525,7 +532,8 @@ def _eval_at(self, func):
525532
return self._rebuild(self.expr, **rkw)
526533
args = [self.expr.func(*v) for v in mapper.values()]
527534
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
528-
args = [self._rebuild(a)._eval_at(func) for a in args]
535+
args = [self._rebuild(a)._eval_at(func, interp_mode=interp_mode, **kwargs)
536+
for a in args]
529537
return self.expr.func(*args)
530538
elif self.expr.is_Mul:
531539
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear

devito/finite_differences/differentiable.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# Moved in 1.13
1616
from sympy.core.basic import ordering_of_classes
1717

18+
from devito.finite_differences.interpolation import interp_at, post_x0_indices
1819
from devito.finite_differences.tools import coeff_priority, make_shift_x0
1920
from devito.logger import warning
2021
from devito.tools import (
@@ -140,10 +141,6 @@ def indices_ref(self):
140141
return DimensionTuple(*self.dimensions, getters=self.dimensions)
141142
return highest_priority(self).indices_ref
142143

143-
@cached_property
144-
def is_Staggered(self):
145-
return any([getattr(i, 'is_Staggered', False) for i in self._args_diff])
146-
147144
@cached_property
148145
def is_TimeDependent(self):
149146
return any(i.is_Time for i in self.dimensions)
@@ -184,13 +181,11 @@ def coefficients(self):
184181
key = lambda x: coeff_priority.get(x, -1)
185182
return sorted(coefficients, key=key, reverse=True)[0]
186183

187-
def _eval_at(self, func):
188-
if not func.is_Staggered:
189-
# Cartesian grid, do no waste time
190-
return self
184+
def _eval_at(self, func, **kwargs):
191185
return self.func(*[
192-
getattr(a, '_eval_at', lambda x: a)(func) for a in self.args # noqa: B023
193-
]) # false positive
186+
getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs) # noqa: B023
187+
for a in self.args # false positive: lambda is invoked in-place
188+
])
194189

195190
def _subs(self, old, new, **hints):
196191
if old == self:
@@ -669,6 +664,63 @@ def _gather_for_diff(self):
669664
other = self.func(*other)._eval_at(highest_priority(self))
670665
return self.func(other, *derivs)
671666

667+
def _eval_at(self, func, interp_mode='direct', **kwargs):
668+
"""
669+
Evaluate a Mul at the location of `func`.
670+
671+
Two modes:
672+
673+
- `interp_mode='direct'` (default): per-arg evaluation; each factor is
674+
independently evaluated at `func`'s location via
675+
`Differentiable._eval_at`.
676+
677+
- `interp_mode='symmetric'`: when every Differentiable factor has a
678+
staggering different from `func`'s, apply the `I * (a * I^T * b)`
679+
form:
680+
681+
1. Pick a `block` location -- the highest-priority factor's
682+
staggering (NODE is the highest priority, so coefficient-like
683+
NODE factors win, as in the `I * C * I^T` elastic stiffness
684+
pattern). Each factor not at the block is brought there via
685+
`I^T` (an explicit 0-order FD interpolation operator).
686+
Derivatives additionally set `x0` on their own derivative
687+
dimensions to `func`'s indices.
688+
2. The product is formed at `block`'s location.
689+
3. The whole product is interpolated to `func` via `I` (an
690+
explicit 0-order FD operator).
691+
692+
When the trigger does not hold (e.g. some factor already matches
693+
`func`'s staggering), we fall back to `direct`.
694+
"""
695+
if interp_mode != 'symmetric':
696+
return super()._eval_at(func, **kwargs)
697+
698+
diff, other = split(self.args, lambda a: isinstance(a, Differentiable))
699+
700+
# Symmetric form requires every Differentiable factor to differ from
701+
# func; otherwise direct evaluation is cleaner and equivalent.
702+
if len(diff) < 2 or \
703+
any(a.staggered == func.staggered for a in diff):
704+
return super()._eval_at(func, **kwargs)
705+
706+
block_indices = highest_priority(self).indices_ref
707+
708+
# Bring each factor to block's location (I^T where needed)
709+
new_factors = list(other)
710+
for a in diff:
711+
if isinstance(a, sympy.Derivative):
712+
source = post_x0_indices(a, func)
713+
a = a._rebuild(x0={dim: func.indices_ref[dim] for dim in a.dims
714+
if dim in func.indices_ref.getters})
715+
else:
716+
source = a.indices_ref
717+
new_factors.append(interp_at(a, source, block_indices,
718+
self.interp_order))
719+
720+
# Final I from block's location to func
721+
return interp_at(self.func(*new_factors), block_indices,
722+
func.indices_ref, self.interp_order)
723+
672724

673725
class Pow(DifferentiableOp, sympy.Pow):
674726
_fd_priority = 0
@@ -1020,7 +1072,7 @@ def _subs(self, old, new, **hints):
10201072

10211073
class DiffDerivative(IndexDerivative, DifferentiableOp):
10221074

1023-
def _eval_at(self, func):
1075+
def _eval_at(self, func, **kwargs):
10241076
# Like EvalDerivative, a DiffDerivative must have already been evaluated
10251077
# at a valid x0 and should not be re-evaluated at a different location
10261078
return self
@@ -1074,7 +1126,7 @@ def _new_rawargs(self, *args, **kwargs):
10741126
kwargs.pop('is_commutative', None)
10751127
return self.func(*args, **kwargs)
10761128

1077-
def _eval_at(self, func):
1129+
def _eval_at(self, func, **kwargs):
10781130
# An EvalDerivative must have already been evaluated at a valid x0
10791131
# and should not be re-evaluated at a different location
10801132
return self
@@ -1092,7 +1144,7 @@ class diffify:
10921144
10931145
Notes
10941146
-----
1095-
The name "diffify" stems from SymPy's "simpify", which has an analogous task --
1147+
The name "diffify" stems from SymPy's "simplify", which has an analogous task --
10961148
converting all arguments into SymPy core objects.
10971149
"""
10981150

@@ -1240,7 +1292,7 @@ def _(expr, x0, **kwargs):
12401292
@interp_for_fd.register(AbstractFunction)
12411293
def _(expr, x0, **kwargs):
12421294
x0_expr = {d: v for d, v in x0.items() if v.has(d)
1243-
and expr.indices[d] is not v}
1295+
and expr.indices.get(d, v) is not v}
12441296
if x0_expr:
12451297
return expr.subs({expr.indices[d]: v for d, v in x0_expr.items()})
12461298
else:
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from contextlib import suppress
2+
3+
__all__ = ['interp_at', 'interp_mapper', 'post_x0_indices']
4+
5+
6+
def interp_mapper(source, target, dims):
7+
"""
8+
Build a `{dim: target_index}` mapper for dimensions in `dims` where
9+
`source[dim]` differs from `target[dim]`.
10+
11+
`source` and `target` are dict-like `{dim: index_expr}` (e.g. a plain
12+
dict or a `DimensionTuple`). Dimensions missing from either side are
13+
skipped silently.
14+
"""
15+
mapper = {}
16+
for d in dims:
17+
try:
18+
s = source[d]
19+
t = target[d]
20+
except (KeyError, IndexError):
21+
continue
22+
if s is not t:
23+
mapper[d] = t
24+
return mapper
25+
26+
27+
def interp_at(expr, source, target, interp_order):
28+
"""
29+
Build a symbolic 0-order FD interpolation operator on `expr` that maps
30+
values from `source` indices to `target` indices, only on the
31+
dimensions where the two locations differ.
32+
"""
33+
from devito.finite_differences.differentiable import Differentiable
34+
35+
if not isinstance(expr, Differentiable):
36+
return expr
37+
38+
mapper = interp_mapper(source, target, expr.dimensions)
39+
if not mapper:
40+
return expr
41+
42+
return expr.diff(*mapper.keys(),
43+
deriv_order=(0,) * len(mapper),
44+
fd_order=(interp_order,) * len(mapper),
45+
x0=mapper)
46+
47+
48+
def post_x0_indices(deriv, func):
49+
"""
50+
Conceptual indices of `deriv` after setting `x0` on its own derivative
51+
dimensions to `func`'s indices. Derivative dims take `func`'s indices;
52+
other dims keep the underlying expression's natural location (so that
53+
`interp_for_fd` does not introduce a spurious second shift).
54+
"""
55+
ref = {}
56+
for dim in deriv.dimensions:
57+
if dim in deriv.dims and dim in func.dimensions:
58+
ref[dim] = func.indices_ref[dim]
59+
else:
60+
with suppress(KeyError):
61+
ref[dim] = deriv.indices_ref[dim]
62+
return ref

devito/finite_differences/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
292292
o_min = int(np.ceil(mid - r)) + side.val
293293
o_max = int(np.floor(mid + r)) + side.val
294294
if o_max == o_min:
295-
if dim.is_Time or not expr.is_Staggered:
295+
if dim.is_Time or not bool(expr.staggered):
296296
o_max += 1
297297
else:
298298
o_min -= 1

devito/operator/operator.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,16 @@ class Operator(Callable):
6868
Symbolic substitutions to be applied to ``expressions``.
6969
* opt : str
7070
The performance optimization level. Defaults to ``configuration['opt']``.
71+
* sym_opt : dict
72+
Symbolic-level options controlling mathematical choices made during
73+
expression lowering (e.g. how staggered multi-factor products are
74+
interpolated). Distinct from ``opt``, which controls code generation
75+
and performance. Accepted keys:
76+
77+
- ``'interp-mode'`` (``'direct'`` | ``'symmetric'``): selects the
78+
interpolation strategy used by ``Mul._eval_at`` when projecting a
79+
multi-factor expression onto a target staggered location. See the
80+
tutorial at ``examples/userapi/08_staggered_interp.ipynb``.
7181
* language : str
7282
The target language for shared-memory parallelism. Defaults to
7383
``configuration['language']``.
@@ -235,6 +245,7 @@ def _build(cls, expressions, **kwargs):
235245
# Potentially required for lazily allocated Functions
236246
op._mode = kwargs['mode']
237247
op._options = kwargs['options']
248+
op._sym_options = kwargs['sym_options']
238249
op._allocator = kwargs['allocator']
239250
op._platform = kwargs['platform']
240251

@@ -342,6 +353,7 @@ def _lower_exprs(cls, expressions, **kwargs):
342353
* Shift indices for domain alignment.
343354
"""
344355
expand = kwargs['options'].get('expand', True)
356+
interp_mode = kwargs.get('sym_options', {}).get('interp-mode', 'direct')
345357

346358
# Specialization is performed on unevaluated expressions
347359
expressions = cls._specialize_dsl(expressions, **kwargs)
@@ -352,7 +364,8 @@ def _lower_exprs(cls, expressions, **kwargs):
352364
# ModuloDimensions
353365
if not expand:
354366
expand = lambda d: d.is_Stepping
355-
expressions = flatten([i._evaluate(expand=expand) for i in expressions])
367+
expressions = flatten([i._evaluate(expand=expand, interp_mode=interp_mode)
368+
for i in expressions])
356369

357370
# Scalarize the tensor equations, if any
358371
expressions = [j for i in expressions for j in i._flatten]
@@ -1661,6 +1674,12 @@ def parse_kwargs(**kwargs):
16611674
mode = 'noop'
16621675
kwargs['mode'] = mode
16631676

1677+
# `sym_opt` -- symbolic-level options (mathematical choices, not codegen)
1678+
sym_opt = kwargs.pop('sym_opt', None) or {}
1679+
if not isinstance(sym_opt, (dict, frozendict)):
1680+
raise InvalidOperator(f"Illegal `sym_opt={str(sym_opt)}`")
1681+
kwargs['sym_options'] = dict(sym_opt)
1682+
16641683
# `platform`
16651684
platform = kwargs.get('platform')
16661685
if platform is not None:

0 commit comments

Comments
 (0)