Skip to content

Commit c71560b

Browse files
authored
Merge pull request #2936 from devitocodes/inject-interp
api: Fix handling of staggering in injection/interpolation
2 parents 68e2f74 + 551d050 commit c71560b

8 files changed

Lines changed: 987 additions & 810 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,9 @@ def _eval_at(self, func):
489489
# compare staggering
490490
if self.expr.staggered == func.staggered and self.expr.is_Function:
491491
return self
492+
# Time derivatives are not affected by space staggering
493+
if all(d.is_Time for d in self.dims):
494+
return self
492495

493496
# Check if x0's keys come from a DerivedDimension
494497
x0 = func.indices_ref.getters

devito/operations/interpolators.py

Lines changed: 110 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from abc import ABC, abstractmethod
2+
from contextlib import suppress
23
from functools import cached_property, wraps
4+
from itertools import groupby
35

46
import numpy as np
57
import sympy
@@ -67,12 +69,9 @@ def _extract_subdomain(variables):
6769
"""
6870
sdms = set()
6971
for v in variables:
70-
try:
72+
with suppress(AttributeError):
7173
if v.grid.is_SubDomain:
7274
sdms.add(v.grid)
73-
except AttributeError:
74-
# Variable not on a grid (Indexed for example)
75-
pass
7675

7776
if len(sdms) > 1:
7877
raise NotImplementedError("Sparse operation on multiple Functions defined on"
@@ -230,7 +229,7 @@ def r(self):
230229
return self.sfunction.r
231230

232231
@memoized_meth
233-
def _weights(self, subdomain=None):
232+
def _weights(self, subdomain=None, shifts=None):
234233
raise NotImplementedError
235234

236235
@property
@@ -243,8 +242,22 @@ def _cdim(self):
243242
dims = [self.sfunction._crdim(d) for d in self._gdims]
244243
return dims
245244

245+
def _field_shifts(self, field):
246+
"""
247+
Per-grid-Dimension half-cell shift induced by `field`'s staggering
248+
(e.g. `h_x/2` for a field staggered in `x`). Returns None for
249+
unstaggered fields. SubDomain-induced origin offsets are deliberately
250+
ignored — they are not staggering.
251+
"""
252+
staggered = field.staggered
253+
if not staggered or staggered.on_node:
254+
return ()
255+
return tuple((d.spacing / 2) if s else 0
256+
for d, s in zip(field.dimensions, staggered, strict=True)
257+
if d.is_Space)
258+
246259
@memoized_meth
247-
def _rdim(self, subdomain=None):
260+
def _rdim(self, subdomain=None, shifts=None):
248261
# If the interpolation operation is limited to a SubDomain,
249262
# use the SubDimensions of that SubDomain
250263
if subdomain:
@@ -254,7 +267,7 @@ def _rdim(self, subdomain=None):
254267

255268
# Make radius dimension conditional to avoid OOB
256269
rdims = []
257-
pos = self.sfunction._position_map.values()
270+
pos = self.sfunction._position_map(shifts=shifts).values()
258271

259272
for (d, rd, p) in zip(gdims, self._cdim, pos, strict=True):
260273
# Add conditional to avoid OOB
@@ -279,12 +292,10 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
279292
# dimensions of that SubDomain from any extra dimensions found
280293
edims = []
281294
for v in extras:
282-
try:
295+
with suppress(AttributeError):
283296
if v.grid.is_SubDomain:
284297
edims.extend([d for d in v.grid.dimensions
285298
if d.is_Sub and d.root in self._gdims])
286-
except AttributeError:
287-
pass
288299

289300
gdims = filter_ordered(edims + list(self._gdims))
290301
extra = filter_ordered([i for v in extras for i in v.dimensions
@@ -300,27 +311,34 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
300311
idims = extra + as_tuple(implicit_dims) + self.sfunction.dimensions
301312
return tuple(idims)
302313

303-
def _coeff_temps(self, implicit_dims):
314+
def _coeff_temps(self, implicit_dims, shifts=None):
304315
return []
305316

306-
def _positions(self, implicit_dims):
317+
def _positions(self, implicit_dims, shifts=None):
307318
return [Eq(v, INT(floor(k)), implicit_dims=implicit_dims)
308-
for k, v in self.sfunction._position_map.items()]
319+
for k, v in self.sfunction._position_map(shifts=shifts).items()]
309320

310-
def _interp_idx(self, variables, implicit_dims=None, subdomain=None):
321+
def _interp_idx(self, variables, implicit_dims=None, subdomain=None,
322+
shifts=None):
311323
"""
312-
Generate interpolation indices for the DiscreteFunctions in ``variables``.
324+
Generate interpolation indices for the DiscreteFunctions in `variables`.
325+
326+
`shifts` is a per-Dimension physical offset for the target field's
327+
origin: it only affects the integer position symbol via the position
328+
map (`pos = floor((c - o - shift)/h)`). The index substitution itself
329+
is unchanged — any staggered offset in a field's own symbolic access is
330+
absorbed by Devito's normal indexing.
313331
"""
314-
pos = self.sfunction._position_map.values()
332+
pos = self.sfunction._position_map(shifts=shifts).values()
315333

316334
# Temporaries for the position
317-
temps = self._positions(implicit_dims)
335+
temps = self._positions(implicit_dims, shifts=shifts)
318336

319337
# Coefficient symbol expression
320-
temps.extend(self._coeff_temps(implicit_dims))
338+
temps.extend(self._coeff_temps(implicit_dims, shifts=shifts))
321339

322340
# Substitution mapper for variables
323-
mapper = self._rdim(subdomain=subdomain).getters
341+
mapper = self._rdim(subdomain=subdomain, shifts=shifts).getters
324342

325343
# Index substitution to make in variables
326344
subs = {
@@ -337,7 +355,7 @@ def _interp_idx(self, variables, implicit_dims=None, subdomain=None):
337355
@check_coords
338356
def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
339357
"""
340-
Generate equations interpolating an arbitrary expression into ``self``.
358+
Generate equations interpolating an arbitrary expression into `self`.
341359
342360
Parameters
343361
----------
@@ -375,7 +393,7 @@ def inject(self, field, expr, implicit_dims=None):
375393

