Skip to content

Commit f1b29be

Browse files
committed
compiler: fix for kernel
1 parent ff53f3b commit f1b29be

7 files changed

Lines changed: 189 additions & 41 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,15 @@ def callback(self, clusters, prefix, seen=None):
477477
key = lambda i: key0(i) and key1(i) # noqa: B023
478478
ispace = c.ispace.project(key)
479479

480-
properties = c.properties.sequentialize()
480+
# Only sequentialize the Dimensions that actually drive the
481+
# halo (its ``loc_indices`` and the Dimensions they define).
482+
# Sequentialising *every* outer Dimension would also bring
483+
# along independent iterators like ``p_rec`` for an
484+
# interpolation, vetoing downstream blocking of the real
485+
# cluster the HaloTouch sits alongside.
486+
relevant = set().union(*(i._defines for i in hs.loc_indices))
487+
seq_dims = [d for d in c.properties if d in relevant]
488+
properties = c.properties.sequentialize(seq_dims)
481489

482490
halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)
483491

devito/ir/clusters/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class EqBlock(CacheInstances):
3434
@classmethod
3535
def _preprocess_args(cls, exprs, ispace=null_ispace, guards=None,
3636
properties=None, syncs=None, halo_scheme=None):
37-
exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs))
37+
exprs = tuple(clusterize_eq(e, ispace=ispace) for e in as_tuple(exprs))
3838
guards = Guards(guards or {})
3939
properties = Properties(properties or {})
4040
syncs = normalize_syncs(syncs or {})

devito/ir/equations/equation.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from devito.tools import (
1414
Pickable, Tag, as_hashable, filter_sorted, frozendict, reuse_if_unchanged
1515
)
16-
from devito.symbolics import IntDiv, limits_mapper, uxreplace
17-
from devito.tools import Pickable, Tag, frozendict
1816
from devito.types import (
1917
Eq, Inc, IncrInterpolation, Injection, InjectionMixin, Interpolation,
2018
InterpolationMixin, ReduceMax, ReduceMin, ReduceMinMax, SparseEq, SparseOpMixin,
@@ -208,6 +206,19 @@ def detect(cls, expr):
208206
(ReduceMin, OpMin),
209207
(ReduceMax, OpMax),
210208
(Inc, OpInc),
209+
# An ``Interpolation`` looks like a plain ``Eq`` -- ``sf[p_*] =
210+
# expr[rp_*]`` -- but the cluster scheduler iterates the rhs
211+
# over the radius dims, so values are implicitly summed across
212+
# ``rp_*``. Tagging it as ``OpInc`` makes the dependence
213+
# analysis treat ``rp_*`` as reduction dims
214+
# (``parallel_if_atomic``), which matches the lowered code
215+
# (``sumrec += ... ; sf[p_*] = sumrec``) and stops the
216+
# blocking heuristic from turning the tiny radius loops into
217+
# thread blocks. The actual write-back flavour at ``sf[p_*]``
218+
# is keyed off the Eq's class (``is_increment_writeback``) in
219+
# ``lower_sparse_ops`` so this tag doesn't accidentally turn
220+
# ``Interpolation`` assignments into increments.
221+
(InterpolationMixin, OpInc),
211222
)
212223
for kls, op in reduction_mapper:
213224
if isinstance(expr, kls):
@@ -366,12 +377,17 @@ class LoweredSparseEq(SparseOpMixin, LoweredEq):
366377

367378
class LoweredInterpolation(InterpolationMixin, LoweredSparseEq):
368379
"""IR counterpart of ``Interpolation``."""
369-
pass
380+
# ``sf[p_*] = ...``: the write-back at the sparse position is an
381+
# assignment. The Eq is still tagged as a reduction
382+
# (``OpInc``/``is_Reduction``) because the rhs is summed over the
383+
# radius dims; only the *outer* write-back to ``sf[p_*]`` is plain.
384+
is_increment_writeback = False
370385

371386

372387
class LoweredIncrInterpolation(InterpolationMixin, LoweredSparseEq):
373388
"""IR counterpart of ``IncrInterpolation``."""
374-
pass
389+
# ``sf[p_*] += ...``: the user asked for ``interpolate(..., increment=True)``.
390+
is_increment_writeback = True
375391

