Skip to content

Commit 7a1a6aa

Browse files
committed
compiler: support mutli-buffering
1 parent c8ed5e1 commit 7a1a6aa

2 files changed

Lines changed: 105 additions & 86 deletions

File tree

devito/passes/clusters/asynchrony.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22

3-
from sympy import true
3+
from sympy import Mod, true
44

55
from devito.ir import (
66
Backward, Forward, GuardBoundNext, PrefetchUpdate, Queue, ReleaseLock, SyncArray,
@@ -78,7 +78,9 @@ def callback(self, clusters, prefix):
7878
d = self.key0(c0)
7979
if d is not dim:
8080
continue
81-
81+
print(c0.guards)
82+
if d in c0.guards and not c0.guards[d].has(Mod):
83+
continue
8284
protected = self._schedule_waitlocks(c0, d, clusters, locks, syncs)
8385
self._schedule_withlocks(c0, d, protected, locks, syncs)
8486

devito/passes/clusters/buffering.py

Lines changed: 101 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def key(f):
116116
# Then we inject them into the Clusters. This involves creating the
117117
# initializing Clusters, and replacing the buffered Functions with the buffers
118118
clusters = InjectBuffers(mapper, sregistry, options).process(clusters)
119-
print(clusters)
119+
120120
return clusters
121121

122122

@@ -142,22 +142,20 @@ def callback(self, clusters, prefix):
142142
return clusters
143143
d = prefix[-1].dim
144144

145-
def key(f, *args):
146-
for (ff, _) in self.mapper:
147-
if f == ff:
148-
return True
149-
return False
145+
key = lambda f, *args: any(f == ff for ff, _ in self.mapper)
150146
bfmap = map_buffered_functions(clusters, key)
151147

152148
# A BufferDescriptor is a simple data structure storing additional
153149
# information about a buffer, harvested from the subset of `clusters`
154150
# that access it
155-
descriptors = {b: BufferDescriptor(f, b, bfmap[f], g)
156-
for (f, g), b in self.mapper.items()
157-
if f in bfmap}
151+
descriptors = {}
152+
for (f, g), b in self.mapper.items():
153+
if f in bfmap:
154+
descriptors.setdefault(b, []).append(BufferDescriptor(f, b, bfmap[f], g))
158155

159156
# Are we inside the right `d`?
160-
descriptors = {b: v for b, v in descriptors.items() if d in v.itdims}
157+
descriptors = {b: [vi for vi in v if d in vi.itdims]
158+
for b, v in descriptors.items()}
161159

162160
if not descriptors:
163161
return clusters
@@ -172,23 +170,28 @@ def key(f, *args):
172170
# Substitution rules to replace buffered Functions with buffers
173171
# E.g., `usave[time+1, x+1, y+1] -> ub0[t1, x+1, y+1]`
174172
subs = {}
175-
for b, v in descriptors.items():
176-
accesses = chain(*[c.scope[v.f] for c in v.clusters])
177-
index_mapper = {i: mds[(v.xd, i)] for i in v.indices}
178-
for a in accesses:
179-
subs[a.access] = b.indexed[[index_mapper.get(i, i) for i in a]]
173+
for b, vb in descriptors.items():
174+
for v in vb:
175+
for c in v.clusters:
176+
if c.guards.get(d) != v.guards.get(d):
177+
continue
178+
subs.setdefault(c, {})
179+
accesses = c.scope[v.f]
180+
index_mapper = {i: mds[(v.xd, i)] for i in v.indices}
181+
for a in accesses:
182+
subs[c][a.access] = b.indexed[[index_mapper.get(i, i) for i in a]]
180183

181184
processed = []
182185
for c in clusters:
183186
# If a buffer is read but never written, then we need to add
184187
# an Eq to step through the next slot
185188
# E.g., `ub[0, x] = usave[time+2, x]`
186-
for _, v in descriptors.items():
189+
for v in chain.from_iterable(descriptors.values()):
187190
if not v.is_readonly:
188191
continue
189192
if c not in v.firstread:
190193
continue
191-
if not c.guards.get(d) == v.guards.get(d):
194+
if c.guards.get(d) != v.guards.get(d):
192195
continue
193196

194197
idxf = v.last_idx[c]
@@ -207,9 +210,8 @@ def key(f, *args):
207210
guards = c.guards.xandg(v.xd, GuardBound(0, v.first_idx.f))
208211
else:
209212
guards = c.guards
210-
211213
properties = c.properties.sequentialize(d)
212-
if not isinstance(d, BufferDimension) and c.guards[d].has(Mod):
214+
if not isinstance(d, BufferDimension) and c.guards.get(d).has(Mod):
213215
properties = properties.prefetchable(d)
214216
# `c` may be a HaloTouch Cluster, so with no vision of the `bdims`
215217
properties = properties.parallelize(v.bdims).affine(v.bdims)
@@ -219,7 +221,7 @@ def key(f, *args):
219221
processed.append(Cluster(expr, ispace, guards, properties, syncs))
220222

221223
# Substitute the buffered Functions with the buffers
222-
exprs = [uxreplace(e, subs) for e in c.exprs]
224+
exprs = [uxreplace(e, subs.get(c, {})) for e in c.exprs]
223225
ispace = c.ispace.augment(subiters)
224226
properties = c.properties.sequentialize(d)
225227
processed.append(
@@ -228,12 +230,12 @@ def key(f, *args):
228230

229231
# Append the copy-back if `c` is the last-write of some buffers
230232
# E.g., `usave[time+1, x] = ub[t1, x]`
231-
for _, v in descriptors.items():
233+
for v in chain.from_iterable(descriptors.values()):
232234
if v.is_readonly:
233235
continue
234236
if c not in v.lastwrite:
235237
continue
236-
if not c.guards.get(d) == v.guards.get(d):
238+
if c.guards.get(d) != v.guards.get(d):
237239
continue
238240

239241
idxf = v.last_idx[c]
@@ -269,36 +271,37 @@ def key(f, *args):
269271
return init + processed
270272

271273
def _optimize(self, clusters, descriptors):
272-
for b, v in descriptors.items():
273-
if v.is_writeonly:
274-
# `b` might be written by multiple, potentially mutually
275-
# exclusive, equations. For example, two equations that have or
276-
# will have complementary guards, hence only one will be
277-
# executed. In such a case, we can split the equations over
278-
# separate IterationSpaces
279-
key0 = lambda: Stamp()
280-
elif v.is_readonly:
281-
# `b` is read multiple times -- this could just be the case of
282-
# coupled equations, so we more cautiously perform a
283-
# "buffer-wise" splitting of the IterationSpaces (i.e., only
284-
# relevant if there are at least two read-only buffers)
285-
stamp = Stamp()
286-
key0 = lambda: stamp # noqa: B023
287-
else:
288-
continue
289-
290-
processed = []
291-
for c in clusters:
292-
if b not in c.functions:
293-
processed.append(c)
274+
for b, vb in descriptors.items():
275+
for v in vb:
276+
if v.is_writeonly:
277+
# `b` might be written by multiple, potentially mutually
278+
# exclusive, equations. For example, two equations that have or
279+
# will have complementary guards, hence only one will be
280+
# executed. In such a case, we can split the equations over
281+
# separate IterationSpaces
282+
key0 = lambda: Stamp()
283+
elif v.is_readonly:
284+
# `b` is read multiple times -- this could just be the case of
285+
# coupled equations, so we more cautiously perform a
286+
# "buffer-wise" splitting of the IterationSpaces (i.e., only
287+
# relevant if there are at least two read-only buffers)
288+
stamp = Stamp()
289+
key0 = lambda: stamp # noqa: B023
290+
else:
294291
continue
295292

296-
key1 = lambda d: not d._defines & v.dim._defines # noqa: B023
297-
dims = c.ispace.project(key1).itdims
298-
ispace = c.ispace.lift(dims, key0())
299-
processed.append(c.rebuild(ispace=ispace))
293+
processed = []
294+
for c in clusters:
295+
if b not in c.functions:
296+
processed.append(c)
297+
continue
300298

301-
clusters = processed
299+
key1 = lambda d: not d._defines & v.dim._defines # noqa: B023
300+
dims = c.ispace.project(key1).itdims
301+
ispace = c.ispace.lift(dims, key0())
302+
processed.append(c.rebuild(ispace=ispace))
303+
304+
clusters = processed
302305

303306
return clusters
304307

@@ -314,11 +317,11 @@ def _reuse(self, init, clusters, descriptors):
314317
cbk = lambda v: v
315318

316319
mapper = as_mapper(descriptors, key=lambda b: b._signature)
317-
mapper = {k: cbk(v) for k, v in mapper.items() if cbk(v)}
320+
mapper = {k: [cbk(v) for v in vb if cbk(v)] for k, vb in mapper.items()}
318321

319322
subs = {}
320323
drop = set()
321-
for reusable in mapper.values():
324+
for reusable in chain.from_iterable(mapper.values()):
322325
retain = reusable.pop(0)
323326
drop.update(reusable)
324327

@@ -365,17 +368,22 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
365368
# {buffered Function -> Buffer}
366369
xds = {}
367370
mapper = {}
371+
extras = {}
368372
for f, clusters in bfmap.items():
369373
for k, ck in groupby(clusters, key=lambda c: c.guards):
374+
ck = list(ck)
370375
exprs = flatten(c.exprs for c in ck)
371376

372377
bdims = key(f, exprs)
373378

374379
dims = [d for d in f.dimensions if d not in bdims]
375380
if len(dims) != 1:
376381
raise CompilationError(f"Unsupported multi-dimensional `buffering` "
377-
f"required by `{f}`")
382+
f"required by `{f}`")
378383
dim = dims.pop()
384+
if k and not dim._defines & k.keys():
385+
extras.setdefault(f, []).append(k)
386+
continue
379387

380388
if is_buffering(exprs):
381389
# Multi-level buffering
@@ -386,13 +394,15 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
386394
buffer, = buffers
387395
xd = buffer.indices[dim]
388396
else:
389-
size = infer_buffer_size(f, dim, clusters)
397+
398+
size = infer_buffer_size(f, dim, ck)
390399

391400
if async_degree is not None:
392401
if async_degree < size:
393402
warning(
394403
'Ignoring provided asynchronous degree as it would be '
395-
f'too small for the required buffer (provided {async_degree}, '
404+
'too small for the required buffer'
405+
f' (provided {async_degree}, '
396406
f'but need at least {size} for `{f.name}`)'
397407
)
398408
else:
@@ -421,6 +431,13 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
421431
padding=padding, grid=f.grid, halo=f.halo,
422432
space='mapped', mapped=f, f=f)
423433