376394
def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
377395
"""
378-
Generate equations interpolating an arbitrary expression into ``self``.
396+
Generate equations interpolating an arbitrary expression into `self`.
379397
380398
Parameters
381399
----------
@@ -389,16 +407,13 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
389407
the operator.
390408
"""
391409
# Derivatives must be evaluated before the introduction of indirect accesses
392-
try:
393-
_expr = expr.evaluate
394-
except AttributeError:
395-
# E.g., a generic SymPy expression or a number
396-
_expr = expr
410+
with suppress(AttributeError):
411+
expr = expr._eval_at(self.sfunction).evaluate
397412

398413
if self_subs is None:
399414
self_subs = {}
400415

401-
variables = list(retrieve_function_carriers(_expr))
416+
variables = list(retrieve_function_carriers(expr))
402417
subdomain = _extract_subdomain(variables)
403418

404419
# Implicit dimensions
@@ -413,7 +428,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
413428
summands = [Eq(rhs, 0., implicit_dims=implicit_dims)]
414429
# Substitute coordinate base symbols into the interpolation coefficients
415430
weights = self._weights(subdomain=subdomain)
416-
summands.extend([Inc(rhs, (weights * _expr).xreplace(idx_subs),
431+
summands.extend([Inc(rhs, (weights * expr).xreplace(idx_subs),
417432
implicit_dims=implicit_dims)])
418433

419434
# Write/Incr `self`
@@ -451,35 +466,48 @@ def _inject(self, field, expr, implicit_dims=None):
451466

452467
subdomain = _extract_subdomain(fields)
453468

454-
# Derivatives must be evaluated before the introduction of indirect accesses
455-
try:
456-
_exprs = tuple(e.evaluate for e in exprs)
457-
except AttributeError:
458-
# E.g., a generic SymPy expression or a number
459-
_exprs = exprs
460-
461-
variables = list(v for e in _exprs for v in retrieve_function_carriers(e))
462-
463-
# Implicit dimensions
464-
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)
465-
# Move all temporaries inside inner loop to improve parallelism
466-
# Can only be done for inject as interpolation need a temporary
467-
# summing temp that wouldn't allow collapsing
468-
implicit_dims = implicit_dims + tuple(r.parent for r in
469-
self._rdim(subdomain=subdomain))
470-
471-
# List of indirection indices for all adjacent grid points
472-
finterp = fields + as_tuple(variables)
473-
idx_subs, temps = self._interp_idx(finterp, implicit_dims=implicit_dims,
474-
subdomain=subdomain)
475-
476-
# Substitute coordinate base symbols into the interpolation coefficients
477-
eqns = [Inc(_field.xreplace(idx_subs),
478-
(self._weights(subdomain=subdomain) * _expr).xreplace(idx_subs),
479-
implicit_dims=implicit_dims)
480-
for (_field, _expr) in zip(fields, _exprs, strict=True)]
481-
482-
return temps + eqns
469+
# Derivatives must be evaluated before the introduction of indirect
470+
# accesses. Variables are sampled at their own grid location; the
471+
# position map for the target field carries the staggering so the
472+
# field's stencil neighbors land on the right indices.
473+
with suppress(AttributeError):
474+
exprs = tuple(e._eval_at(f).evaluate
475+
for e, f in zip(exprs, fields, strict=True))
476+
477+
eqns = []
478+
temps = []
479+
# We need to create one set of equations (temps and and coeffs) per staggering
480+
# field in which we inject as the reference index depends on the field's origin
481+
for _, g in groupby(zip(fields, exprs, strict=True), lambda f: f[0].staggered):
482+
g_fields, g_exprs = zip(*g, strict=True)
483+
variables = list(v for e in g_exprs for v in retrieve_function_carriers(e))
484+
485+
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)
486+
487+
# All fields in a single injection share the same staggering by
488+
# construction (they are written together at the same indices), so
489+
# derive shifts from the first field.
490+
shifts = self._field_shifts(g_fields[0])
491+
492+
# Move all temporaries inside inner loop to improve parallelism
493+
# Can only be done for inject as interpolation needs a summing temp
494+
# that wouldn't allow collapsing
495+
implicit_dims = implicit_dims + tuple(r.parent for r in
496+
self._rdim(subdomain=subdomain,
497+
shifts=shifts))
498+
499+
# List of indirection indices for all adjacent grid points
500+
idx_subs, _temps = self._interp_idx(list(g_fields) + variables,
501+
implicit_dims=implicit_dims,
502+
subdomain=subdomain, shifts=shifts)
503+
504+
w = self._weights(subdomain=subdomain, shifts=shifts)
505+
temps.extend(_temps)
506+
eqns.extend([Inc(f.xreplace(idx_subs), (w * e).xreplace(idx_subs),
507+
implicit_dims=implicit_dims)
508+
for f, e in zip(g_fields, g_exprs, strict=True)])
509+
510+
return filter_ordered(temps) + eqns
483511

484512

485513
class LinearInterpolator(WeightedInterpolator):
@@ -495,24 +523,30 @@ class LinearInterpolator(WeightedInterpolator):
495523
_name = 'linear'
496524

497525
@memoized_meth
498-
def _weights(self, subdomain=None):
499-
rdim = self._rdim(subdomain=subdomain)
526+
def _weights(self, subdomain=None, shifts=None):
527+
rdim = self._rdim(subdomain=subdomain, shifts=shifts)
500528
c = [(1 - p) * (1 - r) + p * r
501-
for (p, d, r) in zip(self._point_symbols, self._gdims, rdim, strict=True)]
529+
for (p, d, r) in zip(self._point_symbols(shifts), self._gdims, rdim,
530+
strict=True)]
502531
return Mul(*c)
503532

