Skip to content

Commit f67a3c1

Browse files
committed
compiler: Revamp to enable blocking before CIRE
1 parent 3de03ca commit f67a3c1

14 files changed

Lines changed: 481 additions & 297 deletions

File tree

devito/core/cpu.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from devito.operator.operator import rcompile
66
from devito.passes import stream_dimensions
77
from devito.passes.clusters import (
8-
Lift, blocking, buffering, cire, cse, factorize, fission, fuse, optimize_hyperplanes,
9-
optimize_pows
8+
Lift, apply_par_tiles, blocking, buffering, cire, cse, factorize, fission, fuse,
9+
optimize_hyperplanes, optimize_pows
1010
)
1111
from devito.passes.equations import collect_derivatives
1212
from devito.passes.iet import (
@@ -67,6 +67,7 @@ def _normalize_kwargs(cls, **kwargs):
6767
reduce=oo.pop('par-tile-reduce', None))
6868

6969
# CIRE
70+
o['cire-block-temps'] = oo.pop('cire-block-temps', cls.CIRE_BLOCK_TEMPS)
7071
o['min-storage'] = oo.pop('min-storage', False)
7172
o['cire-rotate'] = oo.pop('cire-rotate', False)
7273
o['cire-maxpar'] = oo.pop('cire-maxpar', False)
@@ -198,6 +199,9 @@ def _specialize_clusters(cls, clusters, **kwargs):
198199
if options['blocklazy']:
199200
clusters = blocking(clusters, sregistry, options)
200201

202+
# Unfold the `par-tile`s, if any
203+
clusters = apply_par_tiles(clusters, **kwargs)
204+
201205
return clusters
202206

203207
@classmethod

