Skip to content

Commit a453d38

Browse files
committed
compiler: support mutli-buffering
1 parent cd1f1e4 commit a453d38

17 files changed

Lines changed: 690 additions & 248 deletions

File tree

devito/arch/archinfo.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -494,13 +494,15 @@ def parse_product_arch():
494494
return None
495495

496496

497+
device_vars = (
498+
'CUDA_VISIBLE_DEVICES',
499+
'NVIDIA_VISIBLE_DEVICES',
500+
'ROCR_VISIBLE_DEVICES',
501+
'HIP_VISIBLE_DEVICES'
502+
)
503+
504+
497505
def get_visible_devices():
498-
device_vars = (
499-
'CUDA_VISIBLE_DEVICES',
500-
'NVIDIA_VISIBLE_DEVICES',
501-
'ROCR_VISIBLE_DEVICES',
502-
'HIP_VISIBLE_DEVICES'
503-
)
504506
for v in device_vars:
505507
try:
506508
return v, tuple(int(i) for i in os.environ[v].split(','))

devito/ir/clusters/algorithms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,17 @@ def guard(clusters):
259259
# Separate out the indirect ConditionalDimensions, which only serve
260260
# the purpose of protecting from OOB accesses
261261
cds = [d for d in cds if not d.indirect]
262+
modes = [cd.relation for cd in cds]
263+
if modes.count('strict') > 1:
264+
raise CompilationError("Only one `strict` condition"
265+
"can be used in an equation")
266+
elif 'strict' in modes:
267+
mode = 'strict'
268+
else:
269+
mode = sympy.And if sympy.And in modes else sympy.Or
262270

263271
# Chain together all `cds` conditions from all expressions in `c`
264272
guards = {}
265-
mode = sympy.Or
266273
for cd in cds:
267274
# `BOTTOM` parent implies a guard that lives outside of
268275
# any iteration space, which corresponds to the placeholder None
@@ -279,7 +286,6 @@ def guard(clusters):
279286

280287
# Pull `cd` from any expr
281288
condition = guards.setdefault(k, [])
282-
mode = mode and cd.relation
283289
for e in exprs:
284290
try:
285291
condition.append(e.conditionals[cd])
@@ -296,7 +302,10 @@ def guard(clusters):
296302

297303
# Combination `mode` is And by default.
298304
# If all conditions are Or then Or combination `mode` is used.
299-
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
305+
if mode == 'strict':
306+
guards = {d: v[0] for d, v in guards.items()}
307+
else:
308+
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
300309

301310
# Construct a guarded Cluster
302311
processed.append(c.rebuild(exprs=exprs, guards=guards))

devito/ir/equations/algorithms.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
from functools import singledispatch
33

44
from devito.data.allocators import DataReference
5+
from devito.finite_differences.differentiable import diff2sympy
6+
from devito.ir.support import GuardFactor
57
from devito.logger import warning
68
from devito.symbolics import (
7-
retrieve_dimensions, retrieve_functions, retrieve_indexed, uxreplace
9+
IntDiv, retrieve_dimensions, retrieve_functions, retrieve_indexed, uxreplace
810
)
911
from devito.tools import (
1012
Ordering, as_tuple, filter_ordered, filter_sorted, flatten, frozendict
1113
)
12-
from devito.types import ConditionalDimension, Dimension, Eq, IgnoreDimSort, SubDimension
14+
from devito.types import (
15+
ConditionalDimension, Dimension, Eq, IgnoreDimSort, SubDimension, relational_min,
16+
relational_shift
17+
)
1318
from devito.types.array import Array
1419
from devito.types.basic import AbstractFunction
1520
from devito.types.dimension import MultiSubDimension, Thickness
@@ -339,3 +344,49 @@ def _(d, mapper, rebuilt, sregistry):
339344
kwargs['functions'] = functions
340345

