Skip to content

Commit 551d050

Browse files
committed
api: fix inject handling of different staggering
1 parent 6552d82 commit 551d050

7 files changed

Lines changed: 115 additions & 96 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,9 @@ def _eval_at(self, func):
489489
# compare staggering
490490
if self.expr.staggered == func.staggered and self.expr.is_Function:
491491
return self
492+
# Time derivatives are not affected by space staggering
493+
if all(d.is_Time for d in self.dims):
494+
return self
492495

493496
# Check if x0's keys come from a DerivedDimension
494497
x0 = func.indices_ref.getters

devito/operations/interpolators.py

Lines changed: 65 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from abc import ABC, abstractmethod
2+
from contextlib import suppress
23
from functools import cached_property, wraps
4+
from itertools import groupby
35

46
import numpy as np
57
import sympy
@@ -67,12 +69,9 @@ def _extract_subdomain(variables):
6769
"""
6870
sdms = set()
6971
for v in variables:
70-
try:
72+
with suppress(AttributeError):
7173
if v.grid.is_SubDomain:
7274
sdms.add(v.grid)
73-
except AttributeError:
74-
# Variable not on a grid (Indexed for example)
75-
pass
7675

7776
if len(sdms) > 1:
7877
raise NotImplementedError("Sparse operation on multiple Functions defined on"
@@ -245,19 +244,17 @@ def _cdim(self):
245244

246245
def _field_shifts(self, field):
247246
"""
248-
Per-grid-Dimension half-cell shift induced by ``field``'s staggering
249-
(e.g. ``h_x/2`` for a field staggered in ``x``). Returns None for
247+
Per-grid-Dimension half-cell shift induced by `field`'s staggering
248+
(e.g. `h_x/2` for a field staggered in `x`). Returns None for
250249
unstaggered fields. SubDomain-induced origin offsets are deliberately
251250
ignored — they are not staggering.
252251
"""
253-
try:
254-
staggered = field.staggered
255-
except AttributeError:
256-
return None
257-
if not staggered or all(s == 0 for s in staggered):
258-
return None
252+
staggered = field.staggered
253+
if not staggered or staggered.on_node:
254+
return ()
259255
return tuple((d.spacing / 2) if s else 0
260-
for d, s in zip(self._gdims, staggered, strict=True))
256+
for d, s in zip(field.dimensions, staggered, strict=True)
257+
if d.is_Space)
261258

262259
@memoized_meth
263260
def _rdim(self, subdomain=None, shifts=None):
@@ -295,12 +292,10 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
295292
# dimensions of that SubDomain from any extra dimensions found
296293
edims = []
297294
for v in extras:
298-
try:
295+
with suppress(AttributeError):
299296
if v.grid.is_SubDomain:
300297
edims.extend([d for d in v.grid.dimensions
301298
if d.is_Sub and d.root in self._gdims])
302-
except AttributeError:
303-
pass
304299

305300
gdims = filter_ordered(edims + list(self._gdims))
306301
extra = filter_ordered([i for v in extras for i in v.dimensions
@@ -326,11 +321,11 @@ def _positions(self, implicit_dims, shifts=None):
326321
def _interp_idx(self, variables, implicit_dims=None, subdomain=None,
327322
shifts=None):
328323
"""
329-
Generate interpolation indices for the DiscreteFunctions in ``variables``.
324+
Generate interpolation indices for the DiscreteFunctions in `variables`.
330325
331-
``shifts`` is a per-Dimension physical offset for the target field's
326+
`shifts` is a per-Dimension physical offset for the target field's
332327
origin: it only affects the integer position symbol via the position
333-
map (``pos = floor((c - o - shift)/h)``). The index substitution itself
328+
map (`pos = floor((c - o - shift)/h)`). The index substitution itself
334329
is unchanged — any staggered offset in a field's own symbolic access is
335330
absorbed by Devito's normal indexing.
336331
"""
@@ -360,7 +355,7 @@ def _interp_idx(self, variables, implicit_dims=None, subdomain=None,
360355
@check_coords
361356
def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
362357
"""
363-
Generate equations interpolating an arbitrary expression into ``self``.
358+
Generate equations interpolating an arbitrary expression into `self`.
364359
365360
Parameters
366361
----------
@@ -398,7 +393,7 @@ def inject(self, field, expr, implicit_dims=None):
398393

