Skip to content

Commit 74db52a

Browse files
committed
api: fix inject handling of different staggering
1 parent 6552d82 commit 74db52a

6 files changed

Lines changed: 92 additions & 56 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,9 @@ def _eval_at(self, func):
489489
# compare staggering
490490
if self.expr.staggered == func.staggered and self.expr.is_Function:
491491
return self
492+
# Time derivatives are not affected by space staggering
493+
if all(d.is_Time for d in self.dims):
494+
return self
492495

493496
# Check if x0's keys come from a DerivedDimension
494497
x0 = func.indices_ref.getters

devito/operations/interpolators.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
22
from functools import cached_property, wraps
3+
from itertools import groupby
34

45
import numpy as np
56
import sympy
@@ -257,7 +258,8 @@ def _field_shifts(self, field):
257258
if not staggered or all(s == 0 for s in staggered):
258259
return None
259260
return tuple((d.spacing / 2) if s else 0
260-
for d, s in zip(self._gdims, staggered, strict=True))
261+
for d, s in zip(field.dimensions, staggered, strict=True)
262+
if d.is_Space)
261263

262264
@memoized_meth
263265
def _rdim(self, subdomain=None, shifts=None):
@@ -485,35 +487,43 @@ def _inject(self, field, expr, implicit_dims=None):
485487
# E.g., a generic SymPy expression or a number
486488
_exprs = exprs
487489

488-
variables = list(v for e in _exprs for v in retrieve_function_carriers(e))
489-
490-
# Implicit dimensions
491-
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)
492-
493-
# All fields in a single injection share the same staggering by
494-
# construction (they are written together at the same indices), so
495-
# derive shifts from the first field.
496-
shifts = self._field_shifts(fields[0])
497-
498-
# Move all temporaries inside inner loop to improve parallelism
499-
# Can only be done for inject as interpolation needs a summing temp
500-
# that wouldn't allow collapsing
501-
implicit_dims = implicit_dims + tuple(r.parent for r in
502-
self._rdim(subdomain=subdomain,
503-
shifts=shifts))
504-
505-
# List of indirection indices for all adjacent grid points
506-
idx_subs, temps = self._interp_idx(list(fields) + variables,
507-
implicit_dims=implicit_dims,
508-
subdomain=subdomain, shifts=shifts)
509-
510-
eqns = [Inc(_field.xreplace(idx_subs),
511-
(self._weights(subdomain=subdomain, shifts=shifts)
512-
* _expr).xreplace(idx_subs),
513-
implicit_dims=implicit_dims)
514-
for _field, _expr in zip(fields, _exprs, strict=True)]
515-
516-
return temps + eqns
490+
eqns = []
491+
temps = []
492+
pairs = zip(fields, _exprs, strict=True)
493+
# We need to create one set of equations (temps and and coeffs) per staggering
494+
# field in which we inject as the reference index depends on the field's origin
495+
for _, g in groupby(pairs, lambda f: f[0].staggered):
496+
g = list(g)
497+
g_fields = [f for f, _ in g]
498+
g_exprs = [e for _, e in g]
499+
variables = list(v for e in g_exprs for v in retrieve_function_carriers(e))
500+
501+
# Implicit dimensions
502+
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)
503+
504+
# All fields in a single injection share the same staggering by
505+
# construction (they are written together at the same indices), so
506+
# derive shifts from the first field.
507+
shifts = self._field_shifts(g_fields[0])
508+
509+
# Move all temporaries inside inner loop to improve parallelism
510+
# Can only be done for inject as interpolation needs a summing temp
511+
# that wouldn't allow collapsing
512+
implicit_dims = implicit_dims + tuple(r.parent for r in
513+
self._rdim(subdomain=subdomain,
514+
shifts=shifts))
515+
516+
# List of indirection indices for all adjacent grid points
517+
idx_subs, _temps = self._interp_idx(g_fields + variables,
518+
implicit_dims=implicit_dims,
519+
subdomain=subdomain, shifts=shifts)
520+
w = self._weights(subdomain=subdomain, shifts=shifts)
521+
temps.extend(_temps)
522+
eqns.extend([Inc(_field.xreplace(idx_subs), (w * _expr).xreplace(idx_subs),
523+
implicit_dims=implicit_dims)
524+
for _field, _expr in zip(g_fields, g_exprs, strict=True)])
525+
526+
return filter_ordered(temps) + eqns
517527

