2828from devito .ir .equations import DummyEq
2929from devito .ir .equations .algorithms import lower_exprs
3030from devito .ir .iet import (
31- Call , EntryFunction , Expression , FindNodes , HaloSpot , Increment , Iteration , List ,
32- Transformer , make_callable
31+ Call , Conditional , EntryFunction , Expression , ExpressionBundle , FindNodes , HaloSpot ,
32+ Increment , Iteration , List , Transformer , make_callable
3333)
3434from devito .passes .iet .engine import iet_pass
3535from devito .types import Eq , InjectionMixin , InterpolationMixin , Symbol
@@ -120,7 +120,10 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
120120def _find_outer_iteration (iet , expr ):
121121 """
122122 Walk up the IET from ``expr`` and return the outermost Iteration
123- whose ``dim.root`` is the SparseFunction's sparse Dimension.
123+ whose ``dim.root`` is the SparseFunction's sparse Dimension. This
124+ is the entry point of the sparse-op nest in the parent IET; the
125+ full nest (including any block Iterations the cluster pipeline
126+ wrapped around the sparse loop) gets extracted into the efunc.
124127 """
125128 sparse_dim = expr .expr .interpolator .sfunction ._sparse_dim
126129 for it in FindNodes (Iteration ).visit (iet ):
@@ -146,6 +149,15 @@ def _materialise_nest(nest, exprs):
146149 interpolation Expression wrap it with the scalar accumulator
147150 pattern. Multiple sparse-op Expressions sharing the same outer
148151 Iteration are materialised in one pass and reuse the same temps.
152+
153+ ``nest`` is the *outermost* sparse-Dimension Iteration, so that the
154+ whole block-Iteration hierarchy (e.g. ``p_rec0_blk0`` -> ``p_rec``
155+ on the GPU pipeline) is extracted into the efunc and downstream GPU
156+ kernel synthesis can fold the block loops into a thread-grid
157+ wrapping the kernel body. The temps and the accumulator pattern,
158+ however, must live *inside* the innermost sparse Iteration -- one
159+ set per sparse point, sitting beneath any thread-index/OOB-guard
160+ prelude that the GPU kernel prep may have inserted.
149161 """
150162 # Position + coefficient temporaries as IET Expressions. These are
151163 # the same for every Expression in the group, so we emit them once.
@@ -155,24 +167,63 @@ def _materialise_nest(nest, exprs):
155167 temp_exprs = tuple (Expression (DummyEq (e .lhs , e .rhs ))
156168 for e in lower_exprs (sample .sparse_temps ()))
157169
158- # The radius nest is what runs once per sparse point. For each
159- # interpolation Expression in the group, build its
160- # accumulator-wrapped copy of the radius nest. Injection Exprs
161- # share a single copy of the radius nest (their ``Inc`` already
162- # carries the right ``weights * rhs`` form).
163- inner = nest .nodes [0 ] if len (nest .nodes ) == 1 else List (body = nest .nodes )
170+ # Find the innermost sparse-Dimension Iteration within ``nest`` --
171+ # that's where the head Expressions actually live, beneath any block
172+ # Iterations that the cluster pipeline wrapped around the sparse
173+ # loop.
174+ sparse_dim = sample .interpolator .sfunction ._sparse_dim
175+ inner_iter = nest
176+ for it in FindNodes (Iteration ).visit (nest ):
177+ if it .dim .root is sparse_dim and \
178+ any (e in FindNodes (Expression ).visit (it ) for e in exprs ):
179+ inner_iter = it
180+
181+ # ``inner_iter`` may carry a GPU kernel prelude (thread-index
182+ # ``ExpressionBundle`` and OOB ``Conditional``) that downstream
183+ # kernel synthesis expects to find at the top of the block dim's
184+ # body. The temps and the accumulator pattern go *after* that
185+ # prelude.
186+ head , body_nodes = _split_kernel_prelude (inner_iter .nodes )
187+
188+ radius_nest = body_nodes [0 ] if len (body_nodes ) == 1 else List (body = body_nodes )
164189 interp_exprs = [e for e in exprs if isinstance (e .expr , InterpolationMixin )]
165190 inject_exprs = [e for e in exprs if isinstance (e .expr , InjectionMixin )]
166191
167- body = []
192+ new_body = []
168193 for expr in interp_exprs :
169194 siblings = [e for e in exprs if e is not expr ]
170- body .append (_interp_inner_block (inner , expr , expr .expr .interpolator , siblings ))
195+ new_body .append (_interp_inner_block (
196+ radius_nest , expr , expr .expr .interpolator , siblings ))
171197 if inject_exprs :
172198 drop = {e : None for e in interp_exprs }
173- body .append (Transformer (drop , nested = True ).visit (inner ) if drop else inner )
199+ new_body .append (Transformer (drop , nested = True ).visit (radius_nest )
200+ if drop else radius_nest )
201+
202+ new_inner_iter = inner_iter ._rebuild (
203+ nodes = head + temp_exprs + tuple (new_body )
204+ )
205+ if new_inner_iter is inner_iter :
206+ return nest
207+ return Transformer ({inner_iter : new_inner_iter }, nested = True ).visit (nest )
208+
174209
175- return nest ._rebuild (nodes = temp_exprs + tuple (body ))
210+ def _split_kernel_prelude (nodes ):
211+ """
212+ Split the contents of a sparse-Dimension Iteration into the GPU
213+ kernel prelude (the thread-index ``ExpressionBundle`` and the
214+ optional OOB ``Conditional``) and the remaining body. On non-cuda
215+ pipelines the prelude is empty and the full ``nodes`` tuple is the
216+ body.
217+ """
218+ head = ()
219+ body = tuple (nodes )
220+ if body and isinstance (body [0 ], ExpressionBundle ):
221+ head += (body [0 ],)
222+ body = body [1 :]
223+ if body and isinstance (body [0 ], Conditional ):
224+ head += (body [0 ],)
225+ body = body [1 :]
226+ return head , body
176227
177228
178229def _interp_inner_block (inner , expr , interp , siblings ):
@@ -219,10 +270,15 @@ def _interp_inner_block(inner, expr, interp, siblings):
219270
220271 init = Expression (DummyEq (acc , 0 ))
221272 inc = Increment (DummyEq (acc , weighted_rhs ))
222- # Honour the synthetic Eq's flavour: a SparseInc means the user
223- # asked for ``sf[p_*] += sum`` (interpolation with ``increment=True``);
224- # otherwise the standard write is ``sf[p_*] = sum``.
225- write_back_cls = Increment if eq .is_Reduction else Expression
273+ # Honour the synthetic Eq's flavour: an ``IncrInterpolation`` means
274+ # the user asked for ``sf[p_*] += sum`` (interpolation with
275+ # ``increment=True``); a plain ``Interpolation`` is just ``sf[p_*] =
276+ # sum``. We key off the leaf class' ``is_increment_writeback`` flag
277+ # rather than ``is_Reduction`` because both flavours are tagged as
278+ # reductions (``OpInc``) for dependence-analysis purposes -- the rhs
279+ # is implicitly summed over the radius dims -- but only the
280+ # ``IncrInterpolation`` flavour writes back with ``+=``.
281+ write_back_cls = Increment if eq .is_increment_writeback else Expression
226282 write_back = write_back_cls (DummyEq (sf_lhs , acc ))
227283
228284 # Single substitution: drop siblings, replace ``expr`` with ``inc``.
0 commit comments