Skip to content

Commit 02b311b

Browse files
committed
compiler: fix various corner case of multi buffering
1 parent 04251ff commit 02b311b

6 files changed

Lines changed: 105 additions & 55 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ def __new__(cls, *args, **kwargs):
237237
cond = d.relation(cond, GuardFactor(d))
238238
conditionals[d] = cond
239239

240+
# Replace the ConditionalDimensions in `expr`
241+
for d, cond in conditionals.items():
242+
# Replace dimension with index
243+
index = d.index
244+
index = index - relational_min(cond, d.parent)
245+
shift = relational_shift(cond, d.parent)
246+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
247+
240248
# Merge conditionals when possible. E.g if we have an implicit_dim
241249
# and there is a dimension with the same parent, we ca merged
242250
# its condition
@@ -247,19 +255,13 @@ def __new__(cls, *args, **kwargs):
247255
if cd.parent == d.parent and cd != d:
248256
cond = conditionals.pop(d)
249257
mode = cd.relation and d.relation
250-
conditionals[cd] = mode(cond, conditionals[cd])
258+
if issubclass(mode, sympy.Or):
259+
conditionals[d] = cond
260+
conditionals.pop(cd)
261+
else:
262+
conditionals[cd] = mode(cond, conditionals[cd])
251263
break
252264

253-
conditionals = frozendict(conditionals)
254-
255-
# Replace the ConditionalDimensions in `expr`
256-
for d, cond in conditionals.items():
257-
# Replace dimension with index
258-
index = d.index
259-
index = index - relational_min(cond, d.parent)
260-
shift = relational_shift(cond, d.parent)
261-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
262-
263265
# Lower all Differentiable operations into SymPy operations
264266
rhs = diff2sympy(expr.rhs)
265267

devito/passes/clusters/asynchrony.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
193193
if c.properties.is_prefetchable(d._defines):
194194
_actions_from_update_memcpy(c, d, clusters, actions, sregistry)
195195
elif d.is_Custom and is_integer(c.ispace[d].size):
196-
_actions_from_init(c, d, actions)
196+
_actions_from_init(c, d, clusters, actions)
197197

198198
# Attach the computed Actions
199199
processed = []
@@ -214,7 +214,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
214214
return processed
215215

216216

217-
def _actions_from_init(c, d, actions):
217+
def _actions_from_init(c, d, clusters, actions):
218218
e = c.exprs[0]
219219
function = e.rhs.function
220220
target = e.lhs.function

devito/passes/clusters/buffering.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from collections import defaultdict, namedtuple
22
from functools import cached_property
3-
from itertools import chain
3+
from itertools import chain, groupby
44

55
import numpy as np
6-
from sympy import S, simplify
6+
from sympy import Mod, S, simplify
77