518528

519529
class LinearInterpolator(WeightedInterpolator):
@@ -540,10 +550,13 @@ def _weights(self, subdomain=None, shifts=None):
540550
def _point_symbols(self, shifts=None):
541551
"""Symbol for coordinate value in each Dimension of the point."""
542552
dtype = self.sfunction.coordinates.dtype
543-
suffix = self.sfunction._shifts_suffix(shifts)
544-
return DimensionTuple(*(Symbol(name=f'p{d}{suffix}', dtype=dtype)
545-
for d in self.grid.dimensions),
546-
getters=self.grid.dimensions)
553+
symbols = []
554+
for d in self.grid.dimensions:
555+
if shifts and shifts[self.grid.dimensions.index(d)] != 0:
556+
symbols.append(Symbol(name=f'p{d}_s1', dtype=dtype))
557+
else:
558+
symbols.append(Symbol(name=f'p{d}', dtype=dtype))
559+
return DimensionTuple(*symbols, getters=self.grid.dimensions)
547560

548561
def _coeff_temps(self, implicit_dims, shifts=None):
549562
# Positions

devito/types/sparse.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,9 @@ def _sparse_dim(self):
259259

260260
@property
261261
def indices_ref(self):
262-
return DimensionTuple(*self.grid.dimensions,
263-
getters=self.grid.dimensions)
262+
getters = (*self.dimensions, *self.grid.dimensions)
263+
indices = (*self.dimensions, *self.grid.dimensions)
264+
return DimensionTuple(*indices, getters=getters)
264265

265266
@property
266267
def _grid_map(self):
@@ -368,15 +369,13 @@ def coordinates_data(self):
368369

369370
@memoized_meth
370371
def _pos_symbols(self, shifts=None):
371-
suffix = self._shifts_suffix(shifts)
372-
return [Symbol(name=f'pos{d}{suffix}', dtype=np.int32)
373-
for d in self.grid.dimensions]
374-
375-
@staticmethod
376-
def _shifts_suffix(shifts):
377-
if not shifts or all(s == 0 for s in shifts):
378-
return ''
379-
return '_s' + ''.join('1' if s != 0 else '0' for s in shifts)
372+
symbols = []
373+
for d in self.grid.dimensions:
374+
if shifts and shifts[self.grid.dimensions.index(d)] != 0:
375+
symbols.append(Symbol(name=f'pos{d}_s1', dtype=np.int32))
376+
else:
377+
symbols.append(Symbol(name=f'pos{d}', dtype=np.int32))
378+
return symbols
380379

381380
@cached_property
382381
def _point_increments(self):

examples/seismic/tutorials/06_elastic_varying_parameters.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,8 @@
604604
"outputs": [],
605605
"source": [
606606
"assert np.isclose(norm(rec), 23.504, atol=0, rtol=1e-3)\n",
607-
"assert np.isclose(norm(rec2), 2.425, atol=0, rtol=1e-3)\n",
608-
"assert np.isclose(norm(rec3), 2.889, atol=0, rtol=1e-3)"
607+
"assert np.isclose(norm(rec2), 2.4298, atol=0, rtol=1e-3)\n",
608+
"assert np.isclose(norm(rec3), 2.7481, atol=0, rtol=1e-3)"
609609
]
610610
},
611611
{
@@ -838,8 +838,8 @@
838838
"metadata": {},
839839
"outputs": [],
840840
"source": [
841-
"assert np.isclose(norm(rec2), .3250, atol=0, rtol=1e-3)\n",
842-
"assert np.isclose(norm(rec3), .26745, atol=0, rtol=1e-3)"
841+
"assert np.isclose(norm(rec2), .30388, atol=0, rtol=1e-3)\n",
842+
"assert np.isclose(norm(rec3), .26633, atol=0, rtol=1e-3)"
843843
]
844844
},
845845
{
@@ -1102,9 +1102,9 @@
11021102
"metadata": {},
11031103
"outputs": [],
11041104
"source": [
1105-
"assert np.isclose(norm(rec), 31.23, atol=0, rtol=1e-3)\n",
1106-
"assert np.isclose(norm(rec2), 3.5482, atol=0, rtol=1e-3)\n",
1107-
"assert np.isclose(norm(rec3), 4.7007, atol=0, rtol=1e-3)"
1105+
"assert np.isclose(norm(rec), 29.538, atol=0, rtol=1e-3)\n",
1106+
"assert np.isclose(norm(rec2), 1.9116, atol=0, rtol=1e-3)\n",
1107+
"assert np.isclose(norm(rec3), 3.4919, atol=0, rtol=1e-3)"
11081108
]
11091109
}
11101110
],

