@@ -116,7 +116,7 @@ def key(f):
116116 # Then we inject them into the Clusters. This involves creating the
117117 # initializing Clusters, and replacing the buffered Functions with the buffers
118118 clusters = InjectBuffers (mapper , sregistry , options ).process (clusters )
119- print ( clusters )
119+
120120 return clusters
121121
122122
@@ -142,22 +142,20 @@ def callback(self, clusters, prefix):
142142 return clusters
143143 d = prefix [- 1 ].dim
144144
145- def key (f , * args ):
146- for (ff , _ ) in self .mapper :
147- if f == ff :
148- return True
149- return False
145+ key = lambda f , * args : any (f == ff for ff , _ in self .mapper )
150146 bfmap = map_buffered_functions (clusters , key )
151147
152148 # A BufferDescriptor is a simple data structure storing additional
153149 # information about a buffer, harvested from the subset of `clusters`
154150 # that access it
155- descriptors = {b : BufferDescriptor (f , b , bfmap [f ], g )
156- for (f , g ), b in self .mapper .items ()
157- if f in bfmap }
151+ descriptors = {}
152+ for (f , g ), b in self .mapper .items ():
153+ if f in bfmap :
154+ descriptors .setdefault (b , []).append (BufferDescriptor (f , b , bfmap [f ], g ))
158155
159156 # Are we inside the right `d`?
160- descriptors = {b : v for b , v in descriptors .items () if d in v .itdims }
157+ descriptors = {b : [vi for vi in v if d in vi .itdims ]
158+ for b , v in descriptors .items ()}
161159
162160 if not descriptors :
163161 return clusters
@@ -172,23 +170,28 @@ def key(f, *args):
172170 # Substitution rules to replace buffered Functions with buffers
173171 # E.g., `usave[time+1, x+1, y+1] -> ub0[t1, x+1, y+1]`
174172 subs = {}
175- for b , v in descriptors .items ():
176- accesses = chain (* [c .scope [v .f ] for c in v .clusters ])
177- index_mapper = {i : mds [(v .xd , i )] for i in v .indices }
178- for a in accesses :
179- subs [a .access ] = b .indexed [[index_mapper .get (i , i ) for i in a ]]
173+ for b , vb in descriptors .items ():
174+ for v in vb :
175+ for c in v .clusters :
176+ if c .guards .get (d ) != v .guards .get (d ):
177+ continue
178+ subs .setdefault (c , {})
179+ accesses = c .scope [v .f ]
180+ index_mapper = {i : mds [(v .xd , i )] for i in v .indices }
181+ for a in accesses :
182+ subs [c ][a .access ] = b .indexed [[index_mapper .get (i , i ) for i in a ]]
180183
181184 processed = []
182185 for c in clusters :
183186 # If a buffer is read but never written, then we need to add
184187 # an Eq to step through the next slot
185188 # E.g., `ub[0, x] = usave[time+2, x]`
186- for _ , v in descriptors .items ( ):
189+ for v in chain . from_iterable ( descriptors .values () ):
187190 if not v .is_readonly :
188191 continue
189192 if c not in v .firstread :
190193 continue
191- if not c .guards .get (d ) = = v .guards .get (d ):
194+ if c .guards .get (d ) ! = v .guards .get (d ):
192195 continue
193196
194197 idxf = v .last_idx [c ]
@@ -207,9 +210,8 @@ def key(f, *args):
207210 guards = c .guards .xandg (v .xd , GuardBound (0 , v .first_idx .f ))
208211 else :
209212 guards = c .guards
210-
211213 properties = c .properties .sequentialize (d )
212- if not isinstance (d , BufferDimension ) and c .guards [ d ] .has (Mod ):
214+ if not isinstance (d , BufferDimension ) and c .guards . get ( d ) .has (Mod ):
213215 properties = properties .prefetchable (d )
214216 # `c` may be a HaloTouch Cluster, so with no vision of the `bdims`
215217 properties = properties .parallelize (v .bdims ).affine (v .bdims )
@@ -219,7 +221,7 @@ def key(f, *args):
219221 processed .append (Cluster (expr , ispace , guards , properties , syncs ))
220222
221223 # Substitute the buffered Functions with the buffers
222- exprs = [uxreplace (e , subs ) for e in c .exprs ]
224+ exprs = [uxreplace (e , subs . get ( c , {}) ) for e in c .exprs ]
223225 ispace = c .ispace .augment (subiters )
224226 properties = c .properties .sequentialize (d )
225227 processed .append (
@@ -228,12 +230,12 @@ def key(f, *args):
228230
229231 # Append the copy-back if `c` is the last-write of some buffers
230232 # E.g., `usave[time+1, x] = ub[t1, x]`
231- for _ , v in descriptors .items ( ):
233+ for v in chain . from_iterable ( descriptors .values () ):
232234 if v .is_readonly :
233235 continue
234236 if c not in v .lastwrite :
235237 continue
236- if not c .guards .get (d ) = = v .guards .get (d ):
238+ if c .guards .get (d ) ! = v .guards .get (d ):
237239 continue
238240
239241 idxf = v .last_idx [c ]
@@ -269,36 +271,37 @@ def key(f, *args):
269271 return init + processed
270272
271273 def _optimize (self , clusters , descriptors ):
272- for b , v in descriptors .items ():
273- if v .is_writeonly :
274- # `b` might be written by multiple, potentially mutually
275- # exclusive, equations. For example, two equations that have or
276- # will have complementary guards, hence only one will be
277- # executed. In such a case, we can split the equations over
278- # separate IterationSpaces
279- key0 = lambda : Stamp ()
280- elif v .is_readonly :
281- # `b` is read multiple times -- this could just be the case of
282- # coupled equations, so we more cautiously perform a
283- # "buffer-wise" splitting of the IterationSpaces (i.e., only
284- # relevant if there are at least two read-only buffers)
285- stamp = Stamp ()
286- key0 = lambda : stamp # noqa: B023
287- else :
288- continue
289-
290- processed = []
291- for c in clusters :
292- if b not in c .functions :
293- processed .append (c )
274+ for b , vb in descriptors .items ():
275+ for v in vb :
276+ if v .is_writeonly :
277+ # `b` might be written by multiple, potentially mutually
278+ # exclusive, equations. For example, two equations that have or
279+ # will have complementary guards, hence only one will be
280+ # executed. In such a case, we can split the equations over
281+ # separate IterationSpaces
282+ key0 = lambda : Stamp ()
283+ elif v .is_readonly :
284+ # `b` is read multiple times -- this could just be the case of
285+ # coupled equations, so we more cautiously perform a
286+ # "buffer-wise" splitting of the IterationSpaces (i.e., only
287+ # relevant if there are at least two read-only buffers)
288+ stamp = Stamp ()
289+ key0 = lambda : stamp # noqa: B023
290+ else :
294291 continue
295292
296- key1 = lambda d : not d ._defines & v .dim ._defines # noqa: B023
297- dims = c .ispace .project (key1 ).itdims
298- ispace = c .ispace .lift (dims , key0 ())
299- processed .append (c .rebuild (ispace = ispace ))
293+ processed = []
294+ for c in clusters :
295+ if b not in c .functions :
296+ processed .append (c )
297+ continue
300298
301- clusters = processed
299+ key1 = lambda d : not d ._defines & v .dim ._defines # noqa: B023
300+ dims = c .ispace .project (key1 ).itdims
301+ ispace = c .ispace .lift (dims , key0 ())
302+ processed .append (c .rebuild (ispace = ispace ))
303+
304+ clusters = processed
302305
303306 return clusters
304307
@@ -314,11 +317,11 @@ def _reuse(self, init, clusters, descriptors):
314317 cbk = lambda v : v
315318
316319 mapper = as_mapper (descriptors , key = lambda b : b ._signature )
317- mapper = {k : cbk (v ) for k , v in mapper . items () if cbk (v )}
320+ mapper = {k : [ cbk (v ) for v in vb if cbk (v )] for k , vb in mapper . items ( )}
318321
319322 subs = {}
320323 drop = set ()
321- for reusable in mapper .values ():
324+ for reusable in chain . from_iterable ( mapper .values () ):
322325 retain = reusable .pop (0 )
323326 drop .update (reusable )
324327
@@ -365,17 +368,22 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
365368 # {buffered Function -> Buffer}
366369 xds = {}
367370 mapper = {}
371+ extras = {}
368372 for f , clusters in bfmap .items ():
369373 for k , ck in groupby (clusters , key = lambda c : c .guards ):
374+ ck = list (ck )
370375 exprs = flatten (c .exprs for c in ck )
371376
372377 bdims = key (f , exprs )
373378
374379 dims = [d for d in f .dimensions if d not in bdims ]
375380 if len (dims ) != 1 :
376381 raise CompilationError (f"Unsupported multi-dimensional `buffering` "
377- f"required by `{ f } `" )
382+ f"required by `{ f } `" )
378383 dim = dims .pop ()
384+ if k and not dim ._defines & k .keys ():
385+ extras .setdefault (f , []).append (k )
386+ continue
379387
380388 if is_buffering (exprs ):
381389 # Multi-level buffering
@@ -386,13 +394,15 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
386394 buffer , = buffers
387395 xd = buffer .indices [dim ]
388396 else :
389- size = infer_buffer_size (f , dim , clusters )
397+
398+ size = infer_buffer_size (f , dim , ck )
390399
391400 if async_degree is not None :
392401 if async_degree < size :
393402 warning (
394403 'Ignoring provided asynchronous degree as it would be '
395- f'too small for the required buffer (provided { async_degree } , '
404+ 'too small for the required buffer'
405+ f' (provided { async_degree } , '
396406 f'but need at least { size } for `{ f .name } `)'
397407 )
398408 else :
@@ -421,6 +431,13 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
421431 padding = padding , grid = f .grid , halo = f .halo ,
422432 space = 'mapped' , mapped = f , f = f )
423433
434+ for f , k in extras .items ():
435+ for (ff , kk ) in dict (mapper ):
436+ if f == ff :
437+ for ki in k :
438+ if ki .keys () & set (mapper [(ff , kk )].dimensions ):
439+ mapper [(f , ki )] = mapper [(ff , kk )]
440+
424441 return mapper
425442
426443
@@ -453,7 +470,7 @@ def __init__(self, f, b, clusters, guards):
453470 self .indices = extract_indices (f , self .dim , clusters )
454471
455472 def __repr__ (self ):
456- return f"Descriptor[{ self .f } -> { self .b } ]"
473+ return f"Descriptor[{ self .f } -> { self .b } ], { self . guards } "
457474
458475 @property
459476 def size (self ):
@@ -668,7 +685,7 @@ def make_mds(descriptors, prefix, sregistry):
668685 inspecting all buffers so that ModuloDimensions are reused when possible.
669686 """
670687 mds = defaultdict (int )
671- for v in descriptors .values ():
688+ for v in chain . from_iterable ( descriptors .values () ):
672689 size = v .xd .symbolic_size
673690
674691 if size == 1 :
@@ -684,7 +701,6 @@ def make_mds(descriptors, prefix, sregistry):
684701 # same strategy is also applied in clusters/algorithms/Stepper
685702 key = lambda i : - np .inf if i - p == 0 else (i - p ) # noqa: B023
686703 indices = sorted (v .indices , key = key )
687- v_mds = None
688704
689705 for k , i in enumerate (indices ):
690706 k = (v .xd , i )
@@ -711,42 +727,43 @@ def init_buffers(descriptors, options):
711727 init_onwrite = options ['buf-init-onwrite' ]
712728
713729 init = []
714- for b , v in descriptors .items ():
715- f = v .f
716-
717- if v .is_read :
718- # Special case: avoid initialization in the case of double (or
719- # multiple) buffering because it's completely unnecessary
720- if v .is_double_buffering :
721- continue
730+ for b , vb in descriptors .items ():
731+ for v in vb :
732+ f = v .f
733+
734+ if v .is_read :
735+ # Special case: avoid initialization in the case of double (or
736+ # multiple) buffering because it's completely unnecessary
737+ if v .is_double_buffering :
738+ continue
722739
723- lhs = b .indexify ()._subs (v .xd , v .first_idx .b )
724- rhs = f .indexify ()._subs (v .dim , v .first_idx .f )
740+ lhs = b .indexify ()._subs (v .xd , v .first_idx .b )
741+ rhs = f .indexify ()._subs (v .dim , v .first_idx .f )
725742
726- elif v .is_write and init_onwrite (f ):
727- lhs = b .indexify ()
728- rhs = S .Zero
743+ elif v .is_write and init_onwrite (f ):
744+ lhs = b .indexify ()
745+ rhs = S .Zero
729746
730- else :
731- continue
747+ else :
748+ continue
732749
733- expr = Eq (lhs , rhs )
734- expr = lower_exprs (expr )
750+ expr = Eq (lhs , rhs )
751+ expr = lower_exprs (expr )
735752
736- ispace = v .write_to
753+ ispace = v .write_to
737754
738- guards = {}
739- guards [None ] = GuardBound (v .dim .root .symbolic_min , v .dim .root .symbolic_max )
740- if v .is_read :
741- guards [v .xd ] = GuardBound (0 , v .first_idx .f )
755+ guards = {}
756+ guards [None ] = GuardBound (v .dim .root .symbolic_min , v .dim .root .symbolic_max )
757+ if v .is_read :
758+ guards [v .xd ] = GuardBound (0 , v .first_idx .f )
742759
743- properties = Properties ()
744- properties = properties .affine (ispace .itdims )
745- properties = properties .parallelize (ispace .itdims )
760+ properties = Properties ()
761+ properties = properties .affine (ispace .itdims )
762+ properties = properties .parallelize (ispace .itdims )
746763
747- syncs = {None : [InitArray (None , b )]}
764+ syncs = {None : [InitArray (None , b )]}
748765
749- init .append (Cluster (expr , ispace , guards , properties , syncs ))
766+ init .append (Cluster (expr , ispace , guards , properties , syncs ))
750767
751768 return init
752769
0 commit comments