Skip to content

Commit 7209a62

Browse files
committed
compiler: fix for kernel
1 parent ff53f3b commit 7209a62

5 files changed

Lines changed: 116 additions & 27 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,19 @@ def callback(self, clusters, prefix, seen=None):
472472
# Construct the HaloTouch Cluster
473473
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
474474

475-
key0 = lambda i: i in prefix[:-1] or i in hs.loc_indices # noqa: B023
475+
# The HaloTouch only needs to be scheduled at the outermost
476+
# level the halo'd data depends on -- typically the time loop
477+
# and the sub-iterators (``loc_indices``) that index it. Any
478+
# outer Dimension whose iteration is *independent* of the
479+
# halo (e.g. ``p_rec`` for an interpolation reading ``u``
480+
# along the radius nest) shouldn't be in the HaloTouch's
481+
# ispace, otherwise its ``sequentialize()``d properties veto
482+
# blocking on the real clusters it sits alongside.
483+
relevant = (set(hs.loc_indices) |
484+
set().union(*(i._defines for i in hs.loc_indices)))
485+
key0 = lambda i: i in prefix[:-1] and i._defines & relevant # noqa: B023
476486
key1 = lambda i: not i._defines & set(hs.distributed_defined) # noqa: B023
477-
key = lambda i: key0(i) and key1(i) # noqa: B023
487+
key = lambda i: (key0(i) or i in hs.loc_indices) and key1(i) # noqa: B023
478488
ispace = c.ispace.project(key)
479489

480490
properties = c.properties.sequentialize()

devito/ir/clusters/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class EqBlock(CacheInstances):
3434
@classmethod
3535
def _preprocess_args(cls, exprs, ispace=null_ispace, guards=None,
3636
properties=None, syncs=None, halo_scheme=None):
37-
exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs))
37+
exprs = tuple(clusterize_eq(e, ispace=ispace) for e in as_tuple(exprs))
3838
guards = Guards(guards or {})
3939
properties = Properties(properties or {})
4040
syncs = normalize_syncs(syncs or {})

devito/ir/equations/equation.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,19 @@ def detect(cls, expr):
208208
(ReduceMin, OpMin),
209209
(ReduceMax, OpMax),
210210
(Inc, OpInc),
211+
# An ``Interpolation`` looks like a plain ``Eq`` -- ``sf[p_*] =
212+
# expr[rp_*]`` -- but the cluster scheduler iterates the rhs
213+
# over the radius dims, so values are implicitly summed across
214+
# ``rp_*``. Tagging it as ``OpInc`` makes the dependence
215+
# analysis treat ``rp_*`` as reduction dims
216+
# (``parallel_if_atomic``), which matches the lowered code
217+
# (``sumrec += ... ; sf[p_*] = sumrec``) and stops the
218+
# blocking heuristic from turning the tiny radius loops into
219+
# thread blocks. The actual write-back flavour at ``sf[p_*]``
220+
# is keyed off the Eq's class (``is_increment_writeback``) in
221+
# ``lower_sparse_ops`` so this tag doesn't accidentally turn
222+
# ``Interpolation`` assignments into increments.
223+
(InterpolationMixin, OpInc),
211224
)
212225
for kls, op in reduction_mapper:
213226
if isinstance(expr, kls):
@@ -366,12 +379,17 @@ class LoweredSparseEq(SparseOpMixin, LoweredEq):
366379

367380
class LoweredInterpolation(InterpolationMixin, LoweredSparseEq):
368381
"""IR counterpart of ``Interpolation``."""
369-
pass
382+
# ``sf[p_*] = ...``: the write-back at the sparse position is an
383+
# assignment. The Eq is still tagged as a reduction
384+
# (``OpInc``/``is_Reduction``) because the rhs is summed over the
385+
# radius dims; only the *outer* write-back to ``sf[p_*]`` is plain.
386+
is_increment_writeback = False
370387

371388

372389
class LoweredIncrInterpolation(InterpolationMixin, LoweredSparseEq):
373390
"""IR counterpart of ``IncrInterpolation``."""
374-
pass
391+
# ``sf[p_*] += ...``: the user asked for ``interpolate(..., increment=True)``.
392+
is_increment_writeback = True
375393

376394

377395
class LoweredInjection(InjectionMixin, LoweredSparseEq):
@@ -458,12 +476,12 @@ class ClusterizedSparseEq(SparseOpMixin, ClusterizedEq):
458476