88
from devito.exceptions import CompilationError
99
from devito.ir import (
@@ -116,7 +116,7 @@ def key(f):
116116
# Then we inject them into the Clusters. This involves creating the
117117
# initializing Clusters, and replacing the buffered Functions with the buffers
118118
clusters = InjectBuffers(mapper, sregistry, options).process(clusters)
119-
119+
print(clusters)
120120
return clusters
121121

122122

@@ -142,14 +142,18 @@ def callback(self, clusters, prefix):
142142
return clusters
143143
d = prefix[-1].dim
144144

145-
key = lambda f, *args: f in self.mapper
145+
def key(f, *args):
146+
for (ff, _) in self.mapper:
147+
if f == ff:
148+
return True
149+
return False
146150
bfmap = map_buffered_functions(clusters, key)
147151

148152
# A BufferDescriptor is a simple data structure storing additional
149153
# information about a buffer, harvested from the subset of `clusters`
150154
# that access it
151-
descriptors = {b: BufferDescriptor(f, b, bfmap[f])
152-
for f, b in self.mapper.items()
155+
descriptors = {b: BufferDescriptor(f, b, bfmap[f], g)
156+
for (f, g), b in self.mapper.items()
153157
if f in bfmap}
154158

155159
# Are we inside the right `d`?
@@ -184,6 +188,8 @@ def callback(self, clusters, prefix):
184188
continue
185189
if c not in v.firstread:
186190
continue
191+
if not c.guards.get(d) == v.guards.get(d):
192+
continue
187193

188194
idxf = v.last_idx[c]
189195
idxb = mds[(v.xd, idxf)]
@@ -203,7 +209,7 @@ def callback(self, clusters, prefix):
203209
guards = c.guards
204210

205211
properties = c.properties.sequentialize(d)
206-
if not isinstance(d, BufferDimension):
212+
if not isinstance(d, BufferDimension) and c.guards[d].has(Mod):
207213
properties = properties.prefetchable(d)
208214
# `c` may be a HaloTouch Cluster, so with no vision of the `bdims`
209215
properties = properties.parallelize(v.bdims).affine(v.bdims)
@@ -227,6 +233,8 @@ def callback(self, clusters, prefix):
227233
continue
228234
if c not in v.lastwrite:
229235
continue
236+
if not c.guards.get(d) == v.guards.get(d):
237+
continue
230238

231239
idxf = v.last_idx[c]
232240
idxb = mds[(v.xd, idxf)]
@@ -358,15 +366,16 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
358366
xds = {}
359367
mapper = {}
360368
for f, clusters in bfmap.items():
361-
exprs = flatten(c.exprs for c in clusters)
369+
for k, ck in groupby(clusters, key=lambda c: c.guards):
370+
exprs = flatten(c.exprs for c in ck)
362371

363-
bdims = key(f, exprs)
372+
bdims = key(f, exprs)
364373

365-
dims = [d for d in f.dimensions if d not in bdims]
366-
if len(dims) != 1:
367-
raise CompilationError(f"Unsupported multi-dimensional `buffering` "
368-
f"required by `{f}`")
369-
dim = dims.pop()
374+
dims = [d for d in f.dimensions if d not in bdims]
375+
if len(dims) != 1:
376+
raise CompilationError(f"Unsupported multi-dimensional `buffering` "
377+
f"required by `{f}`")
378+
dim = dims.pop()
370379

371380
if is_buffering(exprs):
372381
# Multi-level buffering
@@ -391,25 +400,25 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
391400
else:
392401
size = async_degree
393402

394-
# A special CustomDimension to use in place of `dim` in the buffer
395-
try:
396-
xd = xds[(dim, size)]
397-
except KeyError:
398-
name = sregistry.make_name(prefix='db')
399-
xd = xds[(dim, size)] = BufferDimension(name, 0, size-1, size, dim)
400-
extra_kwargs = {}
401-
402-
# The buffer dimensions
403-
dimensions = list(f.dimensions)
404-
assert dim in f.dimensions
405-
dimensions[dimensions.index(dim)] = xd
406-
407-
# Finally create the actual buffer
408-
cls = callback or Array
409-
name = sregistry.make_name(prefix=f'{f.name}b')
410-
mapper[f] = cls(name=name, dimensions=dimensions, dtype=f.dtype,
411-
grid=f.grid, halo=f.halo,
412-
space='mapped', mapped=f, f=f, **extra_kwargs)
403+
# A special CustomDimension to use in place of `dim` in the buffer
404+
try:
405+
xd = xds[(dim, size)]
406+
except KeyError:
407+
name = sregistry.make_name(prefix='db')
408+
xd = xds[(dim, size)] = BufferDimension(name, 0, size-1, size, dim)
409+
extra_kwargs = {}
410+
411+
# The buffer dimensions
412+
dimensions = list(f.dimensions)
413+
assert dim in f.dimensions
414+
dimensions[dimensions.index(dim)] = xd
415+
416+
# Finally create the actual buffer
417+
cls = callback or Array
418+
name = sregistry.make_name(prefix=f'{f.name}b')
419+
mapper[(f, k)] = cls(name=name, dimensions=dimensions, dtype=f.dtype,
420+
grid=f.grid, halo=f.halo,
421+
space='mapped', mapped=f, f=f, **extra_kwargs)
413422

414423
return mapper
415424

@@ -429,10 +438,11 @@ def map_buffered_functions(clusters, key):
429438

430439
class BufferDescriptor:
431440

432-
def __init__(self, f, b, clusters):
441+
def __init__(self, f, b, clusters, guards):
433442
self.f = f
434443
self.b = b
435444
self.clusters = clusters
445+
self.guards = guards
436446

437447
self.xd, = b.find(BufferDimension)
438448
self.bdims = tuple(d for d in b.dimensions if d is not self.xd)
@@ -673,8 +683,9 @@ def make_mds(descriptors, prefix, sregistry):
673683
# same strategy is also applied in clusters/algorithms/Stepper
674684
key = lambda i: -np.inf if i - p == 0 else (i - p) # noqa: B023
675685
indices = sorted(v.indices, key=key)
686+
v_mds = None
676687

677-
for i in indices:
688+
for k, i in enumerate(indices):
678689
k = (v.xd, i)
679690
if k in mds:
680691
continue

devito/symbolics/extended_sympy.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020
from devito.types import Symbol
2121
from devito.types.basic import Basic
22-
from devito.types.relational import Ge
2322

2423
__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa
2524
'LeftShift', 'RightShift', 'IntDiv', 'Terminal', 'CallFromPointer',
@@ -49,11 +48,6 @@ def canonical(self):
4948
def negated(self):
5049
return CondNe(*self.args, evaluate=False)
5150

52-
@property
53-
def _as_min(self):
54-
from devito.symbolics.extended_dtypes import INT
55-
return INT(Ge(*self.args))
56-
5751

5852
class CondNe(sympy.Ne):
5953

devito/types/relational.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,4 +320,5 @@ def _(expr, s):
320320
def _(expr, s):
321321
if isinstance(expr.lhs, sympy.Mod):
322322
return 0
323-
return expr._as_min
323+
from devito.symbolics.extended_dtypes import INT
324+
return INT(Ge(*expr.args))

tests/test_buffering.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,45 @@ def test_multi_cond_v1():
818818
for i in range(nt-1):
819819
assert np.allclose(f.data[i], i*2)
820820
assert np.allclose(f.data[nt-1], ntmod - 2)
821+
822+
823+
@pytest.mark.parametrize("factor", [1, 2, 3])
824+
def test_buffering_multi_cond(factor):
825+
grid = Grid((16, 16))
826+
827+
nt = 5
828+
ntmod = (nt - 1) * factor + 1
829+
830+
ct0 = ConditionalDimension(name="ct0", parent=grid.time_dim, factor=factor,
831+
relation=Or)
832+
f = TimeFunction(grid=grid, name='f', time_order=0, space_order=0,
833+
time_dim=ct0, save=nt)
834+
T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0)
835+
836+
eqs = []
837+
eqs.append(Eq(T, grid.time_dim))
838+
839+
# conditional dimension for the last sample in the operator
840+
ctend = ConditionalDimension(name="ctend", parent=grid.time_dim,
841+
condition=CondEq(grid.time_dim, ntmod - 2),
842+
relation=Or)
843+
844+
eqs.append(Eq(f, T)) # this to save times from 0 to nt - 2
845+
# this to save the last time sample nt - 1
846+
eqs.append(Eq(f.forward, T+1, implicit_dims=ctend))
847+
848+
# run operator with serialization
849+
op = Operator(eqs, opt='buffering')
850+
op.apply(time_m=0, time_M=ntmod-2)
851+
852+
# Now run backward as well with buffering
853+
854+
f_all = TimeFunction(grid=grid, name='f_all', time_order=0,
855+
space_order=0, time_dim=ct0, save=nt)
856+
857+
eq_all = [Eq(f_all, f)]
858+
eq_all.append(Eq(f_all.forward, f.forward, implicit_dims=ctend))
859+
op_all = Operator(eq_all, opt='buffering')
860+
op_all.apply(time_m=0, time_M=ntmod-2)
861+
862+
assert np.allclose(f_all.data[:, 11, 11], factor * np.arange(nt))

0 commit comments

Comments
 (0)