376392

377393
class LoweredInjection(InjectionMixin, LoweredSparseEq):
@@ -458,12 +474,12 @@ class ClusterizedSparseEq(SparseOpMixin, ClusterizedEq):
458474

459475
class ClusterizedInterpolation(InterpolationMixin, ClusterizedSparseEq):
460476
"""Frozen counterpart of ``LoweredInterpolation``."""
461-
pass
477+
is_increment_writeback = False
462478

463479

464480
class ClusterizedIncrInterpolation(InterpolationMixin, ClusterizedSparseEq):
465481
"""Frozen counterpart of ``LoweredIncrInterpolation``."""
466-
pass
482+
is_increment_writeback = True
467483

468484

469485
class ClusterizedInjection(InjectionMixin, ClusterizedSparseEq):

devito/ir/stree/algorithms.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,12 @@ def preprocess(clusters, options=None, **kwargs):
182182
continue
183183

184184
else:
185-
dims = set(c.ispace.promote(lambda d: d.is_Block).itdims)
185+
itdims = set(c.ispace.promote(lambda d: d.is_Block).itdims)
186+
# Expand the iteration dims by ``_defines`` so that derived
187+
# Dimensions (e.g. ``rp_*`` radius dims with the grid
188+
# Dimension as parent) match the halo's distributed indices
189+
# of the Dimensions they ultimately iterate over
190+
dims = set().union(*[d._defines for d in itdims])
186191

187192
found = []
188193
for c1 in list(queue):

devito/passes/iet/sparse.py

Lines changed: 125 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from devito.ir.equations import DummyEq
2929
from devito.ir.equations.algorithms import lower_exprs
3030
from devito.ir.iet import (
31-
Call, EntryFunction, Expression, FindNodes, HaloSpot, Increment, Iteration, List,
32-
Transformer, make_callable
31+
Call, Conditional, EntryFunction, Expression, ExpressionBundle, FindNodes, HaloSpot,
32+
Increment, Iteration, List, Transformer, make_callable
3333
)
3434
from devito.passes.iet.engine import iet_pass
3535
from devito.types import Eq, InjectionMixin, InterpolationMixin, Symbol
@@ -91,9 +91,22 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
9191
efuncs.append(efunc)
9292

9393
call = Call(efunc.name, list(efunc.parameters))
94+
95+
# Any HaloSpot living inside the nest (e.g. around the radius
96+
# loops reading a grid Function) becomes invisible once the
97+
# nest collapses into an opaque Call. Hoist its halo_scheme out
98+
# so the halo exchange still happens at the right point in the
99+
# parent IET, sitting next to the Call rather than buried in it.
100+
# An injection's lhs is *written* (not read) so its inner halo
101+
# entry for that Function is a reduction-only halo that the
102+
# caller has nothing to read back -- drop it before hoisting.
103+
reduced = {e.expr.lhs.function for e in exprs
104+
if isinstance(e.expr, InjectionMixin)}
105+
prelude = _hoisted_halo_prelude(nest, reduced)
106+
94107
parent = parents[nest]
95108
if parent is None:
96-
mapper[nest] = call
109+
mapper[nest] = List(body=prelude + (call,)) if prelude else call
97110
continue
98111

99112
# Drop fields that the (now-opaque) Call only writes/increments,
@@ -120,7 +133,10 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
120133
def _find_outer_iteration(iet, expr):
121134
"""
122135
Walk up the IET from ``expr`` and return the outermost Iteration
123-
whose ``dim.root`` is the SparseFunction's sparse Dimension.
136+
whose ``dim.root`` is the SparseFunction's sparse Dimension. This
137+
is the entry point of the sparse-op nest in the parent IET; the
138+
full nest (including any block Iterations the cluster pipeline
139+
wrapped around the sparse loop) gets extracted into the efunc.
124140
"""
125141
sparse_dim = expr.expr.interpolator.sfunction._sparse_dim
126142
for it in FindNodes(Iteration).visit(iet):
@@ -139,13 +155,60 @@ def _enclosing_halospot(iet, nest):
139155
return None
140156