434+
for f, k in extras.items():
435+
for (ff, kk) in dict(mapper):
436+
if f == ff:
437+
for ki in k:
438+
if ki.keys() & set(mapper[(ff, kk)].dimensions):
439+
mapper[(f, ki)] = mapper[(ff, kk)]
440+
424441
return mapper
425442

426443

@@ -453,7 +470,7 @@ def __init__(self, f, b, clusters, guards):
453470
self.indices = extract_indices(f, self.dim, clusters)
454471

455472
def __repr__(self):
456-
return f"Descriptor[{self.f} -> {self.b}]"
473+
return f"Descriptor[{self.f} -> {self.b}], {self.guards}"
457474

458475
@property
459476
def size(self):
@@ -668,7 +685,7 @@ def make_mds(descriptors, prefix, sregistry):
668685
inspecting all buffers so that ModuloDimensions are reused when possible.
669686
"""
670687
mds = defaultdict(int)
671-
for v in descriptors.values():
688+
for v in chain.from_iterable(descriptors.values()):
672689
size = v.xd.symbolic_size
673690

674691
if size == 1:
@@ -684,7 +701,6 @@ def make_mds(descriptors, prefix, sregistry):
684701
# same strategy is also applied in clusters/algorithms/Stepper
685702
key = lambda i: -np.inf if i - p == 0 else (i - p) # noqa: B023
686703
indices = sorted(v.indices, key=key)
687-
v_mds = None
688704

689705
for k, i in enumerate(indices):
690706
k = (v.xd, i)
@@ -711,42 +727,43 @@ def init_buffers(descriptors, options):
711727
init_onwrite = options['buf-init-onwrite']
712728

713729
init = []
714-
for b, v in descriptors.items():
715-
f = v.f
716-
717-
if v.is_read:
718-
# Special case: avoid initialization in the case of double (or
719-
# multiple) buffering because it's completely unnecessary
720-
if v.is_double_buffering:
721-
continue
730+
for b, vb in descriptors.items():
731+
for v in vb:
732+
f = v.f
733+
734+
if v.is_read:
735+
# Special case: avoid initialization in the case of double (or
736+
# multiple) buffering because it's completely unnecessary
737+
if v.is_double_buffering:
738+
continue
722739

723-
lhs = b.indexify()._subs(v.xd, v.first_idx.b)
724-
rhs = f.indexify()._subs(v.dim, v.first_idx.f)
740+
lhs = b.indexify()._subs(v.xd, v.first_idx.b)
741+
rhs = f.indexify()._subs(v.dim, v.first_idx.f)
725742

726-
elif v.is_write and init_onwrite(f):
727-
lhs = b.indexify()
728-
rhs = S.Zero
743+
elif v.is_write and init_onwrite(f):
744+
lhs = b.indexify()
745+
rhs = S.Zero
729746

730-
else:
731-
continue
747+
else:
748+
continue
732749

733-
expr = Eq(lhs, rhs)
734-
expr = lower_exprs(expr)
750+
expr = Eq(lhs, rhs)
751+
expr = lower_exprs(expr)
735752

736-
ispace = v.write_to
753+
ispace = v.write_to
737754

738-
guards = {}
739-
guards[None] = GuardBound(v.dim.root.symbolic_min, v.dim.root.symbolic_max)
740-
if v.is_read:
741-
guards[v.xd] = GuardBound(0, v.first_idx.f)
755+
guards = {}
756+
guards[None] = GuardBound(v.dim.root.symbolic_min, v.dim.root.symbolic_max)
757+
if v.is_read:
758+
guards[v.xd] = GuardBound(0, v.first_idx.f)
742759

743-
properties = Properties()
744-
properties = properties.affine(ispace.itdims)
745-
properties = properties.parallelize(ispace.itdims)
760+
properties = Properties()
761+
properties = properties.affine(ispace.itdims)
762+
properties = properties.parallelize(ispace.itdims)
746763

747-
syncs = {None: [InitArray(None, b)]}
764+
syncs = {None: [InitArray(None, b)]}
748765

749-
init.append(Cluster(expr, ispace, guards, properties, syncs))
766+
init.append(Cluster(expr, ispace, guards, properties, syncs))
750767

751768
return init
752769

0 commit comments

Comments
 (0)