@@ -116,7 +116,7 @@ def key(f):
116116 # Then we inject them into the Clusters. This involves creating the
117117 # initializing Clusters, and replacing the buffered Functions with the buffers
118118 clusters = InjectBuffers (mapper , sregistry , options ).process (clusters )
119- print ( clusters )
119+
120120 return clusters
121121
122122
@@ -142,22 +142,20 @@ def callback(self, clusters, prefix):
142142 return clusters
143143 d = prefix [- 1 ].dim
144144
145- def key (f , * args ):
146- for (ff , _ ) in self .mapper :
147- if f == ff :
148- return True
149- return False
145+ key = lambda f , * args : any (f == ff for ff , _ in self .mapper )
150146 bfmap = map_buffered_functions (clusters , key )
151147
152148 # A BufferDescriptor is a simple data structure storing additional
153149 # information about a buffer, harvested from the subset of `clusters`
154150 # that access it
155- descriptors = {b : BufferDescriptor (f , b , bfmap [f ], g )
156- for (f , g ), b in self .mapper .items ()
157- if f in bfmap }
151+ descriptors = {}
152+ for (f , g ), b in self .mapper .items ():
153+ if f in bfmap :
154+ descriptors .setdefault (b , []).append (BufferDescriptor (f , b , bfmap [f ], g ))
158155
159156 # Are we inside the right `d`?
160- descriptors = {b : v for b , v in descriptors .items () if d in v .itdims }
157+ descriptors = {b : [vi for vi in v if d in vi .itdims ]
158+ for b , v in descriptors .items ()}
161159
162160 if not descriptors :
163161 return clusters
@@ -172,23 +170,28 @@ def key(f, *args):
172170 # Substitution rules to replace buffered Functions with buffers
173171 # E.g., `usave[time+1, x+1, y+1] -> ub0[t1, x+1, y+1]`
174172 subs = {}
175- for b , v in descriptors .items ():
176- accesses = chain (* [c .scope [v .f ] for c in v .clusters ])
177- index_mapper = {i : mds [(v .xd , i )] for i in v .indices }
178- for a in accesses :
179- subs [a .access ] = b .indexed [[index_mapper .get (i , i ) for i in a ]]
173+ for b , vb in descriptors .items ():
174+ for v in vb :
175+ for c in v .clusters :
176+ if c .guards .get (d ) != v .guards .get (d ):
177+ continue
178+ subs .setdefault (c , {})
179+ accesses = c .scope [v .f ]
180+ index_mapper = {i : mds [(v .xd , i )] for i in v .indices }
181+ for a in accesses :
182+ subs [c ][a .access ] = b .indexed [[index_mapper .get (i , i ) for i in a ]]
180183
181184 processed = []
182185 for c in clusters :
183186 # If a buffer is read but never written, then we need to add
184187 # an Eq to step through the next slot
185188 # E.g., `ub[0, x] = usave[time+2, x]`
186- for _ , v in descriptors .items ( ):
189+ for v in chain . from_iterable ( descriptors .values () ):
187190 if not v .is_readonly :
188191 continue
189192 if c not in v .firstread :
190193 continue
191- if not c .guards .get (d ) = = v .guards .get (d ):
194+ if c .guards .get (d ) ! = v .guards .get (d ):
192195 continue
193196
194197 idxf = v .last_idx [c ]
@@ -219,7 +222,7 @@ def key(f, *args):
219222 processed .append (Cluster (expr , ispace , guards , properties , syncs ))
220223
221224 # Substitute the buffered Functions with the buffers
222- exprs = [uxreplace (e , subs ) for e in c .exprs ]
225+ exprs = [uxreplace (e , subs . get ( c , {}) ) for e in c .exprs ]
223226 ispace = c .ispace .augment (subiters )
224227 properties = c .properties .sequentialize (d )
225228 processed .append (
@@ -228,12 +231,12 @@ def key(f, *args):
228231
229232 # Append the copy-back if `c` is the last-write of some buffers
230233 # E.g., `usave[time+1, x] = ub[t1, x]`
231- for _ , v in descriptors .items ( ):
234+ for v in chain . from_iterable ( descriptors .values () ):
232235 if v .is_readonly :
233236 continue
234237 if c not in v .lastwrite :
235238 continue
236- if not c .guards .get (d ) = = v .guards .get (d ):
239+ if c .guards .get (d ) ! = v .guards .get (d ):
237240 continue
238241
239242 idxf = v .last_idx [c ]
@@ -269,36 +272,37 @@ def key(f, *args):
269272 return init + processed
270273
271274 def _optimize (self , clusters , descriptors ):
272- for b , v in descriptors .items ():
273- if v .is_writeonly :
274- # `b` might be written by multiple, potentially mutually
275- # exclusive, equations. For example, two equations that have or
276- # will have complementary guards, hence only one will be
277- # executed. In such a case, we can split the equations over
278- # separate IterationSpaces
279- key0 = lambda : Stamp ()
280- elif v .is_readonly :
281- # `b` is read multiple times -- this could just be the case of
282- # coupled equations, so we more cautiously perform a
283- # "buffer-wise" splitting of the IterationSpaces (i.e., only
284- # relevant if there are at least two read-only buffers)
285- stamp = Stamp ()
286- key0 = lambda : stamp # noqa: B023
287- else :
288- continue
289-
290- processed = []
291- for c in clusters :
292- if b not in c .functions :
293- processed .append (c )
275+ for b , vb in descriptors .items ():
276+ for v in vb :
277+ if v .is_writeonly :
278+ # `b` might be written by multiple, potentially mutually
279+ # exclusive, equations. For example, two equations that have or
280+ # will have complementary guards, hence only one will be
281+ # executed. In such a case, we can split the equations over
282+ # separate IterationSpaces
283+ key0 = lambda : Stamp ()
284+ elif v .is_readonly :
285+ # `b` is read multiple times -- this could just be the case of
286+ # coupled equations, so we more cautiously perform a
287+ # "buffer-wise" splitting of the IterationSpaces (i.e., only
288+ # relevant if there are at least two read-only buffers)
289+ stamp = Stamp ()
290+ key0 = lambda : stamp # noqa: B023
291+ else :
294292 continue
295293
296- key1 = lambda d : not d ._defines & v .dim ._defines # noqa: B023
297- dims = c .ispace .project (key1 ).itdims
298- ispace = c .ispace .lift (dims , key0 ())
299- processed .append (c .rebuild (ispace = ispace ))
294+ processed = []
295+ for c in clusters :
296+ if b not in c .functions :
297+ processed .append (c )
298+ continue
299+
300+ key1 = lambda d : not d ._defines & v .dim ._defines # noqa: B023
301+ dims = c .ispace .project (key1 ).itdims
302+ ispace = c .ispace .lift (dims , key0 ())
303+ processed .append (c .rebuild (ispace = ispace ))
300304
301- clusters = processed
305+ clusters = processed
302306
303307 return clusters
304308
@@ -314,11 +318,11 @@ def _reuse(self, init, clusters, descriptors):
314318 cbk = lambda v : v
315319
316320 mapper = as_mapper (descriptors , key = lambda b : b ._signature )
317- mapper = {k : cbk (v ) for k , v in mapper . items () if cbk (v )}
321+ mapper = {k : [ cbk (v ) for v in vb if cbk (v )] for k , vb in mapper . items ( )}
318322
319323 subs = {}
320324 drop = set ()
321- for reusable in mapper .values ():
325+ for reusable in chain . from_iterable ( mapper .values () ):
322326 retain = reusable .pop (0 )
323327 drop .update (reusable )
324328
@@ -365,18 +369,24 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
365369 # {buffered Function -> Buffer}
366370 xds = {}
367371 mapper = {}
372+ extras = {}
368373 for f , clusters in bfmap .items ():
369374 for k , ck in groupby (clusters , key = lambda c : c .guards ):
375+ ck = list (ck )
370376 exprs = flatten (c .exprs for c in ck )
371377
372378 bdims = key (f , exprs )
373379
374380 dims = [d for d in f .dimensions if d not in bdims ]
375381 if len (dims ) != 1 :
376382 raise CompilationError (f"Unsupported multi-dimensional `buffering` "
377- f"required by `{ f } `" )
383+ f"required by `{ f } `" )
378384 dim = dims .pop ()
379385
386+ if not dim ._defines & k .keys ():
387+ extras .setdefault (f , []).append (k )
388+ continue
389+
380390 if is_buffering (exprs ):
381391 # Multi-level buffering
382392 # NOTE: a bit rudimentary (we could go through the exprs one by one
@@ -386,13 +396,15 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
386396 buffer , = buffers
387397 xd = buffer .indices [dim ]
388398 else :
389- size = infer_buffer_size (f , dim , clusters )
399+
400+ size = infer_buffer_size (f , dim , ck )
390401
391402 if async_degree is not None :
392403 if async_degree < size :
393404 warning (
394405 'Ignoring provided asynchronous degree as it would be '
395- f'too small for the required buffer (provided { async_degree } , '
406+ 'too small for the required buffer'
407+ f' (provided { async_degree } , '
396408 f'but need at least { size } for `{ f .name } `)'
397409 )
398410 else :
@@ -421,6 +433,13 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
421433 padding = padding , grid = f .grid , halo = f .halo ,
422434 space = 'mapped' , mapped = f , f = f )
423435
436+ for f , k in extras .items ():
437+ for (ff , kk ) in dict (mapper ):
438+ if f == ff :
439+ for ki in k :
440+ if ki .keys () & set (mapper [(ff , kk )].dimensions ):
441+ mapper [(f , ki )] = mapper [(ff , kk )]
442+
424443 return mapper
425444
426445
@@ -453,7 +472,7 @@ def __init__(self, f, b, clusters, guards):
453472 self .indices = extract_indices (f , self .dim , clusters )
454473
455474 def __repr__ (self ):
456- return f"Descriptor[{ self .f } -> { self .b } ]"
475+ return f"Descriptor[{ self .f } -> { self .b } ], { self . guards } "
457476
458477 @property
459478 def size (self ):
@@ -668,7 +687,7 @@ def make_mds(descriptors, prefix, sregistry):
668687 inspecting all buffers so that ModuloDimensions are reused when possible.
669688 """
670689 mds = defaultdict (int )
671- for v in descriptors .values ():
690+ for v in chain . from_iterable ( descriptors .values () ):
672691 size = v .xd .symbolic_size
673692
674693 if size == 1 :
@@ -684,7 +703,6 @@ def make_mds(descriptors, prefix, sregistry):
684703 # same strategy is also applied in clusters/algorithms/Stepper
685704 key = lambda i : - np .inf if i - p == 0 else (i - p ) # noqa: B023
686705 indices = sorted (v .indices , key = key )
687- v_mds = None
688706
689707 for k , i in enumerate (indices ):
690708 k = (v .xd , i )
@@ -711,42 +729,43 @@ def init_buffers(descriptors, options):
711729 init_onwrite = options ['buf-init-onwrite' ]
712730
713731 init = []
714- for b , v in descriptors .items ():
715- f = v .f
716-
717- if v .is_read :
718- # Special case: avoid initialization in the case of double (or
719- # multiple) buffering because it's completely unnecessary
720- if v .is_double_buffering :
721- continue
732+ for b , vb in descriptors .items ():
733+ for v in vb :
734+ f = v .f
735+
736+ if v .is_read :
737+ # Special case: avoid initialization in the case of double (or
738+ # multiple) buffering because it's completely unnecessary
739+ if v .is_double_buffering :
740+ continue
722741
723- lhs = b .indexify ()._subs (v .xd , v .first_idx .b )
724- rhs = f .indexify ()._subs (v .dim , v .first_idx .f )
742+ lhs = b .indexify ()._subs (v .xd , v .first_idx .b )
743+ rhs = f .indexify ()._subs (v .dim , v .first_idx .f )
725744
726- elif v .is_write and init_onwrite (f ):
727- lhs = b .indexify ()
728- rhs = S .Zero
745+ elif v .is_write and init_onwrite (f ):
746+ lhs = b .indexify ()
747+ rhs = S .Zero
729748
730- else :
731- continue
749+ else :
750+ continue
732751
733- expr = Eq (lhs , rhs )
734- expr = lower_exprs (expr )
752+ expr = Eq (lhs , rhs )
753+ expr = lower_exprs (expr )
735754
736- ispace = v .write_to
755+ ispace = v .write_to
737756
738- guards = {}
739- guards [None ] = GuardBound (v .dim .root .symbolic_min , v .dim .root .symbolic_max )
740- if v .is_read :
741- guards [v .xd ] = GuardBound (0 , v .first_idx .f )
757+ guards = {}
758+ guards [None ] = GuardBound (v .dim .root .symbolic_min , v .dim .root .symbolic_max )
759+ if v .is_read :
760+ guards [v .xd ] = GuardBound (0 , v .first_idx .f )
742761
743- properties = Properties ()
744- properties = properties .affine (ispace .itdims )
745- properties = properties .parallelize (ispace .itdims )
762+ properties = Properties ()
763+ properties = properties .affine (ispace .itdims )
764+ properties = properties .parallelize (ispace .itdims )
746765
747- syncs = {None : [InitArray (None , b )]}
766+ syncs = {None : [InitArray (None , b )]}
748767
749- init .append (Cluster (expr , ispace , guards , properties , syncs ))
768+ init .append (Cluster (expr , ispace , guards , properties , syncs ))
750769
751770 return init
752771
0 commit comments