141157

158+
def _hoisted_halo_prelude(nest, reduced=None):
159+
"""
160+
Build a tuple of nodes that recreate, outside the sparse nest, any
161+
HaloSpots that live inside it. Each hoisted HaloSpot is wrapped in
162+
the Iterations from ``nest`` whose dim is referenced by the halo's
163+
``loc_indices`` (e.g. the ``ps`` loop wrapping a halo whose
164+
``loc_indices = {ps}``), so the lowering pipeline that turns the
165+
HaloSpot into a halo-exchange Call still sees the indices it needs
166+
in scope.
167+
168+
``reduced``, when provided, lists Functions whose halo entries
169+
should be dropped from the hoisted scheme -- these are the
170+
Functions the sparse Call only writes/increments (e.g. an
171+
injection's lhs), so the parent never reads them back and the
172+
halo update would be redundant.
173+
"""
174+
reduced = reduced or set()
175+
176+
inner = []
177+
for hs in FindNodes(HaloSpot).visit(nest):
178+
scheme = hs.halo_scheme.drop(reduced) if reduced else hs.halo_scheme
179+
if not scheme.is_void:
180+
inner.append(scheme)
181+
if not inner:
182+
return ()
183+
184+
iters = FindNodes(Iteration).visit(nest)
185+
prelude = []
186+
for scheme in inner:
187+
loc_dims = set().union(*(d._defines for d in scheme.loc_indices))
188+
wrappers = [it for it in iters if it.dim in loc_dims]
189+
body = HaloSpot(List(body=[]), scheme)
190+
for it in reversed(wrappers):
191+
body = it._rebuild(nodes=body)
192+
prelude.append(body)
193+
return tuple(prelude)
194+
195+
142196
def _materialise_nest(nest, exprs):
143197
"""
144198
Rewrite the sparse Dimension's Iteration body to compute the
145199
position/coefficient temps once per sparse point, then for any
146200
interpolation Expression wrap it with the scalar accumulator
147201
pattern. Multiple sparse-op Expressions sharing the same outer
148202
Iteration are materialised in one pass and reuse the same temps.
203+
204+
``nest`` is the *outermost* sparse-Dimension Iteration, so that the
205+
whole block-Iteration hierarchy (e.g. ``p_rec0_blk0`` -> ``p_rec``
206+
on the GPU pipeline) is extracted into the efunc and downstream GPU
207+
kernel synthesis can fold the block loops into a thread-grid
208+
wrapping the kernel body. The temps and the accumulator pattern,
209+
however, must live *inside* the innermost sparse Iteration -- one
210+
set per sparse point, sitting beneath any thread-index/OOB-guard
211+
prelude that the GPU kernel prep may have inserted.
149212
"""
150213
# Position + coefficient temporaries as IET Expressions. These are
151214
# the same for every Expression in the group, so we emit them once.
@@ -155,24 +218,63 @@ def _materialise_nest(nest, exprs):
155218
temp_exprs = tuple(Expression(DummyEq(e.lhs, e.rhs))
156219
for e in lower_exprs(sample.sparse_temps()))
157220

158-
# The radius nest is what runs once per sparse point. For each
159-
# interpolation Expression in the group, build its
160-
# accumulator-wrapped copy of the radius nest. Injection Exprs
161-
# share a single copy of the radius nest (their ``Inc`` already
162-
# carries the right ``weights * rhs`` form).
163-
inner = nest.nodes[0] if len(nest.nodes) == 1 else List(body=nest.nodes)
221+
# Find the innermost sparse-Dimension Iteration within ``nest`` --
222+
# that's where the head Expressions actually live, beneath any block
223+
# Iterations that the cluster pipeline wrapped around the sparse
224+
# loop.
225+
sparse_dim = sample.interpolator.sfunction._sparse_dim
226+
inner_iter = nest
227+
for it in FindNodes(Iteration).visit(nest):
228+
if it.dim.root is sparse_dim and \
229+
any(e in FindNodes(Expression).visit(it) for e in exprs):
230+
inner_iter = it
231+
232+
# ``inner_iter`` may carry a GPU kernel prelude (thread-index
233+
# ``ExpressionBundle`` and OOB ``Conditional``) that downstream
234+
# kernel synthesis expects to find at the top of the block dim's
235+
# body. The temps and the accumulator pattern go *after* that
236+
# prelude.
237+
head, body_nodes = _split_kernel_prelude(inner_iter.nodes)
238+
239+
radius_nest = body_nodes[0] if len(body_nodes) == 1 else List(body=body_nodes)
164240
interp_exprs = [e for e in exprs if isinstance(e.expr, InterpolationMixin)]
165241
inject_exprs = [e for e in exprs if isinstance(e.expr, InjectionMixin)]
166242

