Skip to content

Commit 5c9a027

Browse files
committed
api: add option to inject without increment
1 parent 39f2b57 commit 5c9a027

4 files changed

Lines changed: 35 additions & 11 deletions

File tree

devito/operations/interpolators.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,20 @@ class Injection(UnevaluatedSparseOperation):
166166

167167
__rargs__ = ('field', 'expr', 'implicit_dims') + UnevaluatedSparseOperation.__rargs__
168168

169-
def __new__(cls, field, expr, implicit_dims, interpolator):
169+
def __new__(cls, field, expr, increment, implicit_dims, interpolator):
170170
obj = super().__new__(cls, interpolator)
171171

172172
# TODO: unused now, but will be necessary to compute the adjoint
173173
obj.field = field
174174
obj.expr = expr
175+
obj.increment = increment
175176
obj.implicit_dims = implicit_dims
176177

177178
return obj
178179

179180
def operation(self, **kwargs):
180181
return self.interpolator._inject(expr=self.expr, field=self.field,
182+
increment=self.increment,
181183
implicit_dims=self.implicit_dims)
182184

183185
def __repr__(self):
@@ -372,7 +374,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
372374

373375
@check_radius
374376
@check_coords
375-
def inject(self, field, expr, implicit_dims=None):
377+
def inject(self, field, expr, increment=True, implicit_dims=None):
376378
"""
377379
Generate equations injecting an arbitrary expression into a field.
378380
@@ -387,7 +389,7 @@ def inject(self, field, expr, implicit_dims=None):
387389
injection expression, but that should be honored when constructing
388390
the operator.
389391
"""
390-
return Injection(field, expr, implicit_dims, self)
392+
return Injection(field, expr, increment, implicit_dims, self)
391393

392394
def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
393395
"""
@@ -439,7 +441,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
439441

440442
return temps + summands + last
441443

442-
def _inject(self, field, expr, implicit_dims=None):
444+
def _inject(self, field, expr, increment=True, implicit_dims=None):
443445
"""
444446
Generate equations injecting an arbitrary expression into a field.
445447
@@ -489,9 +491,10 @@ def _inject(self, field, expr, implicit_dims=None):
489491
pos_only=variables, subdomain=subdomain)
490492

491493
# Substitute coordinate base symbols into the interpolation coefficients
492-
eqns = [Inc(_field.xreplace(idx_subs),
493-
(self._weights(subdomain=subdomain) * _expr).xreplace(idx_subs),
494-
implicit_dims=implicit_dims)
494+
ecls = Inc if increment else Eq
495+
eqns = [ecls(_field.xreplace(idx_subs),
496+
(self._weights(subdomain=subdomain) * _expr).xreplace(idx_subs),
497+
implicit_dims=implicit_dims)
495498
for (_field, _expr) in zip(fields, _exprs, strict=True)]
496499

497500
return temps + eqns

examples/seismic/self_adjoint/sa_03_iso_correctness.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,9 @@
578578
"output_type": "stream",
579579
"text": [
580580
"Operator `IsoFwdOperator` ran in 0.04 s\n",
581-
"No source type defined, returning uninitiallized (zero) source\n",
581+
"No source type defined, returning uninitialized (zero) source\n",
582582
"Operator `IsoAdjOperator` ran in 0.03 s\n",
583-
"No source type defined, returning uninitiallized (zero) source\n",
583+
"No source type defined, returning uninitialized (zero) source\n",
584584
"Operator `IsoAdjOperator` ran in 0.03 s\n"
585585
]
586586
},
@@ -639,7 +639,7 @@
639639
"output_type": "stream",
640640
"text": [
641641
"Operator `IsoFwdOperator` ran in 0.03 s\n",
642-
"No source type defined, returning uninitiallized (zero) source\n",
642+
"No source type defined, returning uninitialized (zero) source\n",
643643
"Operator `IsoAdjOperator` ran in 0.03 s\n"
644644
]
645645
},

examples/seismic/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def src(self):
194194
def new_src(self, name='src', src_type='self', coordinates=None):
195195
coords = coordinates or self.src_positions
196196
if self.src_type is None or src_type is None:
197-
warning("No source type defined, returning uninitiallized (zero) source")
197+
warning("No source type defined, returning uninitialized (zero) source")
198198
src = PointSource(name=name, grid=self.grid,
199199
time_range=self.time_axis, npoint=self.nsrc,
200200
coordinates=coords,

tests/test_interpolation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,27 @@ def test_inject(shape, coords, result, npoints=19):
440440
assert np.allclose(a.data[indices], result, rtol=1.e-5)
441441

442442

443+
@pytest.mark.parametrize('shape, coords', [
444+
((11, 11), [(.1, .9), (.4, .4)]),
445+
((11, 11, 11), [(.1, .9), (.4, .4), (.4, .4)])
446+
])
447+
def test_inject_no_incr(shape, coords, npoints=9):
448+
a = unit_box(shape=shape)
449+
a.data[:] = 2.
450+
p = points(a.grid, coords, npoints=npoints)
451+
452+
p.data[:] = 3.
453+
expr = p.inject(a, p, increment=False)
454+
op = Operator(expr, subs=a.grid.spacing_map)
455+
456+
op(a=a)
457+
458+
indices = [slice(4, 5, 1) for _ in coords]
459+
indices[0] = slice(1, -1, 1)
460+
# Should be 3 at the points
461+
assert np.allclose(a.data[indices], 3, rtol=1.e-5)
462+
463+
443464
@pytest.mark.parametrize('shape, coords, nexpr, result', [
444465
((11, 11), [(.05, .95), (.45, .45)], 1, 1.),
445466
((11, 11), [(.05, .95), (.45, .45)], 2, 1.),

0 commit comments

Comments
 (0)