Skip to content

Commit 6c6d45f

Browse files
committed
api: prevent evaluated derivatives to be re-evaluted
1 parent c8816fd commit 6c6d45f

9 files changed

Lines changed: 54 additions & 17 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,11 @@ def _evaluate(self, **kwargs):
952952
class DiffDerivative(IndexDerivative, DifferentiableOp):
953953
pass
954954

955+
def _eval_at(self, func):
956+
# Like EvalDerivative, a DiffDerivative must have already been evaluated
957+
# at a valid x0 and should not be re-evaluated at a different location
958+
return self
959+
955960

956961
# SymPy args ordering is the same for Derivatives and IndexDerivatives
957962
for i in ('DiffDerivative', 'IndexDerivative'):
@@ -998,6 +1003,11 @@ def _new_rawargs(self, *args, **kwargs):
9981003
kwargs.pop('is_commutative', None)
9991004
return self.func(*args, **kwargs)
10001005

1006+
def _eval_at(self, func):
1007+
# An EvalDerivative must have already been evaluated at a valid x0
1008+
# and should not be re-evaluated at a different location
1009+
return self
1010+
10011011

10021012
class diffify:
10031013

devito/types/basic.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -995,9 +995,23 @@ def _grid_map(self):
995995
Mapper of off-grid interpolation points indices for each dimension.
996996
"""
997997
mapper = {}
998+
subs = {}
998999
for i, j, d in zip(self.indices, self.indices_ref, self.dimensions):
9991000
# Two indices are aligned if they differ by an Integer*spacing.
1000-
v = (i - j)/d.spacing
1001+
if not i.has(d):
1002+
# Maybe a subdimension
1003+
dims = {sd for sd in i.free_symbols if getattr(sd, 'is_Dimension', False)
1004+
and d in sd._defines}
1005+
# More than one dimension, cannot handle
1006+
if len(dims) != 1:
1007+
continue
1008+
sd = dims.pop()
1009+
v = (i - j._subs(d, sd))/d.spacing
1010+
i = i._subs(sd, d)
1011+
subs[d] = sd
1012+
else:
1013+
v = (i - j)/d.spacing
1014+
10011015
try:
10021016
if not isinstance(v, sympy.Number) or int(v) == v:
10031017
continue
@@ -1008,6 +1022,11 @@ def _grid_map(self):
10081022
mapper.update({d: i})
10091023
except (AttributeError, TypeError):
10101024
mapper.update({d: i})
1025+
1026+
# Substitutions for self.function
1027+
if mapper:
1028+
mapper['subs'] = subs
1029+
10111030
return mapper
10121031

10131032
def _evaluate(self, **kwargs):
@@ -1019,8 +1038,10 @@ def _evaluate(self, **kwargs):
10191038
This allow to evaluate off grid points as EvalDerivative that are better
10201039
for the compiler.
10211040
"""
1041+
mapper = self._grid_map
1042+
subs = mapper.pop('subs', {})
10221043
# Average values if at a location not on the Function's grid
1023-
if not self._grid_map:
1044+
if not mapper:
10241045
return self
10251046

10261047
io = self.interp_order
@@ -1031,17 +1052,19 @@ def _evaluate(self, **kwargs):
10311052
retval = self.function
10321053

10331054
# Apply interpolation from inner most dim
1034-
for d, i in self._grid_map.items():
1055+
for d, i in mapper.items():
10351056
retval = retval.diff(d, deriv_order=0, fd_order=io, x0={d: i})
10361057

10371058
# Evaluate. Since we used `self.function` it will be on the grid when
10381059
# evaluate is called again within FD
10391060
retval = retval._evaluate(**kwargs)
1061+
if subs:
1062+
retval = retval.subs(subs)
10401063

10411064
# If harmonic averaging, invert at the end
10421065
if self._avg_mode == 'harmonic':
10431066
from devito.finite_differences.differentiable import SafeInv
1044-
retval = SafeInv(retval, self.function)
1067+
retval = SafeInv(retval, self.function.subs(subs))
10451068

10461069
return retval
10471070