399394
def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
400395
"""
401-
Generate equations interpolating an arbitrary expression into ``self``.
396+
Generate equations interpolating an arbitrary expression into `self`.
402397
403398
Parameters
404399
----------
@@ -412,16 +407,13 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
412407
the operator.
413408
"""
414409
# Derivatives must be evaluated before the introduction of indirect accesses
415-
try:
416-
_expr = expr._eval_at(self.sfunction).evaluate
417-
except AttributeError:
418-
# E.g., a generic SymPy expression or a number
419-
_expr = expr
410+
with suppress(AttributeError):
411+
expr = expr._eval_at(self.sfunction).evaluate
420412

421413
if self_subs is None:
422414
self_subs = {}
423415

424-
variables = list(retrieve_function_carriers(_expr))
416+
variables = list(retrieve_function_carriers(expr))
425417
subdomain = _extract_subdomain(variables)
426418

427419
# Implicit dimensions
@@ -436,7 +428,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
436428
summands = [Eq(rhs, 0., implicit_dims=implicit_dims)]
437429
# Substitute coordinate base symbols into the interpolation coefficients
438430
weights = self._weights(subdomain=subdomain)
439-
summands.extend([Inc(rhs, (weights * _expr).xreplace(idx_subs),
431+
summands.extend([Inc(rhs, (weights * expr).xreplace(idx_subs),
440432
implicit_dims=implicit_dims)])
441433

442434
# Write/Incr `self`
@@ -478,42 +470,44 @@ def _inject(self, field, expr, implicit_dims=None):
478470
# accesses. Variables are sampled at their own grid location; the
479471
# position map for the target field carries the staggering so the
480472
# field's stencil neighbors land on the right indices.
481-
try:
482-
_exprs = tuple(e._eval_at(f).evaluate
483-
for e, f in zip(exprs, fields, strict=True))
484-
except AttributeError:
485-
# E.g., a generic SymPy expression or a number
486-
_exprs = exprs
487-
488-
variables = list(v for e in _exprs for v in retrieve_function_carriers(e))
489-
490-
# Implicit dimensions
491-
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)
492-
493-
# All fields in a single injection share the same staggering by
494-
# construction (they are written together at the same indices), so
495-
# derive shifts from the first field.
496-
shifts = self._field_shifts(fields[0])
497-
498-
# Move all temporaries inside inner loop to improve parallelism
499-
# Can only be done for inject as interpolation needs a summing temp
500-
# that wouldn't allow collapsing
501-
implicit_dims = implicit_dims + tuple(r.parent for r in
502-
self._rdim(subdomain=subdomain,
503-
shifts=shifts))
504-
505-
# List of indirection indices for all adjacent grid points
506-
idx_subs, temps = self._interp_idx(list(fields) + variables,
507-
implicit_dims=implicit_dims,
508-
subdomain=subdomain, shifts=shifts)
509-
510-
eqns = [Inc(_field.xreplace(idx_subs),
511-
(self._weights(subdomain=subdomain, shifts=shifts)
512-
* _expr).xreplace(idx_subs),
513-
implicit_dims=implicit_dims)
514-
for _field, _expr in zip(fields, _exprs, strict=True)]
515-
516-
return temps + eqns
473+
with suppress(AttributeError):
474+
exprs = tuple(e._eval_at(f).evaluate
475+
for e, f in zip(exprs, fields, strict=True))
476+
477+
eqns = []
478+
temps = []
479+
# We need to create one set of equations (temps and and coeffs) per staggering
480+
# field in which we inject as the reference index depends on the field's origin
481+
for _, g in groupby(zip(fields, exprs, strict=True), lambda f: f[0].staggered):
482+
g_fields, g_exprs = zip(*g, strict=True)
483+
variables = list(v for e in g_exprs for v in retrieve_function_carriers(e))
484+
485+
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)
486+
487+
# All fields in a single injection share the same staggering by
488+
# construction (they are written together at the same indices), so
489+
# derive shifts from the first field.
490+
shifts = self._field_shifts(g_fields[0])
491+
492+
# Move all temporaries inside inner loop to improve parallelism
493+
# Can only be done for inject as interpolation needs a summing temp
494+
# that wouldn't allow collapsing
495+
implicit_dims = implicit_dims + tuple(r.parent for r in
496+
self._rdim(subdomain=subdomain,
497+
shifts=shifts))
498+
499+
# List of indirection indices for all adjacent grid points
500+
idx_subs, _temps = self._interp_idx(list(g_fields) + variables,
501+
implicit_dims=implicit_dims,
502+
subdomain=subdomain, shifts=shifts)
503+
504+
w = self._weights(subdomain=subdomain, shifts=shifts)
505+
temps.extend(_temps)
506+
eqns.extend([Inc(f.xreplace(idx_subs), (w * e).xreplace(idx_subs),
507+
implicit_dims=implicit_dims)
508+
for f, e in zip(g_fields, g_exprs, strict=True)])
509+
510+
return filter_ordered(temps) + eqns
517511

518512

519513
class LinearInterpolator(WeightedInterpolator):
@@ -540,10 +534,13 @@ def _weights(self, subdomain=None, shifts=None):
540534
def _point_symbols(self, shifts=None):
541535
"""Symbol for coordinate value in each Dimension of the point."""
542536
dtype = self.sfunction.coordinates.dtype
543-
suffix = self.sfunction._shifts_suffix(shifts)
544-
return DimensionTuple(*(Symbol(name=f'p{d}{suffix}', dtype=dtype)
545-
for d in self.grid.dimensions),
546-
getters=self.grid.dimensions)
537+
symbols = []
538+
for d in self.grid.dimensions:
539+
if shifts and shifts[self.grid.dimensions.index(d)] != 0:
540+
symbols.append(Symbol(name=f'p{d}_s1', dtype=dtype))
541+
else:
542+
symbols.append(Symbol(name=f'p{d}', dtype=dtype))
543+
return DimensionTuple(*symbols, getters=self.grid.dimensions)
547544

548545
def _coeff_temps(self, implicit_dims, shifts=None):
549546
# Positions

devito/types/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,12 +960,12 @@ def indices(self):
960960
"""The indices of the object."""
961961
return DimensionTuple(*self.args, getters=self.dimensions)
962962

963-
@property
963+
@cached_property
964964
def indices_ref(self):
965965
"""The reference indices of the object (indices at first creation)."""
966966
return DimensionTuple(*self.function.indices, getters=self.dimensions)
967967

968-
@property
968+
@cached_property
969969
def origin(self):
970970
"""
971971
Origin of the AbstractFunction in term of Dimension

devito/types/sparse.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,17 @@ def sparse_position(self):
257257
def _sparse_dim(self):
258258
return self.dimensions[self.sparse_position]
259259

260-
@property
260+
@cached_property
261261
def indices_ref(self):
262-
return DimensionTuple(*self.grid.dimensions,
263-
getters=self.grid.dimensions)
262+
getters = (*self.dimensions, *self.grid.dimensions)
263+
indices = (*self.dimensions, *self.grid.dimensions)
264+
return DimensionTuple(*indices, getters=getters)
264265

265266
@property
266267
def _grid_map(self):
267268
return {}
268269

269-
@property
270+
@cached_property
270271
def origin(self):
271272
return DimensionTuple(*[0]*len(self.dimensions),
272273
getters=self.dimensions)
@@ -368,15 +369,13 @@ def coordinates_data(self):
368369

369370
@memoized_meth
370371
def _pos_symbols(self, shifts=None):
371-
suffix = self._shifts_suffix(shifts)
372-
return [Symbol(name=f'pos{d}{suffix}', dtype=np.int32)
373-
for d in self.grid.dimensions]
374-
375-
@staticmethod
376-
def _shifts_suffix(shifts):
377-
if not shifts or all(s == 0 for s in shifts):
378-
return ''
379-
return '_s' + ''.join('1' if s != 0 else '0' for s in shifts)
372+
symbols = []
373+
for d in self.grid.dimensions:
374+
if shifts and shifts[self.grid.dimensions.index(d)] != 0:
375+
symbols.append(Symbol(name=f'pos{d}_s1', dtype=np.int32))
376+
else:
377+
symbols.append(Symbol(name=f'pos{d}', dtype=np.int32))
378+
return symbols
380379

381380
@cached_property
382381
def _point_increments(self):
@@ -396,8 +395,7 @@ def _position_map(self, shifts=None):
396395
coordinates (e.g. ``h_x/2`` for a field staggered in ``x``). If ``shifts``
397396
is None, only the grid origin is subtracted.
398397
"""
399-
if shifts is None:
400-
shifts = (0,) * len(self.grid.dimensions)
398+
shifts = shifts or (0,) * len(self.grid.dimensions)
401399
return OrderedDict([
402400
((c - o - s)/d.spacing, p)
403401
for p, c, d, o, s in zip(

examples/seismic/tutorials/06_elastic_varying_parameters.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,8 @@
604604
"outputs": [],
605605
"source": [
606606
"assert np.isclose(norm(rec), 23.504, atol=0, rtol=1e-3)\n",
607-
"assert np.isclose(norm(rec2), 2.425, atol=0, rtol=1e-3)\n",
608-
"assert np.isclose(norm(rec3), 2.889, atol=0, rtol=1e-3)"
607+
"assert np.isclose(norm(rec2), 2.4298, atol=0, rtol=1e-3)\n",
608+
"assert np.isclose(norm(rec3), 2.7481, atol=0, rtol=1e-3)"
609609
]
610610
},
611611
{
@@ -838,8 +838,8 @@
838838
"metadata": {},
839839
"outputs": [],
840840
"source": [
841-
"assert np.isclose(norm(rec2), .3250, atol=0, rtol=1e-3)\n",
842-
"assert np.isclose(norm(rec3), .26745, atol=0, rtol=1e-3)"
841+
"assert np.isclose(norm(rec2), .30388, atol=0, rtol=1e-3)\n",
842+
"assert np.isclose(norm(rec3), .26633, atol=0, rtol=1e-3)"
843843
]
844844
},
845845
{
@@ -1102,9 +1102,9 @@
11021102
"metadata": {},
11031103
"outputs": [],
11041104
"source": [
1105-
"assert np.isclose(norm(rec), 31.23, atol=0, rtol=1e-3)\n",
1106-
"assert np.isclose(norm(rec2), 3.5482, atol=0, rtol=1e-3)\n",
1107-
"assert np.isclose(norm(rec3), 4.7007, atol=0, rtol=1e-3)"
1105+
"assert np.isclose(norm(rec), 29.538, atol=0, rtol=1e-3)\n",
1106+
"assert np.isclose(norm(rec2), 1.9116, atol=0, rtol=1e-3)\n",
1107+
"assert np.isclose(norm(rec3), 3.4919, atol=0, rtol=1e-3)"
11081108
]
11091109
}
11101110
],

tests/test_dimension.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,10 +1274,10 @@ def test_no_index_sparse(self):
12741274

12751275
radius = 1
12761276
indices = [(INT(floor(i)), INT(floor(i))+radius)
1277-
for i in sf._position_map]
1277+
for i in sf._position_map()]
12781278
bounds = [i.symbolic_size - radius for i in grid.dimensions]
12791279

1280-
eqs = [Eq(p, v) for (v, p) in sf._position_map.items()]
1280+
eqs = [Eq(p, v) for (v, p) in sf._position_map().items()]
12811281
for e, i in enumerate(product(*indices)):
12821282
args = [j > 0 for j in i]
12831283
args.extend([j < k for j, k in zip(i, bounds, strict=True)])

tests/test_interpolation.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from devito import (
99
NODE, DefaultDimension, Dimension, Eq, Function, Grid, MatrixSparseTimeFunction,
1010
Operator, PrecomputedSparseFunction, PrecomputedSparseTimeFunction, Real,
11-
SparseFunction, SparseTimeFunction, SubDomain, TimeFunction, switchconfig
11+
SparseFunction, SparseTimeFunction, SubDomain, TimeFunction, VectorFunction,
12+
switchconfig
1213
)
1314
from devito.operations.interpolators import LinearInterpolator, SincInterpolator
1415
from devito.tools import as_tuple
@@ -472,7 +473,7 @@ def test_inject_staggered(self, stagg):
472473
staggered=staggered)
473474
a.data.fill(0)
474475

