Skip to content

Commit 0efaf37

Browse files
committed
compiler: Shed reduction-only halos when lowering injections
After lifting lower_sparse_ops out of mpiize, an injection nest is turned into a Call before optimize_halospots runs, so _drop_reduction_halospots can no longer detect that the wrapping HaloSpot's entry for the injected field is reduction-only. The stale entry was left in place, and on save=True (no modulo buffering) the hoist pass propagated the loop iteration variable out of the time loop, producing an undeclared 'time' reference. Drop those entries at lowering time so the resulting HaloSpot only carries entries with a genuine read at the IET level.
1 parent 7e07f72 commit 0efaf37

10 files changed

Lines changed: 326 additions & 113 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,12 @@ def _print_Pow(self, expr):
212212
suffix = self.func_literal(expr)
213213
base = self._print(expr.base)
214214
if equal_valued(expr.exp, -1):
215-
return self._print_Float(Float(1.0)) + '/' + \
215+
# Pick the literal precision from this Pow's dtype rather than
216+
# the printer default, so e.g. ``Pow(DOUBLE(h), -1)`` emits
217+
# ``1.0/(double)h`` not ``1.0F/(double)h``. This branch only
218+
# fires when the surrounding Mul printer chose not to group
219+
# the Pow into a denominator (e.g. lone reciprocal).
220+
return f'1.0{self.prec_literal(expr)}/' + \
216221
self.parenthesize(expr.base, PREC)
217222
elif equal_valued(expr.exp, 0.5):
218223
return f'{self.ns(expr)}sqrt{suffix}({base})'