tests/test_dimension.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,10 +1274,10 @@ def test_no_index_sparse(self):
12741274

12751275
radius = 1
12761276
indices = [(INT(floor(i)), INT(floor(i))+radius)
1277-
for i in sf._position_map]
1277+
for i in sf._position_map()]
12781278
bounds = [i.symbolic_size - radius for i in grid.dimensions]
12791279

1280-
eqs = [Eq(p, v) for (v, p) in sf._position_map.items()]
1280+
eqs = [Eq(p, v) for (v, p) in sf._position_map().items()]
12811281
for e, i in enumerate(product(*indices)):
12821282
args = [j > 0 for j in i]
12831283
args.extend([j < k for j, k in zip(i, bounds, strict=True)])

tests/test_interpolation.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from devito import (
99
NODE, DefaultDimension, Dimension, Eq, Function, Grid, MatrixSparseTimeFunction,
1010
Operator, PrecomputedSparseFunction, PrecomputedSparseTimeFunction, Real,
11-
SparseFunction, SparseTimeFunction, SubDomain, TimeFunction, switchconfig
11+
SparseFunction, SparseTimeFunction, SubDomain, TimeFunction, VectorFunction,
12+
switchconfig
1213
)
1314
from devito.operations.interpolators import LinearInterpolator, SincInterpolator
1415
from devito.tools import as_tuple
@@ -472,7 +473,7 @@ def test_inject_staggered(self, stagg):
472473
staggered=staggered)
473474
a.data.fill(0)
474475

475-
b = Function(name='b', grid=a.grid, space_order=2,
476+
b = Function(name='b', grid=a.grid, space_order=8,
476477
staggered=NODE)
477478
b.data.fill(1)
478479
b.data[5, 5, 5] = 2
@@ -508,6 +509,26 @@ def test_inject_staggered(self, stagg):
508509
# Use abs to make sure there is no +- cancellations
509510
assert np.sum(np.abs(a.data)) == interp_val * 2**(sum(np.array(a.staggered)))
510511

512+
def test_inject_staggered_mixed(self):
513+
grid = Grid((11, 11, 11))
514+
v = VectorFunction(name='v', grid=grid, space_order=2)
515+
b = Function(name='b', grid=grid, space_order=2, staggered=NODE)
516+
p = SparseFunction(name="p", grid=grid, nt=10, npoint=1)
517+
518+
eq = p.inject(v, expr=b * p).evaluate
519+
520+
# We should have
521+
# - 3 injection equations v_x, v_y, v_z
522+
# The standard 6 on node temps posx, posy, posz, px, py, pz
523+
# 2 temps for the staggered in x vx posz_s1, px_s1
524+
# 2 temps for the staggered in y vy posz_s1, py_s1
525+
# 2 temps for the staggered in z vz posz_s1, pz_s1
526+
assert len(eq) == 3 + 6 + 2 + 2 + 2
527+
528+
op = Operator(eq)
529+
# Should be a single loop nest with 3 injections
530+
assert_structure(op, ['p_p,rp_px,rp_py,rp_pz'], 'p_prp_pxrp_py,rp_pz')
531+
511532
# ---------------------------------------------------------------------------
512533
# Precomputed interpolation / injection
513534
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)