Skip to content

Commit a05f954

Browse files
committed
api: add option to apply interp indices to function coeffs at inject
1 parent 91d8596 commit a05f954

3 files changed

Lines changed: 47 additions & 9 deletions

File tree

devito/operations/interpolators.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,19 +164,21 @@ class Injection(UnevaluatedSparseOperation):
164164

165165
__rargs__ = ('field', 'expr', 'implicit_dims') + UnevaluatedSparseOperation.__rargs__
166166

167-
def __new__(cls, field, expr, implicit_dims, interpolator):
167+
def __new__(cls, field, expr, implicit_dims, interpolator, interp_expr=False):
168168
obj = super().__new__(cls, interpolator)
169169

170170
# TODO: unused now, but will be necessary to compute the adjoint
171171
obj.field = field
172172
obj.expr = expr
173173
obj.implicit_dims = implicit_dims
174+
obj.interp_expr = interp_expr
174175

175176
return obj
176177

177178
def operation(self, **kwargs):
178179
return self.interpolator._inject(expr=self.expr, field=self.field,
179-
implicit_dims=self.implicit_dims)
180+
implicit_dims=self.implicit_dims,
181+
interp_expr=self.interp_expr)
180182

181183
def __repr__(self):
182184
return f"Injection({repr(self.expr)} into {repr(self.field)})"
@@ -366,7 +368,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
366368

367369
@check_radius
368370
@check_coords
369-
def inject(self, field, expr, implicit_dims=None):
371+
def inject(self, field, expr, implicit_dims=None, interp_expr=False):
370372
"""
371373
Generate equations injecting an arbitrary expression into a field.
372374
@@ -381,7 +383,7 @@ def inject(self, field, expr, implicit_dims=None):
381383
injection expression, but that should be honored when constructing
382384
the operator.
383385
"""
384-
return Injection(field, expr, implicit_dims, self)
386+
return Injection(field, expr, implicit_dims, self, interp_expr=interp_expr)
385387

386388
def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
387389
"""
@@ -433,7 +435,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
433435

434436
return temps + summands + last
435437

436-
def _inject(self, field, expr, implicit_dims=None):
438+
def _inject(self, field, expr, implicit_dims=None, interp_expr=False):
437439
"""
438440
Generate equations injecting an arbitrary expression into a field.
439441
@@ -479,8 +481,10 @@ def _inject(self, field, expr, implicit_dims=None):
479481
self._rdim(subdomain=subdomain))
480482

481483
# List of indirection indices for all adjacent grid points
482-
idx_subs, temps = self._interp_idx(fields, implicit_dims=implicit_dims,
483-
pos_only=variables, subdomain=subdomain)
484+
finterp = fields + as_tuple(variables) if interp_expr else fields
485+
pos_only = () if interp_expr else variables
486+
idx_subs, temps = self._interp_idx(finterp, implicit_dims=implicit_dims,
487+
pos_only=pos_only, subdomain=subdomain)
484488

485489
# Substitute coordinate base symbols into the interpolation coefficients
486490
eqns = [Inc(_field.xreplace(idx_subs),

devito/types/sparse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,7 +1089,8 @@ def interpolate(self, expr, u_t=None, p_t=None, increment=False, implicit_dims=N
10891089
return super().interpolate(expr, increment=increment, self_subs=subs,
10901090
implicit_dims=implicit_dims)
10911091

1092-
def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None):
1092+
def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None,
1093+
interp_expr=False):
10931094
"""
10941095
Generate equations injecting an arbitrary expression into a field.
10951096
@@ -1114,7 +1115,8 @@ def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None):
11141115
if p_t is not None:
11151116
expr = expr.subs({self.time_dim: p_t})
11161117

1117-
return super().inject(field, expr, implicit_dims=implicit_dims)
1118+
return super().inject(field, expr, implicit_dims=implicit_dims,
1119+
interp_expr=interp_expr)
11181120

11191121
@property
11201122
def forward(self):

tests/test_interpolation.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
SparseTimeFunction, SubDomain, TimeFunction, switchconfig
1212
)
1313
from devito.operations.interpolators import LinearInterpolator, SincInterpolator
14+
from devito.symbolics import retrieve_functions
1415
from examples.seismic import (
1516
AcquisitionGeometry, Receiver, RickerSource, TimeAxis, demo_model
1617
)
@@ -569,6 +570,37 @@ def test_inject_from_field(shape, coords, result, npoints=19):
569570
assert np.allclose(a.data[indices], result, rtol=1.e-5)
570571

571572

573+
def test_inject_interp_expr():
574+
"""
575+
Test that the Function coefficient gets interpolated too.
576+
"""
577+
coords = [(.05, .95), (.45, .45)]
578+
a = unit_box(shape=(11, 11))
579+
a.data[:] = 0.
580+
p = points(a.grid, ranges=coords, npoints=19)
581+
m = Function(name='m', grid=a.grid)
582+
m.data_with_halo[:] = 1.
583+
584+
expr = p.inject(a, m, interp_expr=True)
585+
op = Operator(expr)
586+
587+
op(a=a)
588+
589+
indices = [slice(4, 6, 1) for _ in coords]
590+
indices[0] = slice(1, -1, 1)
591+
assert np.allclose(a.data[indices], 1, rtol=1.e-5)
592+
593+
# Extract interp expr to check indices
594+
e_expr = expr.evaluate
595+
funcs = retrieve_functions(e_expr[-1])
596+
assert m in {f.function for f in funcs}
597+
# All funcs should have the same indices wit radius
598+
# includint the coefficient m
599+
indices = {f.indices for f in funcs}
600+
assert len(indices) == 1
601+
assert str(indices.pop()) == '(rp_pointsx + posx, rp_pointsy + posy)'
602+
603+
572604
@pytest.mark.parametrize('shape', [(50, 50, 50)])
573605
def test_position(shape):
574606
t0 = 0.0 # Start time

0 commit comments

Comments
 (0)