devito/core/gpu.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from devito.operator.operator import rcompile
88
from devito.passes import is_on_device, stream_dimensions
99
from devito.passes.clusters import (
10-
Lift, blocking, buffering, cire, cse, factorize, fission, fuse, memcpy_prefetch,
11-
optimize_pows, tasking
10+
Lift, apply_par_tiles, blocking, buffering, cire, cse, factorize, fission, fuse,
11+
memcpy_prefetch, optimize_pows, tasking
1212
)
1313
from devito.passes.equations import collect_derivatives
1414
from devito.passes.iet import (
@@ -38,7 +38,9 @@
3838

3939
class DeviceOperatorMixin:
4040

41+
# Overrides the default values in the main Operator class
4142
BLOCK_LEVELS = 0
43+
CIRE_BLOCK_TEMPS = False
4244
MPI_MODES = (True, 'basic',)
4345

4446
GPU_FIT = 'all-fallback'
@@ -76,6 +78,7 @@ def _normalize_kwargs(cls, **kwargs):
7678
o['skewing'] = oo.pop('skewing', False)
7779

7880
# CIRE
81+
o['cire-block-temps'] = oo.pop('cire-block-temps', cls.CIRE_BLOCK_TEMPS)
7982
o['min-storage'] = False
8083
o['cire-rotate'] = False
8184
o['cire-maxpar'] = oo.pop('cire-maxpar', 'basic')
@@ -239,6 +242,9 @@ def _specialize_clusters(cls, clusters, **kwargs):
239242
if options['blocklazy']:
240243
clusters = blocking(clusters, sregistry, options)
241244

245+
# Unfold the `par-tile`s, if any
246+
clusters = apply_par_tiles(clusters, **kwargs)
247+
242248
return clusters
243249

244250
@classmethod

devito/core/operator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ class BasicOperator(Operator):
5858
situations where the performance impact might be detrimental.
5959
"""
6060

61+
CIRE_BLOCK_TEMPS = True
62+
"""
63+
If an aliasing expression is computed within a blocked loop nest, all CIRE-
64+
generated temporaries will inherit the block shape. If set to False, the
65+
temporaries shape will systematically be defined by the root Dimensions.
66+
"""
67+
6168
CIRE_MINGAIN = 10
6269
"""
6370
Minimum operation count reduction for a redundant expression to be optimized

devito/ir/clusters/cluster.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
1515
from devito.mpi.reduction_scheme import DistReduce
16-
from devito.symbolics import estimate_cost
16+
from devito.symbolics import estimate_cost, uxreplace
1717
from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype
1818
from devito.types import (
1919
CriticalRegion, Fence, Indexed, PhaseMarker, TensorMove, ThreadArrive, ThreadCommit,
@@ -128,6 +128,33 @@ def rebuild(self, *args, **kwargs):
128128
syncs=kwargs.get('syncs', self.syncs),
129129
halo_scheme=kwargs.get('halo_scheme', self.halo_scheme))
130130

131+
def subs(self, mapper, compact=()):
132+
"""
133+
Build a new Cluster applying substitutions rules to `self`.
134+
"""
135+
if not mapper:
136+
return self
137+
138+
if self.halo_scheme:
139+
raise NotImplementedError
140+
141+
key0 = lambda i: i.is_Block
142+
subs0 = {d: self.ispace[d].promote(key0).dim for d in compact}
143+
144+
subs = {**mapper, **subs0}
145+
exprs = [uxreplace(e, subs) for e in self.exprs]
146+
147+
ispace = self.ispace.switch(mapper)
148+
key = lambda i: key0(i) and i in flatten(d._defines for d in subs0)
149+
ispace = ispace.promote(key, mode='total')
150+
151+
guards = self.guards.subs(mapper).promote(subs0)
152+
properties = self.properties.subs(mapper).promote(subs0)
153+
syncs = self.syncs.subs(mapper)
154+
155+
return self.__class__(exprs=exprs, ispace=ispace, guards=guards,
156+
properties=properties, syncs=syncs)
157+
131158
@property
132159
def exprs(self):
133160
return self._exprs
@@ -591,6 +618,14 @@ def dspace(self):
591618
"""Return the DataSpace of this ClusterGroup."""
592619
return DataSpace.union(*[i.dspace.reset() for i in self])
593620

621+
@property
622+
def is_dense(self):
623+
return all(i.is_dense for i in self)
624+
625+
@property
626+
def is_wild(self):
627+
return all(i.is_wild for i in self)
628+
594629
@property
595630
def is_halo_touch(self):
596631
return all(i.is_halo_touch for i in self)

devito/ir/support/guards.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,31 @@ def as_map(self, d, cls):
342342

343343
return dict(i.args for i in search(self.get(d), cls))
344344

345+
def subs(self, mapper):
346+
m = {mapper.get(d, d): v.xreplace(mapper) for d, v in self.items()}
347+
348+
return Guards(m)
349+
350+
def promote(self, subs):
351+
m = self
352+
for d, v in subs.items():
353+
guards = {self.get(i) for i in d._defines} - {true}
354+
if len(guards) > 1:
355+
raise NotImplementedError(
356+
f"Cannot promote {d} to {v} due to multiple guards: {guards}"
357+
)
358+
elif len(guards) == 0:
359+
continue
360+
361+
guard = guards.pop()
362+
guard = guard.xreplace({d: v})
363+
364+
m = m.impose(v, guard)
365+
366+
m = m.popany(subs)
367+
368+
return m
369+
345370

346371
class GuardExpr(LocalObject, BooleanFunction):
347372

devito/ir/support/properties.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ def __init__(self, name, val=None):
3939
TILABLE = Property('tilable')
4040
"""A fully parallel Dimension that would benefit from tiling (or "blocking")."""
4141

42-
TILABLE_SMALL = Property('tilable*')
42+
NO_TUNING = Property('notuning')
4343
"""
44-
Like TILABLE, but it would benefit from relatively small block, since the
45-
iteration space is likely to be very small.
44+
A Dimension that would be unlikely to benefit from tuning. For example, the
45+
underlying iteration space is relatively small, and/or the enclosed expressions
46+
are not characterized by data reuse, etc.
4647
"""
4748

4849
SKEWABLE = Property('skewable')
@@ -115,7 +116,6 @@ def __init__(self, name, val=None):
115116

116117
# Bundles
117118
PARALLELS = {PARALLEL, PARALLEL_INDEP, PARALLEL_IF_ATOMIC, PARALLEL_IF_PVT}
118-
TILABLES = {TILABLE, TILABLE_SMALL}
119119

120120

121121
def normalize_properties(*args):
@@ -253,16 +253,10 @@ def prefetchable(self, dims, v=PREFETCHABLE):
253253
m[d] = self.get(d, set()) | {v}
254254
return Properties(m)
255255

256-
def block(self, dims, kind='default'):
257-
if kind == 'default':
258-
p = TILABLE
259-
elif kind == 'small':
260-
p = TILABLE_SMALL
261-
else:
262-
raise ValueError
256+
def block(self, dims):
263257
m = dict(self)
264258
for d in as_tuple(dims):
265-
m[d] = set(self.get(d, [])) | {p}
259+
m[d] = set(self.get(d, [])) | {TILABLE}
266260
return Properties(m)
267261

268262
def inbound(self, dims):
@@ -289,6 +283,9 @@ def init_halo_right_shm(self, dims):
289283
INIT_HALO_LEFT_SHM})
290284
return properties
291285