341346
mapper[d] = d._rebuild(**kwargs)
347+
348+
349+
def generate_conditionals(expr, input_expr, ordering):
350+
"""
351+
Generate the conditionals for the given expression,
352+
based on the input expression and the ordering of dimensions.
353+
"""
354+
# Construct the conditionals
355+
conditionals = {}
356+
for d in ordering:
357+
if not d.is_Conditional:
358+
continue
359+
if d.condition is None:
360+
conditionals[d] = GuardFactor(d)
361+
else:
362+
cond = diff2sympy(lower_exprs(d.condition))
363+
if d._factor is not None:
364+
cond = d.relation(cond, GuardFactor(d))
365+
conditionals[d] = cond
366+
367+
# Merge conditionals when possible. E.g if we have an implicit_dim
368+
# and there is a dimension with the same parent, we can merge
369+
# their conditions
370+
for d in input_expr.implicit_dims:
371+
if d not in conditionals:
372+
continue
373+
for cd in list(conditionals):
374+
if cd.parent == d.parent and cd != d:
375+
cond = conditionals.pop(d)
376+
if d.relation == 'strict':
377+
conditionals[cd] = conditionals[d] = cond
378+
else:
379+
mode = cd.relation and d.relation
380+
conditionals[cd] = mode(cond, conditionals[cd])
381+
break
382+
383+
# Replace the ConditionalDimensions in `expr`
384+
for d, cond in conditionals.items():
385+
# Replace dimension with index
386+
index = d.index
387+
if d.condition is not None and d in expr.free_symbols:
388+
index = index - relational_min(cond, d.parent)
389+
shift = relational_shift(cond, d.parent)
390+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
391+
392+
return expr, conditionals

devito/ir/equations/equation.py

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,15 @@
55
import sympy
66

77
from devito.finite_differences.differentiable import diff2sympy
8-
from devito.ir.equations.algorithms import dimension_sort, lower_exprs
8+
from devito.ir.equations.algorithms import dimension_sort, generate_conditionals
99
from devito.ir.support import (
10-
GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses
10+
Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses
1111
)
12-
from devito.symbolics import IntDiv, limits_mapper, retrieve_accesses, uxreplace
12+
from devito.symbolics import limits_mapper, retrieve_accesses
1313
from devito.tools import (
1414
Pickable, Tag, as_hashable, filter_sorted, frozendict, reuse_if_unchanged
1515
)
16-
from devito.symbolics import IntDiv, limits_mapper, uxreplace
17-
from devito.tools import Pickable, Tag, frozendict
18-
from devito.types import (
19-
Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min, relational_shift
20-
)
16+
from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax
2117

2218
__all__ = [
2319
'ClusterizedEq',
@@ -296,44 +292,7 @@ def __new__(cls, *args, **kwargs):
296292
relations=ordering.relations, mode='partial')
297293
ispace = IterationSpace(intervals, iterators)
298294

299-
# Construct the conditionals
300-
conditionals = {}
301-
for d in ordering:
302-
if not d.is_Conditional:
303-
continue
304-
if d.condition is None:
305-
conditionals[d] = GuardFactor(d)
306-
else:
307-
cond = diff2sympy(lower_exprs(d.condition))
308-
if d._factor is not None:
309-
cond = d.relation(cond, GuardFactor(d))
310-
conditionals[d] = cond
311-
312-
# Replace the ConditionalDimensions in `expr`
313-
for d, cond in conditionals.items():
314-
# Replace dimension with index
315-
index = d.index
316-
index = index - relational_min(cond, d.parent)
317-
shift = relational_shift(cond, d.parent)
318-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
319-
320-
# Merge conditionals when possible. E.g if we have an implicit_dim
321-
# and there is a dimension with the same parent, we ca merged
322-
# its condition
323-
for d in input_expr.implicit_dims:
324-
if d not in conditionals:
325-
continue
326-
for cd in dict(conditionals):
327-
if cd.parent == d.parent and cd != d:
328-
cond = conditionals.pop(d)
329-
mode = cd.relation and d.relation
330-
if issubclass(mode, sympy.Or):
331-
conditionals[d] = cond
332-
conditionals.pop(cd)
333-
else:
334-
conditionals[cd] = mode(cond, conditionals[cd])
335-
break
336-
295+
expr, conditionals = generate_conditionals(expr, input_expr, ordering)
337296
# Lower all Differentiable operations into SymPy operations
338297
rhs = diff2sympy(expr.rhs)
339298

devito/ir/support/guards.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from sympy.logic.boolalg import BooleanFunction
1414

1515
from devito.ir.support.space import Forward, IterationDirection
16-
from devito.symbolics import CondEq, CondNe, search
16+
from devito.symbolics import CondEq, CondNe, IntDiv, search
17+
from devito.symbolics.manipulation import _uxreplace_handle, _uxreplace_registry
1718
from devito.tools import Pickable, as_tuple, frozendict, split
1819
from devito.types import Dimension, LocalObject
1920

@@ -31,6 +32,34 @@
3132
]
3233

3334