504-
@cached_property
505-
def _point_symbols(self):
533+
@memoized_meth
534+
def _point_symbols(self, shifts=None):
506535
"""Symbol for coordinate value in each Dimension of the point."""
507536
dtype = self.sfunction.coordinates.dtype
508-
return DimensionTuple(*(Symbol(name=f'p{d}', dtype=dtype)
509-
for d in self.grid.dimensions),
510-
getters=self.grid.dimensions)
537+
symbols = []
538+
for d in self.grid.dimensions:
539+
if shifts and shifts[self.grid.dimensions.index(d)] != 0:
540+
symbols.append(Symbol(name=f'p{d}_s1', dtype=dtype))
541+
else:
542+
symbols.append(Symbol(name=f'p{d}', dtype=dtype))
543+
return DimensionTuple(*symbols, getters=self.grid.dimensions)
511544

512-
def _coeff_temps(self, implicit_dims):
545+
def _coeff_temps(self, implicit_dims, shifts=None):
513546
# Positions
514-
pmap = self.sfunction._position_map
515-
poseq = [Eq(self._point_symbols[d], pos - floor(pos),
547+
pmap = self.sfunction._position_map(shifts=shifts)
548+
psyms = self._point_symbols(shifts)
549+
poseq = [Eq(psyms[d], pos - floor(pos),
516550
implicit_dims=implicit_dims)
517551
for (d, pos) in zip(self._gdims, pmap.keys(), strict=True)]
518552
return poseq
@@ -531,23 +565,24 @@ class PrecomputedInterpolator(WeightedInterpolator):
531565

532566
_name = 'precomp'
533567

534-
def _positions(self, implicit_dims):
568+
def _positions(self, implicit_dims, shifts=None):
535569
if self.sfunction.gridpoints_data is None:
536-
return super()._positions(implicit_dims)
570+
return super()._positions(implicit_dims, shifts=shifts)
537571
else:
538572
# No position temp as we have directly the gridpoints
539573
return[Eq(p, k, implicit_dims=implicit_dims)
540-
for (k, p) in self.sfunction._position_map.items()]
574+
for (k, p) in self.sfunction._position_map(shifts=shifts).items()]
541575

542576
@property
543577
def interpolation_coeffs(self):
544578
return self.sfunction.interpolation_coeffs
545579

546580
@memoized_meth
547-
def _weights(self, subdomain=None):
581+
def _weights(self, subdomain=None, shifts=None):
548582
ddim, cdim = self.interpolation_coeffs.dimensions[1:]
549583
mappers = [{ddim: ri, cdim: rd-rd.parent.symbolic_min}
550-
for (ri, rd) in enumerate(self._rdim(subdomain=subdomain))]
584+
for (ri, rd) in enumerate(self._rdim(subdomain=subdomain,
585+
shifts=shifts))]
551586
return Mul(*[self.interpolation_coeffs.subs(mapper)
552587
for mapper in mappers])
553588