167-
body = []
243+
new_body = []
168244
for expr in interp_exprs:
169245
siblings = [e for e in exprs if e is not expr]
170-
body.append(_interp_inner_block(inner, expr, expr.expr.interpolator, siblings))
246+
new_body.append(_interp_inner_block(
247+
radius_nest, expr, expr.expr.interpolator, siblings))
171248
if inject_exprs:
172249
drop = {e: None for e in interp_exprs}
173-
body.append(Transformer(drop, nested=True).visit(inner) if drop else inner)
250+
new_body.append(Transformer(drop, nested=True).visit(radius_nest)
251+
if drop else radius_nest)
252+
253+
new_inner_iter = inner_iter._rebuild(
254+
nodes=head + temp_exprs + tuple(new_body)
255+
)
256+
if new_inner_iter is inner_iter:
257+
return nest
258+
return Transformer({inner_iter: new_inner_iter}, nested=True).visit(nest)
174259

175-
return nest._rebuild(nodes=temp_exprs + tuple(body))
260+
261+
def _split_kernel_prelude(nodes):
262+
"""
263+
Split the contents of a sparse-Dimension Iteration into the GPU
264+
kernel prelude (the thread-index ``ExpressionBundle`` and the
265+
optional OOB ``Conditional``) and the remaining body. On non-cuda
266+
pipelines the prelude is empty and the full ``nodes`` tuple is the
267+
body.
268+
"""
269+
head = ()
270+
body = tuple(nodes)
271+
if body and isinstance(body[0], ExpressionBundle):
272+
head += (body[0],)
273+
body = body[1:]
274+
if body and isinstance(body[0], Conditional):
275+
head += (body[0],)
276+
body = body[1:]
277+
return head, body
176278

177279

178280
def _interp_inner_block(inner, expr, interp, siblings):
@@ -219,10 +321,15 @@ def _interp_inner_block(inner, expr, interp, siblings):
219321

220322
init = Expression(DummyEq(acc, 0))
221323
inc = Increment(DummyEq(acc, weighted_rhs))
222-
# Honour the synthetic Eq's flavour: a SparseInc means the user
223-
# asked for ``sf[p_*] += sum`` (interpolation with ``increment=True``);
224-
# otherwise the standard write is ``sf[p_*] = sum``.
225-
write_back_cls = Increment if eq.is_Reduction else Expression
324+
# Honour the synthetic Eq's flavour: an ``IncrInterpolation`` means
325+
# the user asked for ``sf[p_*] += sum`` (interpolation with
326+
# ``increment=True``); a plain ``Interpolation`` is just ``sf[p_*] =
327+
# sum``. We key off the leaf class' ``is_increment_writeback`` flag
328+
# rather than ``is_Reduction`` because both flavours are tagged as
329+
# reductions (``OpInc``) for dependence-analysis purposes -- the rhs
330+
# is implicitly summed over the radius dims -- but only the
331+
# ``IncrInterpolation`` flavour writes back with ``+=``.
332+
write_back_cls = Increment if eq.is_increment_writeback else Expression
226333
write_back = write_back_cls(DummyEq(sf_lhs, acc))
227334

228335
# Single substitution: drop siblings, replace ``expr`` with ``inc``.