devito/types/dense.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,9 +1073,15 @@ def _fd_priority(self):
10731073
def _eval_at(self, func):
10741074
if self.staggered == func.staggered:
10751075
return self
1076-
mapper = {self.indices_ref[d]: func.indices_ref[d]
1077-
for d in self.dimensions
1078-
if self.indices_ref[d] is not func.indices_ref[d]}
1076+
1077+
mapper = {}
1078+
for d in self.dimensions:
1079+
try:
1080+
if self.indices_ref[d] is not func.indices_ref[d]:
1081+
mapper[self.indices_ref[d]] = func.indices_ref[d]
1082+
except KeyError:
1083+
pass
1084+
10791085
if mapper:
10801086
return self.subs(mapper)
10811087
return self

examples/seismic/elastic/elastic_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def run(shape=(50, 50), spacing=(20.0, 20.0), tn=1000.0,
3939
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
4040
def test_elastic(dtype):
4141
_, _, _, [rec1, rec2, v, tau] = run(dtype=dtype)
42-
assert np.isclose(norm(rec1), 19.9367, atol=1e-3, rtol=0)
42+
assert np.isclose(norm(rec1), 19.9368, atol=1e-3, rtol=0)
4343
assert np.isclose(norm(rec2), 0.6512, atol=1e-3, rtol=0)
4444

4545

examples/seismic/viscoelastic/viscoelastic_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def run(shape=(50, 50), spacing=(20.0, 20.0), tn=1000.0,
4141
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
4242
def test_viscoelastic(dtype):
4343
_, _, _, [rec1, rec2, v, tau] = run(dtype=dtype)
44-
assert np.isclose(norm(rec1), 12.62339, atol=1e-3, rtol=0)
45-
assert np.isclose(norm(rec2), 0.320817, atol=1e-3, rtol=0)
44+
assert np.isclose(norm(rec1), 12.6235, atol=1e-3, rtol=0)
45+
assert np.isclose(norm(rec2), 0.32201, atol=1e-3, rtol=0)
4646

4747

4848
@pytest.mark.parametrize('shape', [(51, 51), (16, 16, 16)])

examples/userapi/07_functions_on_subdomains.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3003,7 +3003,7 @@
30033003
"metadata": {},
30043004
"outputs": [],
30053005
"source": [
3006-
"assert np.isclose(np.linalg.norm(rec.data), 4263.511, atol=0, rtol=1e-4)"
3006+
"assert np.isclose(np.linalg.norm(rec.data), 3640.584, atol=0, rtol=1e-4)"
30073007
]
30083008
}
30093009
],

tests/test_derivatives.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -805,13 +805,11 @@ def test_param_stagg_add(self):
805805
eq1 = Eq(vx, (c11 * txx).dy)
806806
eq2 = Eq(vx, (c11 * txx + c66 * txy).dy)
807807

808-
# C66 is a paramater. Expects to evaluate c66 at xp then the derivative at yp
809-
# and the derivative will interpolate txy at xp
808+
# Expects to evaluate c66 at xp then the derivative at yp
810809
expect0 = (c66.subs({x: xp, y: yp}).evaluate * txy).dy.evaluate
811810
assert simplify(eq0.evaluate.rhs - expect0) == 0
812811

813-
# C11 is a paramater and txy is staggered in x.
814-
# Expects to evaluate c11 and txy xp then the derivative at yp
812+
# Expects to evaluate c11 and txy at xp then the derivative at yp
815813
expect1 = (c11._subs(x, xp).evaluate * txx._subs(x, xp).evaluate).dy.evaluate
816814
assert simplify(eq1.evaluate.rhs - expect1) == 0
817815

tests/test_differentiable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_shift():
5757
assert a.shift(x, x.spacing).shift(x, -x.spacing) == a
5858
assert a.shift(x, x.spacing).shift(x, x.spacing) == a.shift(x, 2*x.spacing)
5959
assert a.dx.evaluate.shift(x, x.spacing) == a.shift(x, x.spacing).dx.evaluate
60-
assert a.shift(x, .5 * x.spacing)._grid_map == {x: x + .5 * x.spacing}
60+
assert a.shift(x, .5 * x.spacing)._grid_map == {x: x + .5 * x.spacing, 'subs': {}}
6161

6262

6363
def test_interp():

tests/test_symbolics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def test_is_on_grid():
716716
u = Function(name="u", grid=grid, space_order=2)
717717

718718
assert u._grid_map == {}
719-
assert u.subs({x: x0})._grid_map == {x: x0}
719+
assert u.subs({x: x0})._grid_map == {x: x0, 'subs': {}}
720720
assert all(uu._grid_map == {} for uu in retrieve_functions(u.subs({x: x0}).evaluate))
721721

722722

0 commit comments

Comments
 (0)