|
14 | 14 | from devito.finite_differences.elementary import floor |
15 | 15 | from devito.logger import warning |
16 | 16 | from devito.symbolics import INT, retrieve_function_carriers, retrieve_functions |
| 17 | +from devito.symbolics.extended_dtypes import DOUBLE |
17 | 18 | from devito.tools import as_tuple, filter_ordered, memoized_meth |
18 | 19 | from devito.types import ( |
19 | 20 | Eq, Inc, IncrInterpolation, Injection, Interpolation, SubFunction, Symbol |
@@ -257,9 +258,44 @@ def _augment_implicit_dims(self, implicit_dims, extras=None): |
257 | 258 | def _coeff_temps(self, implicit_dims, shifts=None): |
258 | 259 | return [] |
259 | 260 |
|
| 261 | + @memoized_meth |
| 262 | + def _raw_pos_symbols(self, shifts=None): |
| 263 | + """ |
| 264 | + Per-Dimension Symbol holding the unrounded grid-relative position |
| 265 | + ``(coord - origin - shift)/h``. Both the integer position |
| 266 | + (``floor(...)``) and the linear-interp fractional part |
| 267 | + (``... - floor(...)``) reuse this Symbol so the divide-and-shift |
| 268 | + expression is emitted only once per sparse point. |
| 269 | + """ |
| 270 | + dtype = self.sfunction.coordinates.dtype |
| 271 | + symbols = [] |
| 272 | + for d, s in zip(self.grid.dimensions, |
| 273 | + shifts or (0,) * len(self.grid.dimensions), |
| 274 | + strict=True): |
| 275 | + suffix = '_s1' if s != 0 else '' |
| 276 | + symbols.append(Symbol(name=f'rpos{d}{suffix}', dtype=dtype)) |
| 277 | + return DimensionTuple(*symbols, getters=self.grid.dimensions) |
| 278 | + |
260 | 279 | def _positions(self, implicit_dims, shifts=None): |
261 | | - return [Eq(v, INT(floor(k)), implicit_dims=implicit_dims) |
262 | | - for k, v in self.sfunction._position_map(shifts=shifts).items()] |
| 280 | + # The ``(coord - origin)/h`` subtract is the only step that can lose |
| 281 | + # precision to catastrophic cancellation when ``coord`` and ``origin`` |
| 282 | + # are large and close to each other (e.g. an origin-shifted survey). |
| 283 | + # Promote ``origin`` and ``h`` to float64 so the subtract and divide |
| 284 | + # happen in double precision in C (one cast operand promotes the |
| 285 | + # whole expression); the result narrows to the field dtype on store |
| 286 | + # to ``rpos*`` so downstream ``floor`` / fractional math stays in |
| 287 | + # the field dtype. |
| 288 | + rposs = self._raw_pos_symbols(shifts=shifts) |
| 289 | + subs = {o: DOUBLE(o) for o in self.grid.origin_symbols} |
| 290 | + subs.update({d.spacing: DOUBLE(d.spacing) for d in self._gdims}) |
| 291 | + return [Eq(rposs[d], k.xreplace(subs), implicit_dims=implicit_dims) |
| 292 | + for d, k in zip(self._gdims, |
| 293 | + self.sfunction._position_map(shifts=shifts), |
| 294 | + strict=True)] + \ |
| 295 | + [Eq(v, INT(floor(rposs[d])), implicit_dims=implicit_dims) |
| 296 | + for d, v in zip(self._gdims, |
| 297 | + self.sfunction._position_map(shifts=shifts).values(), |
| 298 | + strict=True)] |
263 | 299 |
|
264 | 300 | def sparse_temps(self, rhs, implicit_dims, field=None): |
265 | 301 | """ |
@@ -458,13 +494,14 @@ def _point_symbols(self, shifts=None): |
458 | 494 | return DimensionTuple(*symbols, getters=self.grid.dimensions) |
459 | 495 |
|
460 | 496 | def _coeff_temps(self, implicit_dims, shifts=None): |
461 | | - # Positions |
462 | | - pmap = self.sfunction._position_map(shifts=shifts) |
| 497 | + # The fractional part of the unrounded position; reuse the |
| 498 | + # ``rpos*`` Symbols emitted by ``_positions`` rather than the full |
| 499 | + # ``(c - o)/h`` expression so the divide is computed only once. |
| 500 | + rposs = self._raw_pos_symbols(shifts=shifts) |
463 | 501 | psyms = self._point_symbols(shifts) |
464 | | - poseq = [Eq(psyms[d], pos - floor(pos), |
465 | | - implicit_dims=implicit_dims) |
466 | | - for (d, pos) in zip(self._gdims, pmap.keys(), strict=True)] |
467 | | - return poseq |
| 502 | + return [Eq(psyms[d], rposs[d] - floor(rposs[d]), |
| 503 | + implicit_dims=implicit_dims) |
| 504 | + for d in self._gdims] |
468 | 505 |
|
469 | 506 |
|
470 | 507 | class PrecomputedInterpolator(WeightedInterpolator): |
|
0 commit comments