@@ -592,8 +627,8 @@ def interpolation_coeffs(self):
592627
return tuple(coeffs)
593628

594629
@memoized_meth
595-
def _weights(self, subdomain=None):
596-
rdims = self._rdim(subdomain=subdomain)
630+
def _weights(self, subdomain=None, shifts=None):
631+
rdims = self._rdim(subdomain=subdomain, shifts=shifts)
597632
return Mul(*[
598633
w._subs(rd, rd-rd.parent.symbolic_min)
599634
for (rd, w) in zip(rdims, self.interpolation_coeffs, strict=True)

devito/types/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,12 +960,12 @@ def indices(self):
960960
"""The indices of the object."""
961961
return DimensionTuple(*self.args, getters=self.dimensions)
962962

963-
@property
963+
@cached_property
964964
def indices_ref(self):
965965
"""The reference indices of the object (indices at first creation)."""
966966
return DimensionTuple(*self.function.indices, getters=self.dimensions)
967967

968-
@property
968+
@cached_property
969969
def origin(self):
970970
"""
971971
Origin of the AbstractFunction in term of Dimension

devito/types/dense.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,8 @@ def _eval_at(self, func):
11261126
for d in self.dimensions:
11271127
try:
11281128
if self.indices_ref[d] is not func.indices_ref[d]:
1129-
f_idx = func.indices_ref[d]._subs(func.dimensions[d], d)
1129+
d0 = func.dimensions.get(d, d)
1130+
f_idx = func.indices_ref[d]._subs(d0, d)
11301131
mapper[self.indices_ref[d]] = f_idx
11311132
except KeyError:
11321133
pass

0 commit comments

Comments
 (0)