475-
b = Function(name='b', grid=a.grid, space_order=2,
476+
b = Function(name='b', grid=a.grid, space_order=8,
476477
staggered=NODE)
477478
b.data.fill(1)
478479
b.data[5, 5, 5] = 2
@@ -508,6 +509,26 @@ def test_inject_staggered(self, stagg):
508509
# Use abs to make sure there is no +- cancellations
509510
assert np.sum(np.abs(a.data)) == interp_val * 2**(sum(np.array(a.staggered)))
510511

512+
def test_inject_staggered_mixed(self):
513+
grid = Grid((11, 11, 11))
514+
v = VectorFunction(name='v', grid=grid, space_order=2)
515+
b = Function(name='b', grid=grid, space_order=2, staggered=NODE)
516+
p = SparseFunction(name="p", grid=grid, nt=10, npoint=1)
517+
518+
eq = p.inject(v, expr=b * p).evaluate
519+
520+
# We should have
521+
# - 3 injection equations v_x, v_y, v_z
522+
# The standard 6 on node temps posx, posy, posz, px, py, pz
523+
# 2 temps for the staggered in x vx posz_s1, px_s1
524+
# 2 temps for the staggered in y vy posz_s1, py_s1
525+
# 2 temps for the staggered in z vz posz_s1, pz_s1
526+
assert len(eq) == 3 + 6 + 2 + 2 + 2
527+
528+
op = Operator(eq)
529+
# Should be a single loop nest with 3 injections
530+
assert_structure(op, ['p_p,rp_px,rp_py,rp_pz'], 'p_prp_pxrp_py,rp_pz')
531+
511532
# ---------------------------------------------------------------------------
512533
# Precomputed interpolation / injection
513534
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)