35+
@singledispatch
36+
def bound_index(expr, dim, dir):
37+
if dir == Forward:
38+
return expr._subs(dim, dim + 1)
39+
else:
40+
return expr._subs(dim, dim - 1)
41+
42+
43+
@bound_index.register(Expr)
44+
def _(expr, dim, dir):
45+
if not expr.args:
46+
if dir == Forward:
47+
return expr._subs(dim, dim + 1)
48+
else:
49+
return expr._subs(dim, dim - 1)
50+
return expr.func(*[bound_index(a, dim, dir) for a in expr.args])
51+
52+
53+
@bound_index.register(IntDiv)
54+
def _(expr, dim, dir):
55+
v = dim.symbolic_factor
56+
p0 = dim.root
57+
if dir == Forward:
58+
return Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
59+
else:
60+
return (p0 - 1) - abs(p0 - 1) % v
61+
62+
3463
class AbstractGuard:
3564
pass
3665

@@ -138,37 +167,29 @@ class BaseGuardBoundNext(Guard, Pickable):
138167
given `direction`.
139168
"""
140169

141-
__rargs__ = ('d', 'direction')
170+
__rargs__ = ('d', 'index', 'direction')
171+
__rkwargs__ = ('d_min', 'd_max')
142172

143-
def __new__(cls, d, direction, **kwargs):
173+
def __new__(cls, d, index, direction,
174+
d_min=None, d_max=None, **kwargs):
144175
assert isinstance(d, Dimension)
145176
assert isinstance(direction, IterationDirection)
146177

147-
if direction == Forward:
148-
p0 = d.root
149-
p1 = d.root.symbolic_max
178+
# Always take the next index in the iteration direction
179+
next_index = bound_index(index, d, direction)
150180

151-
if d.is_Conditional:
152-
v = d.symbolic_factor
153-
# Round `p0 + 1` up to the nearest multiple of `v`
154-
p0 = Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
155-
else:
156-
p0 = p0 + 1
181+
# The direction might be forward but accessing c - d
182+
# making the access backward w.r.t
183+
# Update direction according to access direction for valid guard
184+
if index.has(-d):
185+
direction = -direction
157186

187+
if direction == Forward:
188+
p0 = next_index
189+
p1 = d_max or d.root.symbolic_max
158190
else:
159-
p0 = d.root.symbolic_min
160-
p1 = d.root
161-
162-
if d.is_Conditional:
163-
v = d.symbolic_factor
164-
# Round `p1 - 1` down to the nearest sub-multiple of `v`
165-
# NOTE: we use ABS to make sure we handle negative values properly.
166-
# Once `p1 - 1` is negative (e.g. `iteration=time - 1` and `time=0`),
167-
# as long as we get a negative number, rather than 0 and even if it's
168-
# not `-v`, we're good
169-
p1 = (p1 - 1) - abs(p1 - 1) % v
170-
else:
171-
p1 = p1 - 1
191+
p0 = d_min if d_min is not None else d.root.symbolic_min
192+
p1 = next_index
172193

173194
try:
174195
if cls.__base__._eval_relation(p0, p1) is true:
@@ -180,12 +201,15 @@ def __new__(cls, d, direction, **kwargs):
180201

181202
obj.d = d
182203
obj.direction = direction
204+
obj.index = index
205+
obj.d_min = d_min
206+
obj.d_max = d_max
183207

184208
return obj
185209

186210
@property
187211
def _args_rebuild(self):
188-
return (self.d, self.direction)
212+
return (self.d, self.index, self.direction)
189213

190214

191215
class GuardBoundNextLe(BaseGuardBoundNext, Le):
@@ -544,3 +568,11 @@ def pairwise_or(*guards):
544568
pass
545569

546570
return guard
571+
572+
573+
_uxreplace_registry.register(BaseGuardBoundNext)
574+
575+
576+
@_uxreplace_handle.register(BaseGuardBoundNext)
577+
def _(expr, args, kwargs):
578+
return expr.func(expr.d, expr.index, expr.direction, **kwargs)

devito/ir/support/space.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,14 @@ def __repr__(self):
628628
def __hash__(self):
629629
return hash(self._name)
630630

631+
def __neg__(self):
632+
if self._name == '++':
633+
return Backward
634+
elif self._name == '--':
635+
return Forward
636+
else:
637+
return Any
638+
631639

632640
Forward = IterationDirection('++')
633641
"""Forward iteration direction ('++')."""

devito/parameters.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,13 @@ def __init__(self, params):
298298
self.params = params
299299

300300
def __enter__(self):
301+
# Prevent having multiple conflicting device vars, e.g
302+
# switching CUDA_VISIBLE_DEVICES but having NVIDIA_VISIBLE_DEVICES set.
303+
from devito.arch.archinfo import device_vars
304+
if any(k in device_vars for k in self.params):
305+
for dk in device_vars:
306+
os.environ.pop(dk, None)
307+
301308
for k, v in self.params.items():
302309
if v is None:
303310
os.environ.pop(k, None)

0 commit comments

Comments
 (0)