286+
def notune(self, dims):
287+
return self.add(dims, NO_TUNING)
288+
292289
def is_parallel(self, dims):
293290
return any(len(self[d] & {PARALLEL, PARALLEL_INDEP}) > 0
294291
for d in as_tuple(dims))
@@ -310,10 +307,7 @@ def is_sequential(self, dims):
310307
return any(SEQUENTIAL in self.get(d, ()) for d in as_tuple(dims))
311308

312309
def is_blockable(self, d):
313-
return bool(self.get(d, set()) & {TILABLE, TILABLE_SMALL})
314-
315-
def is_blockable_small(self, d):
316-
return TILABLE_SMALL in self.get(d, set())
310+
return bool(TILABLE in self.get(d, ()))
317311

318312
def _is_property_any(self, dims, v):
319313
if dims is None:
@@ -335,6 +329,25 @@ def is_halo_right_init(self, dims=None):
335329
def is_halo_init(self, dims=None):
336330
return self.is_halo_left_init(dims) or self.is_halo_right_init(dims)
337331

332+
def avoid_tuning(self, dims):
333+
return any(NO_TUNING in self.get(d, set()) for d in as_tuple(dims))
334+
335+
def subs(self, mapper):
336+
return Properties({mapper.get(d, d): v for d, v in self.items()})
337+
338+
def promote(self, subs):
339+
m = self
340+
for d, pd in subs.items():
341+
if pd not in d._defines:
342+
raise ValueError(f"Cannot promote {d} to {pd} as {pd} does not "
343+
f"belong to {d}'s hierarchy")
344+
345+
v = normalize_properties(*[self.get(i, set()) for i in d._defines])
346+
347+
m = self.drop(d._defines).add(pd, v)
348+
349+
return m
350+
338351
@property
339352
def nblockable(self):
340353
return sum([self.is_blockable(d) for d in self])

devito/ir/support/space.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,12 @@ def relaxed(self, d=None):
478478

479479
def promote(self, cond, mode='unordered'):
480480
intervals = [i.promote(cond) for i in self]
481-
intervals = IntervalGroup(intervals, relations=None, mode=mode)
481+
482+
if mode == 'total':
483+
relations = [filter_ordered(i.dim for i in intervals)]
484+
else:
485+
relations = None
486+
intervals = IntervalGroup(intervals, relations=relations, mode=mode)
482487

483488
# There could be duplicate Dimensions at this point, so we sum up the
484489
# Intervals defined over the same Dimension to produce a well-defined
@@ -899,16 +904,29 @@ def augment(self, sub_iterators):
899904

900905
return IterationSpace(self.intervals, items, self.directions)
901906

902-
def switch(self, d0, d1, direction=None):
903-
intervals = self.intervals.switch(d0, d1)
907+
def switch(self, d0, d1=None, direction=None):
908+
"""
909+
Construct a new IterationSpace in which the Dimension `d0` is replaced with
910+
`d1`. Optionally, `d0` could be a mapper, in which case multiple Dimensions
911+
may be switched.
912+
"""
913+
if isinstance(d0, dict):
914+
mapper = d0
915+
else:
916+
mapper = {d0: d1}
904917

918+
intervals = self.intervals
905919
sub_iterators = dict(self.sub_iterators)
906-
sub_iterators.pop(d0, None)
907-
sub_iterators[d1] = ()
908-
909920
directions = dict(self.directions)
910-
v = directions.pop(d0, None)
911-
directions[d1] = direction or v
921+
922+
for d0, d1 in mapper.items():
923+
intervals = intervals.switch(d0, d1)
924+
925+
sub_iterators.pop(d0, None)
926+
sub_iterators[d1] = ()
927+
928+
v = directions.pop(d0, None)
929+
directions[d1] = direction or v
912930

913931
return IterationSpace(intervals, sub_iterators, directions)
914932

devito/ir/support/syncs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def update(self, ops):
175175
m[d] = set(self.get(d, [])) | set(v)
176176
return Ops(m)
177177

178+
def subs(self, mapper):
179+
return Ops({mapper.get(d, d): v for d, v in self.items()})
180+
178181
def _get_sync(self, cls, dims=None):
179182
if dims is None:
180183
dims = list(self)

0 commit comments

Comments
 (0)