2828from devito .ir .equations import DummyEq
2929from devito .ir .equations .algorithms import lower_exprs
3030from devito .ir .iet import (
31- Call , EntryFunction , Expression , FindNodes , HaloSpot , Increment , Iteration , List ,
32- Transformer , make_callable
31+ Call , Conditional , EntryFunction , Expression , ExpressionBundle , FindNodes , HaloSpot ,
32+ Increment , Iteration , List , Transformer , make_callable
3333)
3434from devito .passes .iet .engine import iet_pass
3535from devito .types import Eq , InjectionMixin , InterpolationMixin , Symbol
@@ -91,9 +91,22 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
9191 efuncs .append (efunc )
9292
9393 call = Call (efunc .name , list (efunc .parameters ))
94+
95+ # Any HaloSpot living inside the nest (e.g. around the radius
96+ # loops reading a grid Function) becomes invisible once the
97+ # nest collapses into an opaque Call. Hoist its halo_scheme out
98+ # so the halo exchange still happens at the right point in the
99+ # parent IET, sitting next to the Call rather than buried in it.
100+ # An injection's lhs is *written* (not read) so its inner halo
101+ # entry for that Function is a reduction-only halo that the
102+ # caller has nothing to read back -- drop it before hoisting.
103+ reduced = {e .expr .lhs .function for e in exprs
104+ if isinstance (e .expr , InjectionMixin )}
105+ prelude = _hoisted_halo_prelude (nest , reduced )
106+
94107 parent = parents [nest ]
95108 if parent is None :
96- mapper [nest ] = call
109+ mapper [nest ] = List ( body = prelude + ( call ,)) if prelude else call
97110 continue
98111
99112 # Drop fields that the (now-opaque) Call only writes/increments,
@@ -120,7 +133,10 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
120133def _find_outer_iteration (iet , expr ):
121134 """
122135 Walk up the IET from ``expr`` and return the outermost Iteration
123- whose ``dim.root`` is the SparseFunction's sparse Dimension.
136+ whose ``dim.root`` is the SparseFunction's sparse Dimension. This
137+ is the entry point of the sparse-op nest in the parent IET; the
138+ full nest (including any block Iterations the cluster pipeline
139+ wrapped around the sparse loop) gets extracted into the efunc.
124140 """
125141 sparse_dim = expr .expr .interpolator .sfunction ._sparse_dim
126142 for it in FindNodes (Iteration ).visit (iet ):
@@ -139,13 +155,60 @@ def _enclosing_halospot(iet, nest):
139155 return None
140156
141157
158+ def _hoisted_halo_prelude (nest , reduced = None ):
159+ """
160+ Build a tuple of nodes that recreate, outside the sparse nest, any
161+ HaloSpots that live inside it. Each hoisted HaloSpot is wrapped in
162+ the Iterations from ``nest`` whose dim is referenced by the halo's
163+ ``loc_indices`` (e.g. the ``ps`` loop wrapping a halo whose
164+ ``loc_indices = {ps}``), so the lowering pipeline that turns the
165+ HaloSpot into a halo-exchange Call still sees the indices it needs
166+ in scope.
167+
168+ ``reduced``, when provided, lists Functions whose halo entries
169+ should be dropped from the hoisted scheme -- these are the
170+ Functions the sparse Call only writes/increments (e.g. an
171+ injection's lhs), so the parent never reads them back and the
172+ halo update would be redundant.
173+ """
174+ reduced = reduced or set ()
175+
176+ inner = []
177+ for hs in FindNodes (HaloSpot ).visit (nest ):
178+ scheme = hs .halo_scheme .drop (reduced ) if reduced else hs .halo_scheme
179+ if not scheme .is_void :
180+ inner .append (scheme )
181+ if not inner :
182+ return ()
183+
184+ iters = FindNodes (Iteration ).visit (nest )
185+ prelude = []
186+ for scheme in inner :
187+ loc_dims = set ().union (* (d ._defines for d in scheme .loc_indices ))
188+ wrappers = [it for it in iters if it .dim in loc_dims ]
189+ body = HaloSpot (List (body = []), scheme )
190+ for it in reversed (wrappers ):
191+ body = it ._rebuild (nodes = body )
192+ prelude .append (body )
193+ return tuple (prelude )
194+
195+
142196def _materialise_nest (nest , exprs ):
143197 """
144198 Rewrite the sparse Dimension's Iteration body to compute the
145199 position/coefficient temps once per sparse point, then for any
146200 interpolation Expression wrap it with the scalar accumulator
147201 pattern. Multiple sparse-op Expressions sharing the same outer
148202 Iteration are materialised in one pass and reuse the same temps.
203+
204+ ``nest`` is the *outermost* sparse-Dimension Iteration, so that the
205+ whole block-Iteration hierarchy (e.g. ``p_rec0_blk0`` -> ``p_rec``
206+ on the GPU pipeline) is extracted into the efunc and downstream GPU
207+ kernel synthesis can fold the block loops into a thread-grid
208+ wrapping the kernel body. The temps and the accumulator pattern,
209+ however, must live *inside* the innermost sparse Iteration -- one
210+ set per sparse point, sitting beneath any thread-index/OOB-guard
211+ prelude that the GPU kernel prep may have inserted.
149212 """
150213 # Position + coefficient temporaries as IET Expressions. These are
151214 # the same for every Expression in the group, so we emit them once.
@@ -155,24 +218,63 @@ def _materialise_nest(nest, exprs):
155218 temp_exprs = tuple (Expression (DummyEq (e .lhs , e .rhs ))
156219 for e in lower_exprs (sample .sparse_temps ()))
157220
158- # The radius nest is what runs once per sparse point. For each
159- # interpolation Expression in the group, build its
160- # accumulator-wrapped copy of the radius nest. Injection Exprs
161- # share a single copy of the radius nest (their ``Inc`` already
162- # carries the right ``weights * rhs`` form).
163- inner = nest .nodes [0 ] if len (nest .nodes ) == 1 else List (body = nest .nodes )
221+ # Find the innermost sparse-Dimension Iteration within ``nest`` --
222+ # that's where the head Expressions actually live, beneath any block
223+ # Iterations that the cluster pipeline wrapped around the sparse
224+ # loop.
225+ sparse_dim = sample .interpolator .sfunction ._sparse_dim
226+ inner_iter = nest
227+ for it in FindNodes (Iteration ).visit (nest ):
228+ if it .dim .root is sparse_dim and \
229+ any (e in FindNodes (Expression ).visit (it ) for e in exprs ):
230+ inner_iter = it
231+
232+ # ``inner_iter`` may carry a GPU kernel prelude (thread-index
233+ # ``ExpressionBundle`` and OOB ``Conditional``) that downstream
234+ # kernel synthesis expects to find at the top of the block dim's
235+ # body. The temps and the accumulator pattern go *after* that
236+ # prelude.
237+ head , body_nodes = _split_kernel_prelude (inner_iter .nodes )
238+
239+ radius_nest = body_nodes [0 ] if len (body_nodes ) == 1 else List (body = body_nodes )
164240 interp_exprs = [e for e in exprs if isinstance (e .expr , InterpolationMixin )]
165241 inject_exprs = [e for e in exprs if isinstance (e .expr , InjectionMixin )]
166242
167- body = []
243+ new_body = []
168244 for expr in interp_exprs :
169245 siblings = [e for e in exprs if e is not expr ]
170- body .append (_interp_inner_block (inner , expr , expr .expr .interpolator , siblings ))
246+ new_body .append (_interp_inner_block (
247+ radius_nest , expr , expr .expr .interpolator , siblings ))
171248 if inject_exprs :
172249 drop = {e : None for e in interp_exprs }
173- body .append (Transformer (drop , nested = True ).visit (inner ) if drop else inner )
250+ new_body .append (Transformer (drop , nested = True ).visit (radius_nest )
251+ if drop else radius_nest )
252+
253+ new_inner_iter = inner_iter ._rebuild (
254+ nodes = head + temp_exprs + tuple (new_body )
255+ )
256+ if new_inner_iter is inner_iter :
257+ return nest
258+ return Transformer ({inner_iter : new_inner_iter }, nested = True ).visit (nest )
174259
175- return nest ._rebuild (nodes = temp_exprs + tuple (body ))
260+
261+ def _split_kernel_prelude (nodes ):
262+ """
263+ Split the contents of a sparse-Dimension Iteration into the GPU
264+ kernel prelude (the thread-index ``ExpressionBundle`` and the
265+ optional OOB ``Conditional``) and the remaining body. On non-cuda
266+ pipelines the prelude is empty and the full ``nodes`` tuple is the
267+ body.
268+ """
269+ head = ()
270+ body = tuple (nodes )
271+ if body and isinstance (body [0 ], ExpressionBundle ):
272+ head += (body [0 ],)
273+ body = body [1 :]
274+ if body and isinstance (body [0 ], Conditional ):
275+ head += (body [0 ],)
276+ body = body [1 :]
277+ return head , body
176278
177279
178280def _interp_inner_block (inner , expr , interp , siblings ):
@@ -219,10 +321,15 @@ def _interp_inner_block(inner, expr, interp, siblings):
219321
220322 init = Expression (DummyEq (acc , 0 ))
221323 inc = Increment (DummyEq (acc , weighted_rhs ))
222- # Honour the synthetic Eq's flavour: a SparseInc means the user
223- # asked for ``sf[p_*] += sum`` (interpolation with ``increment=True``);
224- # otherwise the standard write is ``sf[p_*] = sum``.
225- write_back_cls = Increment if eq .is_Reduction else Expression
324+ # Honour the synthetic Eq's flavour: an ``IncrInterpolation`` means
325+ # the user asked for ``sf[p_*] += sum`` (interpolation with
326+ # ``increment=True``); a plain ``Interpolation`` is just ``sf[p_*] =
327+ # sum``. We key off the leaf class' ``is_increment_writeback`` flag
328+ # rather than ``is_Reduction`` because both flavours are tagged as
329+ # reductions (``OpInc``) for dependence-analysis purposes -- the rhs
330+ # is implicitly summed over the radius dims -- but only the
331+ # ``IncrInterpolation`` flavour writes back with ``+=``.
332+ write_back_cls = Increment if eq .is_increment_writeback else Expression
226333 write_back = write_back_cls (DummyEq (sf_lhs , acc ))
227334
228335 # Single substitution: drop siblings, replace ``expr`` with ``inc``.
0 commit comments