devito/operations/interpolators.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from devito.finite_differences.elementary import floor
1515
from devito.logger import warning
1616
from devito.symbolics import INT, retrieve_function_carriers, retrieve_functions
17+
from devito.symbolics.extended_dtypes import DOUBLE
1718
from devito.tools import as_tuple, filter_ordered, memoized_meth
1819
from devito.types import (
1920
Eq, Inc, IncrInterpolation, Injection, Interpolation, SubFunction, Symbol
@@ -257,9 +258,44 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
257258
def _coeff_temps(self, implicit_dims, shifts=None):
258259
return []
259260

261+
@memoized_meth
262+
def _raw_pos_symbols(self, shifts=None):
263+
"""
264+
Per-Dimension Symbol holding the unrounded grid-relative position
265+
``(coord - origin - shift)/h``. Both the integer position
266+
(``floor(...)``) and the linear-interp fractional part
267+
(``... - floor(...)``) reuse this Symbol so the divide-and-shift
268+
expression is emitted only once per sparse point.
269+
"""
270+
dtype = self.sfunction.coordinates.dtype
271+
symbols = []
272+
for d, s in zip(self.grid.dimensions,
273+
shifts or (0,) * len(self.grid.dimensions),
274+
strict=True):
275+
suffix = '_s1' if s != 0 else ''
276+
symbols.append(Symbol(name=f'rpos{d}{suffix}', dtype=dtype))
277+
return DimensionTuple(*symbols, getters=self.grid.dimensions)
278+
260279
def _positions(self, implicit_dims, shifts=None):
261-
return [Eq(v, INT(floor(k)), implicit_dims=implicit_dims)
262-
for k, v in self.sfunction._position_map(shifts=shifts).items()]
280+
# The ``(coord - origin)/h`` subtract is the only step that can lose
281+
# precision to catastrophic cancellation when ``coord`` and ``origin``
282+
# are large and close to each other (e.g. an origin-shifted survey).
283+
# Promote ``origin`` and ``h`` to float64 so the subtract and divide
284+
# happen in double precision in C (one cast operand promotes the
285+
# whole expression); the result narrows to the field dtype on store
286+
# to ``rpos*`` so downstream ``floor`` / fractional math stays in
287+
# the field dtype.
288+
rposs = self._raw_pos_symbols(shifts=shifts)
289+
subs = {o: DOUBLE(o) for o in self.grid.origin_symbols}
290+
subs.update({d.spacing: DOUBLE(d.spacing) for d in self._gdims})
291+
return [Eq(rposs[d], k.xreplace(subs), implicit_dims=implicit_dims)
292+
for d, k in zip(self._gdims,
293+
self.sfunction._position_map(shifts=shifts),
294+
strict=True)] + \
295+
[Eq(v, INT(floor(rposs[d])), implicit_dims=implicit_dims)
296+
for d, v in zip(self._gdims,
297+
self.sfunction._position_map(shifts=shifts).values(),
298+
strict=True)]
263299

264300
def sparse_temps(self, rhs, implicit_dims, field=None):
265301
"""
@@ -458,13 +494,14 @@ def _point_symbols(self, shifts=None):
458494
return DimensionTuple(*symbols, getters=self.grid.dimensions)
459495

460496
def _coeff_temps(self, implicit_dims, shifts=None):
461-
# Positions
462-
pmap = self.sfunction._position_map(shifts=shifts)
497+
# The fractional part of the unrounded position; reuse the
498+
# ``rpos*`` Symbols emitted by ``_positions`` rather than the full
499+
# ``(c - o)/h`` expression so the divide is computed only once.
500+
rposs = self._raw_pos_symbols(shifts=shifts)
463501
psyms = self._point_symbols(shifts)
464-
poseq = [Eq(psyms[d], pos - floor(pos),
465-
implicit_dims=implicit_dims)
466-
for (d, pos) in zip(self._gdims, pmap.keys(), strict=True)]
467-
return poseq
502+
return [Eq(psyms[d], rposs[d] - floor(rposs[d]),
503+
implicit_dims=implicit_dims)
504+
for d in self._gdims]
468505

469506

470507
class PrecomputedInterpolator(WeightedInterpolator):

devito/passes/iet/parpragma.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,16 @@ def _make_simd(self, iet):
130130
if any(i.is_Indexed for i in reductions):
131131
continue
132132

133+
# A scalar reduction in the loop body (e.g. ``acc += ...`` with
134+
# ``acc`` a Symbol, as emitted by the interpolation accumulator
135+
# pattern) is a cross-iteration dependence the SIMD pragma alone
136+
# can't express -- it would need a ``reduction(+:acc)`` clause,
137+
# which we don't emit here. Without it, ``_Complex`` accumulators
138+
# are miscompiled by some gcc releases.
139+
exprs = FindNodes(Expression).visit(candidate)
140+
if any(e.is_reduction and not e.output.is_Indexed for e in exprs):
141+
continue
142+
133143
# Add SIMD pragma
134144
simd = self._make_simd_pragma(candidate)
135145
pragmas = candidate.pragmas + simd

devito/passes/iet/sparse.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,16 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
6969
continue
7070
groups.setdefault(nest, []).append(expr)
7171

72-
# If a sparse-op nest sits inside a HaloSpot whose halo scheme is
73-
# void (the reduction-only halo got dropped by
74-
# ``_drop_reduction_halospots``), replace the HaloSpot rather than
75-
# just the nest so we don't leave behind an empty HaloSpot — the
76-
# MPI overlap machinery would otherwise try to wrap our Call with
77-
# its own dynamic-args plumbing.
78-
parents = {nest: _enclosing_void_halospot(iet, nest) for nest in groups}
72+
# ``lower_sparse_ops`` runs before ``optimize_halospots``, so the
73+
# halo-exchange optimiser hasn't yet had a chance to drop the
74+
# reduction-only halo entries that the IR scheduler put around an
75+
# injection nest (e.g. an entry for ``u`` at ``loc_indices={time:
76+
# time+1}`` wrapping ``u[time+1] += ...``). Once the nest becomes a
77+
# Call those expressions are no longer visible to
78+
# ``_drop_reduction_halospots``, so we shed those entries here -- and
79+
# if that empties the HaloSpot, replace it whole so the MPI overlap
80+
# machinery doesn't wrap our Call with stale dynamic-args plumbing.
81+
parents = {nest: _enclosing_halospot(iet, nest) for nest in groups}
7982

8083
mapper = {}
8184
efuncs = []
@@ -87,7 +90,26 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
8790
efunc = make_callable(sregistry.make_name(prefix=prefix), new_nest)
8891
efuncs.append(efunc)
8992

90-
mapper[parents[nest] or nest] = Call(efunc.name, list(efunc.parameters))
93+
call = Call(efunc.name, list(efunc.parameters))
94+
parent = parents[nest]
95+
if parent is None:
96+
mapper[nest] = call
97+
continue
98+
99+
# Drop fields that the (now-opaque) Call only writes/increments,
100+
# since the wrapping HaloSpot's purpose was to ensure read-side
101+
# coherency for them and the read no longer exists at the IET
102+
# level. Interpolation reads its target field, so its entries
103+
# stay.
104+
reduced = {e.expr.lhs.function for e in exprs
105+
if isinstance(e.expr, InjectionMixin)}
106+
hs = parent.halo_scheme.drop(reduced) if reduced else parent.halo_scheme
107+
if hs.is_void:
108+
mapper[parent] = call
109+
elif hs is parent.halo_scheme:
110+
mapper[nest] = call
111+
else:
112+
mapper[parent] = parent._rebuild(halo_scheme=hs, body=call)
91113

92114
if not mapper:
93115
return iet, {}
@@ -107,14 +129,12 @@ def _find_outer_iteration(iet, expr):
107129
return None
108130

109131

110-
def _enclosing_void_halospot(iet, nest):
132+
def _enclosing_halospot(iet, nest):
111133
"""
112-
Return the HaloSpot directly wrapping ``nest`` if it carries an
113-
empty (void) HaloScheme, otherwise None. Such HaloSpots are leftover
114-
after ``_drop_reduction_halospots`` cleared all entries.
134+
Return the HaloSpot directly wrapping ``nest``, if any.
115135
"""
116136
for hs in FindNodes(HaloSpot).visit(iet):
117-
if hs.is_void and nest in FindNodes(Iteration).visit(hs):
137+
if nest in FindNodes(Iteration).visit(hs):
118138
return hs
119139
return None
120140

devito/symbolics/extended_sympy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ class UnaryOp(Expr, Pickable, BasicWrapperMixin):
380380

381381
_op = ''
382382

383+
is_commutative = True
384+
383385
__rargs__ = ('base',)
384386

385387
def __new__(cls, base, **kwargs):

devito/symbolics/inspection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,14 @@ def sympy_dtype(expr, base=None, default=None, smin=None):
321321
with suppress(AttributeError):
322322
dtypes.add(i.dtype)
323323

324+
# A ``Cast`` overrides the dtype of the symbol(s) it wraps -- e.g.
325+
# ``DOUBLE(o_x)`` is observably double-typed even though ``o_x``
326+
# itself is float-typed. Without this, the C printer would pick
327+
# float-precision literals (``1.0F``) when emitting a ``Pow(DOUBLE(h),
328+
# -1)`` inside a wider expression.
329+
for c in expr.atoms(Cast):
330+
dtypes.add(c.dtype)
331+
324332
if not dtypes or not np.issubdtype(base, np.complexfloating):
325333
dtypes.update({base} - {None})
326334

devito/types/dense.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1633,8 +1633,11 @@ def _arg_values(self, **kwargs):
16331633
return self._arg_defaults(alias=self)
16341634

16351635
def _arg_apply(self, *args, **kwargs):
1636+
# Parent-owned SubFunction data is computed once and read by the
1637+
# Operator; the parent gathers its own data via its own parameter
1638+
# entry, so there's nothing for the SubFunction to do here.
16361639
if self._parent is not None:
1637-
return self._parent._arg_apply(*args, **kwargs)
1640+
return
16381641
return super()._arg_apply(*args, **kwargs)
16391642

16401643
@property

0 commit comments

Comments
 (0)