459477
class ClusterizedInterpolation(InterpolationMixin, ClusterizedSparseEq):
460478
"""Frozen counterpart of ``LoweredInterpolation``."""
461-
pass
479+
is_increment_writeback = False
462480

463481

464482
class ClusterizedIncrInterpolation(InterpolationMixin, ClusterizedSparseEq):
465483
"""Frozen counterpart of ``LoweredIncrInterpolation``."""
466-
pass
484+
is_increment_writeback = True
467485

468486

469487
class ClusterizedInjection(InjectionMixin, ClusterizedSparseEq):

devito/passes/iet/sparse.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from devito.ir.equations import DummyEq
2929
from devito.ir.equations.algorithms import lower_exprs
3030
from devito.ir.iet import (
31-
Call, EntryFunction, Expression, FindNodes, HaloSpot, Increment, Iteration, List,
32-
Transformer, make_callable
31+
Call, Conditional, EntryFunction, Expression, ExpressionBundle, FindNodes, HaloSpot,
32+
Increment, Iteration, List, Transformer, make_callable
3333
)
3434
from devito.passes.iet.engine import iet_pass
3535
from devito.types import Eq, InjectionMixin, InterpolationMixin, Symbol
@@ -120,7 +120,10 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
120120
def _find_outer_iteration(iet, expr):
121121
"""
122122
Walk up the IET from ``expr`` and return the outermost Iteration
123-
whose ``dim.root`` is the SparseFunction's sparse Dimension.
123+
whose ``dim.root`` is the SparseFunction's sparse Dimension. This
124+
is the entry point of the sparse-op nest in the parent IET; the
125+
full nest (including any block Iterations the cluster pipeline
126+
wrapped around the sparse loop) gets extracted into the efunc.
124127
"""
125128
sparse_dim = expr.expr.interpolator.sfunction._sparse_dim
126129
for it in FindNodes(Iteration).visit(iet):
@@ -146,6 +149,15 @@ def _materialise_nest(nest, exprs):
146149
interpolation Expression wrap it with the scalar accumulator
147150
pattern. Multiple sparse-op Expressions sharing the same outer
148151
Iteration are materialised in one pass and reuse the same temps.
152+
153+
``nest`` is the *outermost* sparse-Dimension Iteration, so that the
154+
whole block-Iteration hierarchy (e.g. ``p_rec0_blk0`` -> ``p_rec``
155+
on the GPU pipeline) is extracted into the efunc and downstream GPU
156+
kernel synthesis can fold the block loops into a thread-grid
157+
wrapping the kernel body. The temps and the accumulator pattern,
158+
however, must live *inside* the innermost sparse Iteration -- one
159+
set per sparse point, sitting beneath any thread-index/OOB-guard
160+
prelude that the GPU kernel prep may have inserted.
149161
"""
150162
# Position + coefficient temporaries as IET Expressions. These are
151163
# the same for every Expression in the group, so we emit them once.
@@ -155,24 +167,63 @@ def _materialise_nest(nest, exprs):
155167
temp_exprs = tuple(Expression(DummyEq(e.lhs, e.rhs))
156168
for e in lower_exprs(sample.sparse_temps()))
157169

158-
# The radius nest is what runs once per sparse point. For each
159-
# interpolation Expression in the group, build its
160-
# accumulator-wrapped copy of the radius nest. Injection Exprs
161-
# share a single copy of the radius nest (their ``Inc`` already
162-
# carries the right ``weights * rhs`` form).
163-
inner = nest.nodes[0] if len(nest.nodes) == 1 else List(body=nest.nodes)
170+
# Find the innermost sparse-Dimension Iteration within ``nest`` --
171+
# that's where the head Expressions actually live, beneath any block
172+
# Iterations that the cluster pipeline wrapped around the sparse
173+
# loop.
174+
sparse_dim = sample.interpolator.sfunction._sparse_dim
175+
inner_iter = nest
176+
for it in FindNodes(Iteration).visit(nest):
177+
if it.dim.root is sparse_dim and \
178+
any(e in FindNodes(Expression).visit(it) for e in exprs):
179+
inner_iter = it
180+
181+
# ``inner_iter`` may carry a GPU kernel prelude (thread-index
182+
# ``ExpressionBundle`` and OOB ``Conditional``) that downstream
183+
# kernel synthesis expects to find at the top of the block dim's
184+
# body. The temps and the accumulator pattern go *after* that
185+
# prelude.
186+
head, body_nodes = _split_kernel_prelude(inner_iter.nodes)
187+
188+
radius_nest = body_nodes[0] if len(body_nodes) == 1 else List(body=body_nodes)
164189
interp_exprs = [e for e in exprs if isinstance(e.expr, InterpolationMixin)]
165190
inject_exprs = [e for e in exprs if isinstance(e.expr, InjectionMixin)]
166191

167-
body = []
192+
new_body = []
168193
for expr in interp_exprs:
169194
siblings = [e for e in exprs if e is not expr]
170-
body.append(_interp_inner_block(inner, expr, expr.expr.interpolator, siblings))
195+
new_body.append(_interp_inner_block(
196+
radius_nest, expr, expr.expr.interpolator, siblings))
171197
if inject_exprs:
172198
drop = {e: None for e in interp_exprs}
173-
body.append(Transformer(drop, nested=True).visit(inner) if drop else inner)
199+
new_body.append(Transformer(drop, nested=True).visit(radius_nest)
200+
if drop else radius_nest)
201+
202+
new_inner_iter = inner_iter._rebuild(
203+
nodes=head + temp_exprs + tuple(new_body)
204+
)
205+
if new_inner_iter is inner_iter:
206+
return nest
207+
return Transformer({inner_iter: new_inner_iter}, nested=True).visit(nest)
208+
174209

175-
return nest._rebuild(nodes=temp_exprs + tuple(body))
210+
def _split_kernel_prelude(nodes):
211+
"""
212+
Split the contents of a sparse-Dimension Iteration into the GPU
213+
kernel prelude (the thread-index ``ExpressionBundle`` and the
214+
optional OOB ``Conditional``) and the remaining body. On non-cuda
215+
pipelines the prelude is empty and the full ``nodes`` tuple is the
216+
body.
217+
"""
218+
head = ()
219+
body = tuple(nodes)
220+
if body and isinstance(body[0], ExpressionBundle):
221+
head += (body[0],)
222+
body = body[1:]
223+
if body and isinstance(body[0], Conditional):
224+
head += (body[0],)
225+
body = body[1:]
226+
return head, body
176227

177228

178229
def _interp_inner_block(inner, expr, interp, siblings):
@@ -219,10 +270,15 @@ def _interp_inner_block(inner, expr, interp, siblings):
219270

220271
init = Expression(DummyEq(acc, 0))
221272
inc = Increment(DummyEq(acc, weighted_rhs))
222-
# Honour the synthetic Eq's flavour: a SparseInc means the user
223-
# asked for ``sf[p_*] += sum`` (interpolation with ``increment=True``);
224-
# otherwise the standard write is ``sf[p_*] = sum``.
225-
write_back_cls = Increment if eq.is_Reduction else Expression
273+
# Honour the synthetic Eq's flavour: an ``IncrInterpolation`` means
274+
# the user asked for ``sf[p_*] += sum`` (interpolation with
275+
# ``increment=True``); a plain ``Interpolation`` is just ``sf[p_*] =
276+
# sum``. We key off the leaf class' ``is_increment_writeback`` flag
277+
# rather than ``is_Reduction`` because both flavours are tagged as
278+
# reductions (``OpInc``) for dependence-analysis purposes -- the rhs
279+
# is implicitly summed over the radius dims -- but only the
280+
# ``IncrInterpolation`` flavour writes back with ``+=``.
281+
write_back_cls = Increment if eq.is_increment_writeback else Expression
226282
write_back = write_back_cls(DummyEq(sf_lhs, acc))
227283

228284
# Single substitution: drop siblings, replace ``expr`` with ``inc``.

devito/types/sparse.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,17 @@ def dist_origin(self):
415415
@memoized_meth
416416
def _crdim(self, dim):
417417
"""
418-
The CustomDimension associated with the Dimension `dim` for
419-
the radius of the interpolation/injection stencil
418+
The CustomDimension associated with the grid Dimension ``dim``
419+
for the radius of the interpolation/injection stencil. The
420+
parent is ``dim`` itself so that ``_defines`` traces back to the
421+
grid Dimension the radius slides over -- this is what dependence
422+
analysis needs to recognise the implicit reduction over ``rp_*``
423+
rather than treating ``rp_*`` as if they were derived from the
424+
SparseFunction's sparse Dimension.
420425
"""
421426
sname = self._sparse_dim.name
422427
return CustomDimension(f"r{sname}{dim.name}", -self.r+1,
423-
self.r, 2*self.r, self._sparse_dim)
428+
self.r, 2*self.r, dim)
424429

425430
@memoized_meth
426431
def _cond_rdim(self, dim, cond):

0 commit comments

Comments
 (0)