Skip to content

Commit 00a97bb

Browse files
committed
compiler: support mutli-buffering
1 parent c8ed5e1 commit 00a97bb

2 files changed

Lines changed: 105 additions & 84 deletions

File tree

devito/passes/clusters/asynchrony.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22

3-
from sympy import true
3+
from sympy import Mod, true
44

55
from devito.ir import (
66
Backward, Forward, GuardBoundNext, PrefetchUpdate, Queue, ReleaseLock, SyncArray,
@@ -78,7 +78,9 @@ def callback(self, clusters, prefix):
7878
d = self.key0(c0)
7979
if d is not dim:
8080
continue
81-
81+
print(c0.guards)
82+
if d in c0.guards and not c0.guards[d].has(Mod):
83+
continue
8284
protected = self._schedule_waitlocks(c0, d, clusters, locks, syncs)
8385
self._schedule_withlocks(c0, d, protected, locks, syncs)
8486

devito/passes/clusters/buffering.py

Lines changed: 101 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def key(f):
116116
# Then we inject them into the Clusters. This involves creating the
117117
# initializing Clusters, and replacing the buffered Functions with the buffers
118118
clusters = InjectBuffers(mapper, sregistry, options).process(clusters)
119-
print(clusters)
119+
120120
return clusters
121121

122122

@@ -142,22 +142,20 @@ def callback(self, clusters, prefix):
142142
return clusters
143143
d = prefix[-1].dim
144144

145-
def key(f, *args):
146-
for (ff, _) in self.mapper:
147-
if f == ff:
148-
return True
149-
return False
145+
key = lambda f, *args: any(f == ff for ff, _ in self.mapper)
150146
bfmap = map_buffered_functions(clusters, key)
151147

152148
# A BufferDescriptor is a simple data structure storing additional
153149
# information about a buffer, harvested from the subset of `clusters`
154150
# that access it
155-
descriptors = {b: BufferDescriptor(f, b, bfmap[f], g)
156-
for (f, g), b in self.mapper.items()
157-
if f in bfmap}
151+
descriptors = {}
152+
for (f, g), b in self.mapper.items():
153+
if f in bfmap:
154+
descriptors.setdefault(b, []).append(BufferDescriptor(f, b, bfmap[f], g))
158155

159156
# Are we inside the right `d`?
160-
descriptors = {b: v for b, v in descriptors.items() if d in v.itdims}
157+
descriptors = {b: [vi for vi in v if d in vi.itdims]
158+
for b, v in descriptors.items()}
161159

162160
if not descriptors:
163161
return clusters
@@ -172,23 +170,28 @@ def key(f, *args):
172170
# Substitution rules to replace buffered Functions with buffers
173171
# E.g., `usave[time+1, x+1, y+1] -> ub0[t1, x+1, y+1]`
174172
subs = {}
175-
for b, v in descriptors.items():
176-
accesses = chain(*[c.scope[v.f] for c in v.clusters])
177-
index_mapper = {i: mds[(v.xd, i)] for i in v.indices}
178-
for a in accesses:
179-
subs[a.access] = b.indexed[[index_mapper.get(i, i) for i in a]]
173+
for b, vb in descriptors.items():
174+
for v in vb:
175+
for c in v.clusters:
176+
if c.guards.get(d) != v.guards.get(d):
177+
continue
178+
subs.setdefault(c, {})
179+
accesses = c.scope[v.f]
180+
index_mapper = {i: mds[(v.xd, i)] for i in v.indices}
181+
for a in accesses:
182+
subs[c][a.access] = b.indexed[[index_mapper.get(i, i) for i in a]]
180183

181184
processed = []
182185
for c in clusters:
183186
# If a buffer is read but never written, then we need to add
184187
# an Eq to step through the next slot
185188
# E.g., `ub[0, x] = usave[time+2, x]`
186-
for _, v in descriptors.items():
189+
for v in chain.from_iterable(descriptors.values()):
187190
if not v.is_readonly:
188191
continue
189192
if c not in v.firstread:
190193
continue
191-
if not c.guards.get(d) == v.guards.get(d):
194+
if c.guards.get(d) != v.guards.get(d):
192195
continue
193196

194197
idxf = v.last_idx[c]
@@ -219,7 +222,7 @@ def key(f, *args):
219222
processed.append(Cluster(expr, ispace, guards, properties, syncs))
220223

221224
# Substitute the buffered Functions with the buffers
222-
exprs = [uxreplace(e, subs) for e in c.exprs]
225+
exprs = [uxreplace(e, subs.get(c, {})) for e in c.exprs]
223226
ispace = c.ispace.augment(subiters)
224227
properties = c.properties.sequentialize(d)
225228
processed.append(
@@ -228,12 +231,12 @@ def key(f, *args):
228231

229232
# Append the copy-back if `c` is the last-write of some buffers
230233
# E.g., `usave[time+1, x] = ub[t1, x]`
231-
for _, v in descriptors.items():
234+
for v in chain.from_iterable(descriptors.values()):
232235
if v.is_readonly:
233236
continue
234237
if c not in v.lastwrite:
235238
continue
236-
if not c.guards.get(d) == v.guards.get(d):
239+
if c.guards.get(d) != v.guards.get(d):
237240
continue
238241

239242
idxf = v.last_idx[c]
@@ -269,36 +272,37 @@ def key(f, *args):
269272
return init + processed
270273

271274
def _optimize(self, clusters, descriptors):
272-
for b, v in descriptors.items():
273-
if v.is_writeonly:
274-
# `b` might be written by multiple, potentially mutually
275-
# exclusive, equations. For example, two equations that have or
276-
# will have complementary guards, hence only one will be
277-
# executed. In such a case, we can split the equations over
278-
# separate IterationSpaces
279-
key0 = lambda: Stamp()
280-
elif v.is_readonly:
281-
# `b` is read multiple times -- this could just be the case of
282-
# coupled equations, so we more cautiously perform a
283-
# "buffer-wise" splitting of the IterationSpaces (i.e., only
284-
# relevant if there are at least two read-only buffers)
285-
stamp = Stamp()
286-
key0 = lambda: stamp # noqa: B023
287-
else:
288-
continue
289-
290-
processed = []
291-
for c in clusters:
292-
if b not in c.functions:
293-
processed.append(c)
275+
for b, vb in descriptors.items():
276+
for v in vb:
277+
if v.is_writeonly:
278+
# `b` might be written by multiple, potentially mutually
279+
# exclusive, equations. For example, two equations that have or
280+
# will have complementary guards, hence only one will be
281+
# executed. In such a case, we can split the equations over
282+
# separate IterationSpaces
283+
key0 = lambda: Stamp()
284+
elif v.is_readonly:
285+
# `b` is read multiple times -- this could just be the case of
286+
# coupled equations, so we more cautiously perform a
287+
# "buffer-wise" splitting of the IterationSpaces (i.e., only
288+
# relevant if there are at least two read-only buffers)
289+
stamp = Stamp()
290+
key0 = lambda: stamp # noqa: B023
291+
else:
294292
continue
295293

296-
key1 = lambda d: not d._defines & v.dim._defines # noqa: B023
297-
dims = c.ispace.project(key1).itdims
298-
ispace = c.ispace.lift(dims, key0())
299-
processed.append(c.rebuild(ispace=ispace))
294+
processed = []
295+
for c in clusters:
296+
if b not in c.functions:
297+
processed.append(c)
298+
continue
299+
300+
key1 = lambda d: not d._defines & v.dim._defines # noqa: B023
301+
dims = c.ispace.project(key1).itdims
302+
ispace = c.ispace.lift(dims, key0())
303+
processed.append(c.rebuild(ispace=ispace))
300304

301-
clusters = processed
305+
clusters = processed
302306

303307
return clusters
304308

@@ -314,11 +318,11 @@ def _reuse(self, init, clusters, descriptors):
314318
cbk = lambda v: v
315319

316320
mapper = as_mapper(descriptors, key=lambda b: b._signature)
317-
mapper = {k: cbk(v) for k, v in mapper.items() if cbk(v)}
321+
mapper = {k: [cbk(v) for v in vb if cbk(v)] for k, vb in mapper.items()}
318322

319323
subs = {}
320324
drop = set()
321-
for reusable in mapper.values():
325+
for reusable in chain.from_iterable(mapper.values()):
322326
retain = reusable.pop(0)
323327
drop.update(reusable)
324328

@@ -365,18 +369,24 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
365369
# {buffered Function -> Buffer}
366370
xds = {}
367371
mapper = {}
372+
extras = {}
368373
for f, clusters in bfmap.items():
369374
for k, ck in groupby(clusters, key=lambda c: c.guards):
375+
ck = list(ck)
370376
exprs = flatten(c.exprs for c in ck)
371377

372378
bdims = key(f, exprs)
373379

374380
dims = [d for d in f.dimensions if d not in bdims]
375381
if len(dims) != 1:
376382
raise CompilationError(f"Unsupported multi-dimensional `buffering` "
377-
f"required by `{f}`")
383+
f"required by `{f}`")
378384
dim = dims.pop()
379385

386+
if not dim._defines & k.keys():
387+
extras.setdefault(f, []).append(k)
388+
continue
389+
380390
if is_buffering(exprs):
381391
# Multi-level buffering
382392
# NOTE: a bit rudimentary (we could go through the exprs one by one
@@ -386,13 +396,15 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
386396
buffer, = buffers
387397
xd = buffer.indices[dim]
388398
else:
389-
size = infer_buffer_size(f, dim, clusters)
399+
400+
size = infer_buffer_size(f, dim, ck)
390401

391402
if async_degree is not None:
392403
if async_degree < size:
393404
warning(
394405
'Ignoring provided asynchronous degree as it would be '
395-
f'too small for the required buffer (provided {async_degree}, '
406+
'too small for the required buffer'
407+
f' (provided {async_degree}, '
396408
f'but need at least {size} for `{f.name}`)'
397409
)
398410
else:
@@ -421,6 +433,13 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
421433
padding=padding, grid=f.grid, halo=f.halo,
422434
space='mapped', mapped=f, f=f)
423435

