11from collections import defaultdict , namedtuple
22from functools import cached_property
3- from itertools import chain
3+ from itertools import chain , groupby
44
55import numpy as np
6- from sympy import S , simplify
6+ from sympy import Mod , S , simplify
77
88from devito .exceptions import CompilationError
99from devito .ir import (
@@ -116,7 +116,7 @@ def key(f):
116116 # Then we inject them into the Clusters. This involves creating the
117117 # initializing Clusters, and replacing the buffered Functions with the buffers
118118 clusters = InjectBuffers (mapper , sregistry , options ).process (clusters )
119-
119+ print ( clusters )
120120 return clusters
121121
122122
@@ -142,14 +142,18 @@ def callback(self, clusters, prefix):
142142 return clusters
143143 d = prefix [- 1 ].dim
144144
145- key = lambda f , * args : f in self .mapper
145+ def key (f , * args ):
146+ for (ff , _ ) in self .mapper :
147+ if f == ff :
148+ return True
149+ return False
146150 bfmap = map_buffered_functions (clusters , key )
147151
148152 # A BufferDescriptor is a simple data structure storing additional
149153 # information about a buffer, harvested from the subset of `clusters`
150154 # that access it
151- descriptors = {b : BufferDescriptor (f , b , bfmap [f ])
152- for f , b in self .mapper .items ()
155+ descriptors = {b : BufferDescriptor (f , b , bfmap [f ], g )
156+ for ( f , g ) , b in self .mapper .items ()
153157 if f in bfmap }
154158
155159 # Are we inside the right `d`?
@@ -184,6 +188,8 @@ def callback(self, clusters, prefix):
184188 continue
185189 if c not in v .firstread :
186190 continue
191+ if not c .guards .get (d ) == v .guards .get (d ):
192+ continue
187193
188194 idxf = v .last_idx [c ]
189195 idxb = mds [(v .xd , idxf )]
@@ -203,7 +209,7 @@ def callback(self, clusters, prefix):
203209 guards = c .guards
204210
205211 properties = c .properties .sequentialize (d )
206- if not isinstance (d , BufferDimension ):
212+ if not isinstance (d , BufferDimension ) and c . guards [ d ]. has ( Mod ) :
207213 properties = properties .prefetchable (d )
208214 # `c` may be a HaloTouch Cluster, so with no vision of the `bdims`
209215 properties = properties .parallelize (v .bdims ).affine (v .bdims )
@@ -227,6 +233,8 @@ def callback(self, clusters, prefix):
227233 continue
228234 if c not in v .lastwrite :
229235 continue
236+ if not c .guards .get (d ) == v .guards .get (d ):
237+ continue
230238
231239 idxf = v .last_idx [c ]
232240 idxb = mds [(v .xd , idxf )]
@@ -358,15 +366,16 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
358366 xds = {}
359367 mapper = {}
360368 for f , clusters in bfmap .items ():
361- exprs = flatten (c .exprs for c in clusters )
369+ for k , ck in groupby (clusters , key = lambda c : c .guards ):
370+ exprs = flatten (c .exprs for c in ck )
362371
363- bdims = key (f , exprs )
372+ bdims = key (f , exprs )
364373
365- dims = [d for d in f .dimensions if d not in bdims ]
366- if len (dims ) != 1 :
367- raise CompilationError (f"Unsupported multi-dimensional `buffering` "
368- f"required by `{ f } `" )
369- dim = dims .pop ()
374+ dims = [d for d in f .dimensions if d not in bdims ]
375+ if len (dims ) != 1 :
376+ raise CompilationError (f"Unsupported multi-dimensional `buffering` "
377+ f"required by `{ f } `" )
378+ dim = dims .pop ()
370379
371380 if is_buffering (exprs ):
372381 # Multi-level buffering
@@ -391,25 +400,25 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
391400 else :
392401 size = async_degree
393402
394- # A special CustomDimension to use in place of `dim` in the buffer
395- try :
396- xd = xds [(dim , size )]
397- except KeyError :
398- name = sregistry .make_name (prefix = 'db' )
399- xd = xds [(dim , size )] = BufferDimension (name , 0 , size - 1 , size , dim )
400- extra_kwargs = {}
401-
402- # The buffer dimensions
403- dimensions = list (f .dimensions )
404- assert dim in f .dimensions
405- dimensions [dimensions .index (dim )] = xd
406-
407- # Finally create the actual buffer
408- cls = callback or Array
409- name = sregistry .make_name (prefix = f'{ f .name } b' )
410- mapper [f ] = cls (name = name , dimensions = dimensions , dtype = f .dtype ,
411- grid = f .grid , halo = f .halo ,
412- space = 'mapped' , mapped = f , f = f , ** extra_kwargs )
403+ # A special CustomDimension to use in place of `dim` in the buffer
404+ try :
405+ xd = xds [(dim , size )]
406+ except KeyError :
407+ name = sregistry .make_name (prefix = 'db' )
408+ xd = xds [(dim , size )] = BufferDimension (name , 0 , size - 1 , size , dim )
409+ extra_kwargs = {}
410+
411+ # The buffer dimensions
412+ dimensions = list (f .dimensions )
413+ assert dim in f .dimensions
414+ dimensions [dimensions .index (dim )] = xd
415+
416+ # Finally create the actual buffer
417+ cls = callback or Array
418+ name = sregistry .make_name (prefix = f'{ f .name } b' )
419+ mapper [( f , k ) ] = cls (name = name , dimensions = dimensions , dtype = f .dtype ,
420+ grid = f .grid , halo = f .halo ,
421+ space = 'mapped' , mapped = f , f = f , ** extra_kwargs )
413422
414423 return mapper
415424
@@ -429,10 +438,11 @@ def map_buffered_functions(clusters, key):
429438
430439class BufferDescriptor :
431440
432- def __init__ (self , f , b , clusters ):
441+ def __init__ (self , f , b , clusters , guards ):
433442 self .f = f
434443 self .b = b
435444 self .clusters = clusters
445+ self .guards = guards
436446
437447 self .xd , = b .find (BufferDimension )
438448 self .bdims = tuple (d for d in b .dimensions if d is not self .xd )
@@ -673,8 +683,9 @@ def make_mds(descriptors, prefix, sregistry):
673683 # same strategy is also applied in clusters/algorithms/Stepper
674684 key = lambda i : - np .inf if i - p == 0 else (i - p ) # noqa: B023
675685 indices = sorted (v .indices , key = key )
686+ v_mds = None
676687
677- for i in indices :
688+ for k , i in enumerate ( indices ) :
678689 k = (v .xd , i )
679690 if k in mds :
680691 continue
0 commit comments