devito/types/sparse.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,17 @@ def dist_origin(self):
415415
@memoized_meth
416416
def _crdim(self, dim):
417417
"""
418-
The CustomDimension associated with the Dimension `dim` for
419-
the radius of the interpolation/injection stencil
418+
The CustomDimension associated with the grid Dimension ``dim``
419+
for the radius of the interpolation/injection stencil. The
420+
parent is ``dim`` itself so that ``_defines`` traces back to the
421+
grid Dimension the radius slides over -- this is what dependence
422+
analysis needs to recognise the implicit reduction over ``rp_*``
423+
rather than treating ``rp_*`` as if they were derived from the
424+
SparseFunction's sparse Dimension.
420425
"""
421426
sname = self._sparse_dim.name
422427
return CustomDimension(f"r{sname}{dim.name}", -self.r+1,
423-
self.r, 2*self.r, self._sparse_dim)
428+
self.r, 2*self.r, dim)
424429

425430
@memoized_meth
426431
def _cond_rdim(self, dim, cond):

tests/test_mpi.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3474,15 +3474,19 @@ def test_issue_2448_v2(self, mode, setup):
34743474

34753475
calls = [i for i in FindNodes(Call).visit(op2) if isinstance(i, HaloUpdateCall)]
34763476

3477-
assert len(calls) == 5
3478-
assert len(FindNodes(HaloUpdateCall).visit(body0(op2).body[1].body[0])) == 2
3479-
assert len(FindNodes(HaloUpdateCall).visit(body0(op2).body[1].body[1])) == 3
3477+
# The sparse-op efunc lowering hoists the interpolation halos so they
3478+
# share with the upstream stencil halo for the same Function. The
3479+
# initial-step halo merges ``v`` with ``v2``; the in-loop halo merges
3480+
# the same pair, leaving 3 distinct HaloUpdateCalls in total.
3481+
assert len(calls) == 3
3482+
assert len(FindNodes(HaloUpdateCall).visit(body0(op2).body[1].body[0])) == 1
3483+
assert len(FindNodes(HaloUpdateCall).visit(body0(op2).body[1].body[1])) == 2
34803484
assert calls[0].arguments[0] is v
3481-
assert calls[1].arguments[0] is v2
3482-
assert calls[2].arguments[0] is tau
3483-
assert calls[2].arguments[1] is tau2
3484-
assert calls[3].arguments[0] is v
3485-
assert calls[4].arguments[0] is v2
3485+
assert calls[0].arguments[1] is v2
3486+
assert calls[1].arguments[0] is tau
3487+
assert calls[1].arguments[1] is tau2
3488+
assert calls[2].arguments[0] is v
3489+
assert calls[2].arguments[1] is v2
34863490

34873491
@pytest.mark.parallel(mode=1)
34883492
def test_issue_2448_v3(self, mode, setup):
@@ -3509,14 +3513,17 @@ def test_issue_2448_v3(self, mode, setup):
35093513

35103514
calls = [i for i in FindNodes(Call).visit(op3) if isinstance(i, HaloUpdateCall)]
35113515

3512-
assert len(calls) == 4
3516+
# The interpolation halo for ``v2.forward`` merges with the in-loop
3517+
# halo for the same Function, so the per-loop tail collapses to a
3518+
# single merged ``v``/``v2`` call instead of two separate ones.
3519+
assert len(calls) == 3
35133520
assert len(FindNodes(HaloUpdateCall).visit(body0(op3).body[1].body[0])) == 1
3514-
assert len(FindNodes(HaloUpdateCall).visit(body0(op3).body[1].body[1])) == 3
3521+
assert len(FindNodes(HaloUpdateCall).visit(body0(op3).body[1].body[1])) == 2
35153522
assert calls[0].arguments[0] is v
35163523
assert calls[1].arguments[0] is tau
35173524
assert calls[1].arguments[1] is tau2
35183525
assert calls[2].arguments[0] is v
3519-
assert calls[3].arguments[0] is v2
3526+
assert calls[2].arguments[1] is v2
35203527

35213528
@pytest.mark.parallel(mode=1)
35223529
def test_issue_2448_backward(self, mode):

0 commit comments

Comments
 (0)