436+
for f, k in extras.items():
437+
for (ff, kk) in dict(mapper):
438+
if f == ff:
439+
for ki in k:
440+
if ki.keys() & set(mapper[(ff, kk)].dimensions):
441+
mapper[(f, ki)] = mapper[(ff, kk)]
442+
424443
return mapper
425444

426445

@@ -453,7 +472,7 @@ def __init__(self, f, b, clusters, guards):
453472
self.indices = extract_indices(f, self.dim, clusters)
454473

455474
def __repr__(self):
456-
return f"Descriptor[{self.f} -> {self.b}]"
475+
return f"Descriptor[{self.f} -> {self.b}], {self.guards}"
457476

458477
@property
459478
def size(self):
@@ -668,7 +687,7 @@ def make_mds(descriptors, prefix, sregistry):
668687
inspecting all buffers so that ModuloDimensions are reused when possible.
669688
"""
670689
mds = defaultdict(int)
671-
for v in descriptors.values():
690+
for v in chain.from_iterable(descriptors.values()):
672691
size = v.xd.symbolic_size
673692

674693
if size == 1:
@@ -684,7 +703,6 @@ def make_mds(descriptors, prefix, sregistry):
684703
# same strategy is also applied in clusters/algorithms/Stepper
685704
key = lambda i: -np.inf if i - p == 0 else (i - p) # noqa: B023
686705
indices = sorted(v.indices, key=key)
687-
v_mds = None
688706

689707
for k, i in enumerate(indices):
690708
k = (v.xd, i)
@@ -711,42 +729,43 @@ def init_buffers(descriptors, options):
711729
init_onwrite = options['buf-init-onwrite']
712730

713731
init = []
714-
for b, v in descriptors.items():
715-
f = v.f
716-
717-
if v.is_read:
718-
# Special case: avoid initialization in the case of double (or
719-
# multiple) buffering because it's completely unnecessary
720-
if v.is_double_buffering:
721-
continue
732+
for b, vb in descriptors.items():
733+
for v in vb:
734+
f = v.f
735+
736+
if v.is_read:
737+
# Special case: avoid initialization in the case of double (or
738+
# multiple) buffering because it's completely unnecessary
739+
if v.is_double_buffering:
740+
continue
722741

723-
lhs = b.indexify()._subs(v.xd, v.first_idx.b)
724-
rhs = f.indexify()._subs(v.dim, v.first_idx.f)
742+
lhs = b.indexify()._subs(v.xd, v.first_idx.b)
743+
rhs = f.indexify()._subs(v.dim, v.first_idx.f)
725744

726-
elif v.is_write and init_onwrite(f):
727-
lhs = b.indexify()
728-
rhs = S.Zero
745+
elif v.is_write and init_onwrite(f):
746+
lhs = b.indexify()
747+
rhs = S.Zero
729748

730-
else:
731-
continue
749+
else:
750+
continue
732751

733-
expr = Eq(lhs, rhs)
734-
expr = lower_exprs(expr)
752+
expr = Eq(lhs, rhs)
753+
expr = lower_exprs(expr)
735754

736-
ispace = v.write_to
755+
ispace = v.write_to
737756

738-
guards = {}
739-
guards[None] = GuardBound(v.dim.root.symbolic_min, v.dim.root.symbolic_max)
740-
if v.is_read:
741-
guards[v.xd] = GuardBound(0, v.first_idx.f)
757+
guards = {}
758+
guards[None] = GuardBound(v.dim.root.symbolic_min, v.dim.root.symbolic_max)
759+
if v.is_read:
760+
guards[v.xd] = GuardBound(0, v.first_idx.f)
742761

743-
properties = Properties()
744-
properties = properties.affine(ispace.itdims)
745-
properties = properties.parallelize(ispace.itdims)
762+
properties = Properties()
763+
properties = properties.affine(ispace.itdims)
764+
properties = properties.parallelize(ispace.itdims)
746765

747-
syncs = {None: [InitArray(None, b)]}
766+
syncs = {None: [InitArray(None, b)]}
748767

749-
init.append(Cluster(expr, ispace, guards, properties, syncs))
768+
init.append(Cluster(expr, ispace, guards, properties, syncs))
750769

751770
return init
752771

0 commit comments

Comments
 (0)