Skip to content

Commit 7eefe42

Browse files
Merge pull request #2948 from devitocodes/smarter-tuner-3
compiler: Enable blocking-before-CIRE
2 parents c01281c + df8672d commit 7eefe42

26 files changed

Lines changed: 701 additions & 383 deletions

devito/arch/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def __init_finalize__(self, **kwargs):
769769
if not configuration['safe-math']:
770770
self.cflags.append('--use_fast_math')
771771

772-
if configuration['profiling'] == 'advanced2':
772+
if configuration['profiling'] in ('advanced2', 'ncu'):
773773
# Optionally print out per-kernel shared memory and register usage
774774
self.cflags.append('--ptxas-options=-v')
775775

devito/core/cpu.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from devito.operator.operator import rcompile
66
from devito.passes import stream_dimensions
77
from devito.passes.clusters import (
8-
Lift, blocking, buffering, cire, cse, factorize, fission, fuse, optimize_hyperplanes,
9-
optimize_pows
8+
Lift, apply_par_tiles, blocking, buffering, cire, cse, factorize, fission, fuse,
9+
optimize_hyperplanes, optimize_pows
1010
)
1111
from devito.passes.equations import collect_derivatives
1212
from devito.passes.iet import (
@@ -67,6 +67,7 @@ def _normalize_kwargs(cls, **kwargs):
6767
reduce=oo.pop('par-tile-reduce', None))
6868

6969
# CIRE
70+
o['cire-block-temps'] = oo.pop('cire-block-temps', cls.CIRE_BLOCK_TEMPS)
7071
o['min-storage'] = oo.pop('min-storage', False)
7172
o['cire-rotate'] = oo.pop('cire-rotate', False)
7273
o['cire-maxpar'] = oo.pop('cire-maxpar', False)
@@ -198,6 +199,9 @@ def _specialize_clusters(cls, clusters, **kwargs):
198199
if options['blocklazy']:
199200
clusters = blocking(clusters, sregistry, options)
200201

202+
# Unfold the `par-tile`s, if any
203+
clusters = apply_par_tiles(clusters, **kwargs)
204+
201205
return clusters
202206

203207
@classmethod

devito/core/gpu.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from devito.operator.operator import rcompile
88
from devito.passes import is_on_device, stream_dimensions
99
from devito.passes.clusters import (
10-
Lift, blocking, buffering, cire, cse, factorize, fission, fuse, memcpy_prefetch,
11-
optimize_pows, tasking
10+
Lift, apply_par_tiles, blocking, buffering, cire, cse, factorize, fission, fuse,
11+
memcpy_prefetch, optimize_pows, tasking
1212
)
1313
from devito.passes.equations import collect_derivatives
1414
from devito.passes.iet import (
@@ -38,7 +38,9 @@
3838

3939
class DeviceOperatorMixin:
4040

41+
# Overrides the default values in the main Operator class
4142
BLOCK_LEVELS = 0
43+
CIRE_BLOCK_TEMPS = False
4244
MPI_MODES = (True, 'basic',)
4345

4446
GPU_FIT = 'all-fallback'
@@ -76,9 +78,10 @@ def _normalize_kwargs(cls, **kwargs):
7678
o['skewing'] = oo.pop('skewing', False)
7779

7880
# CIRE
81+
o['cire-block-temps'] = oo.pop('cire-block-temps', cls.CIRE_BLOCK_TEMPS)
7982
o['min-storage'] = False
8083
o['cire-rotate'] = False
81-
o['cire-maxpar'] = oo.pop('cire-maxpar', True)
84+
o['cire-maxpar'] = oo.pop('cire-maxpar', 'basic')
8285
o['cire-ftemps'] = oo.pop('cire-ftemps', False)
8386
o['cire-mingain'] = oo.pop('cire-mingain', cls.CIRE_MINGAIN)
8487
o['cire-minmem'] = oo.pop('cire-minmem', cls.CIRE_MINMEM)
@@ -239,6 +242,9 @@ def _specialize_clusters(cls, clusters, **kwargs):
239242
if options['blocklazy']:
240243
clusters = blocking(clusters, sregistry, options)
241244

245+
# Unfold the `par-tile`s, if any
246+
clusters = apply_par_tiles(clusters, **kwargs)
247+
242248
return clusters
243249

244250
@classmethod

devito/core/operator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ class BasicOperator(Operator):
5858
situations where the performance impact might be detrimental.
5959
"""
6060

61+
CIRE_BLOCK_TEMPS = True
62+
"""
63+
If an aliasing expression is computed within a blocked loop nest, all CIRE-
64+
generated temporaries will inherit the block shape. If set to False, the
65+
temporaries shape will systematically be defined by the root Dimensions.
66+
"""
67+
6168
CIRE_MINGAIN = 10
6269
"""
6370
Minimum operation count reduction for a redundant expression to be optimized
@@ -240,6 +247,9 @@ def _check_kwargs(cls, **kwargs):
240247
if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES:
241248
raise InvalidOperator(f"Unsupported MPI mode `{oo['mpi']}`")
242249

250+
if oo['cire-maxpar'] not in (False, 'basic', 'compact'):
251+
raise InvalidOperator("Illegal `cire-maxpar` value")
252+
243253
if oo['cse-algo'] not in ('basic', 'smartsort', 'advanced'):
244254
raise InvalidOperator("Illegal `cse-algo` value")
245255

devito/finite_differences/differentiable.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,8 +985,9 @@ def value(self, idx):
985985
class IndexDerivative(IndexSum):
986986

987987
__rargs__ = ('expr', 'mapper')
988+
__rkwargs__ = IndexSum.__rkwargs__ + ('deriv_order',)
988989

989-
def __new__(cls, expr, mapper, **kwargs):
990+
def __new__(cls, expr, mapper, deriv_order=None, **kwargs):
990991
dimensions = as_tuple(set(mapper.values()))
991992

992993
# Detect the Weights among the arguments
@@ -1008,6 +1009,8 @@ def __new__(cls, expr, mapper, **kwargs):
10081009
obj._weights = weights
10091010
obj._mapper = frozendict(mapper)
10101011

1012+
obj._deriv_order = deriv_order
1013+
10111014
return obj
10121015

10131016
def _hashable_content(self):
@@ -1040,6 +1043,10 @@ def weights(self):
10401043
def mapper(self):
10411044
return self._mapper
10421045

1046+
@property
1047+
def deriv_order(self):
1048+
return self._deriv_order
1049+
10431050
@property
10441051
def depth(self):
10451052
iderivs = self.expr.find(IndexDerivative)
@@ -1216,7 +1223,8 @@ def _diff2sympy(obj):
12161223

12171224
# Handle special objects
12181225
if isinstance(obj, DiffDerivative):
1219-
return IndexDerivative(*args, obj.mapper), True
1226+
return IndexDerivative(*args, obj.mapper,
1227+
deriv_order=obj.deriv_order), True
12201228

12211229
# Handle generic objects such as arithmetic operations
12221230
try:

devito/finite_differences/finite_difference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
210210
with suppress(AttributeError):
211211
expr = expr._evaluate(expand=False)
212212

213-
deriv = DiffDerivative(expr*weights, {dim: indices.free_dim})
213+
deriv = DiffDerivative(
214+
expr*weights, {dim: indices.free_dim}, deriv_order=deriv_order
215+
)
214216
else:
215217
terms = []
216218
for i, c in zip(indices, weights, strict=True):

devito/ir/clusters/cluster.py

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
1515
from devito.mpi.reduction_scheme import DistReduce
16-
from devito.symbolics import estimate_cost
16+
from devito.symbolics import estimate_cost, uxreplace
1717
from devito.tools import (
1818
CacheInstances, as_tuple, cached_hash, filter_ordered, flatten, infer_dtype
1919
)
@@ -112,7 +112,7 @@ def dimensions(self):
112112
@cached_property
113113
def exprs_dimensions(self):
114114
"""
115-
The Dimensions that appear explicitly in the Cluster expressions.
115+
The Dimensions that appear explicitly in the expressions.
116116
"""
117117
dims_explicit = {i for i in self.free_symbols if i.is_Dimension}
118118
dims_implicit = {d for e in self.exprs for d in e.implicit_dims}
@@ -121,7 +121,7 @@ def exprs_dimensions(self):
121121
@cached_property
122122
def guards_dimensions(self):
123123
"""
124-
The Dimensions that appear explicitly in the Cluster guards.
124+
The Dimensions that appear explicitly in the guards.
125125
"""
126126
syms_guards = {d for e in self.guards.values() for d in e.free_symbols}
127127
dims_guards = {i for i in syms_guards if i.is_Dimension}
@@ -142,7 +142,7 @@ def used_dimensions(self):
142142
@cached_property
143143
def dist_dimensions(self):
144144
"""
145-
The Cluster's distributed Dimensions.
145+
The distributed Dimensions.
146146
"""
147147
ret = set()
148148
for f in self.functions:
@@ -168,7 +168,7 @@ def grid(self):
168168
elif len(grids) == 1:
169169
return grids.pop()
170170
else:
171-
raise ValueError("Cluster has no unique Grid")
171+
raise ValueError("Multiple Grids detected")
172172

173173
@cached_property
174174
def is_scalar(self):
@@ -296,31 +296,27 @@ def is_glb_load_to_mem_shared(self):
296296
@cached_property
297297
def is_async(self):
298298
"""
299-
True if an asynchronous Cluster, False otherwise.
299+
True if asynchronous, False otherwise.
300300
"""
301301
return any(isinstance(s, (WithLock, PrefetchUpdate))
302302
for s in flatten(self.syncs.values()))
303303

304304
@cached_property
305305
def is_wait(self):
306306
"""
307-
True if a Cluster waiting on a lock (that is a special synchronization
308-
operation), False otherwise.
307+
True if waiting on a lock (that is a special synchronization operation),
308+
False otherwise.
309309
"""
310310
return any(isinstance(s, WaitLock)
311311
for s in flatten(self.syncs.values()))
312312

313313
@cached_property
314314
def dtype(self):
315315
"""
316-
The arithmetic data type of the Cluster.
316+
The arithmetic data type of the enclosed expressions.
317317
318-
If the Cluster performs floating point arithmetic, then the expressions
319-
performing integer arithmetic are ignored, assuming that they are only
320-
carrying out array index calculations.
321-
322-
If two expressions perform calculations with different precision,
323-
the data type with highest precision is returned.
318+
If two expressions perform calculations with different precision, the data
319+
type with highest precision is returned.
324320
"""
325321
dtypes = set()
326322
for i in self.exprs:
@@ -336,8 +332,8 @@ def dtype(self):
336332
@cached_property
337333
def dspace(self):
338334
"""
339-
Derive the DataSpace of the Cluster from its expressions,
340-
IterationSpace, and Guards.
335+
The DataSpace deriving from the enclosed expressions, IterationSpace,
336+
and Guards.
341337
"""
342338
accesses = detect_accesses(self.exprs)
343339

@@ -421,8 +417,8 @@ def ops(self):
421417
@cached_property
422418
def traffic(self):
423419
"""
424-
The Cluster compulsory traffic (number of reads/writes), as a mapper
425-
from Functions to IntervalGroups.
420+
The compulsory traffic (number of reads/writes), as a mapper from
421+
Functions to IntervalGroups.
426422
427423
Notes
428424
-----
@@ -509,30 +505,6 @@ def __getattr__(self, name):
509505
raise AttributeError(name) from None
510506
return getattr(block, name)
511507

512-
@property
513-
def exprs(self):
514-
return self._block.exprs
515-
516-
@property
517-
def ispace(self):
518-
return self._block.ispace
519-
520-
@property
521-
def guards(self):
522-
return self._block.guards
523-
524-
@property
525-
def properties(self):
526-
return self._block.properties
527-
528-
@property
529-
def syncs(self):
530-
return self._block.syncs
531-
532-
@property
533-
def halo_scheme(self):
534-
return self._block.halo_scheme
535-
536508
@classmethod
537509
def from_clusters(cls, *clusters):
538510
"""
@@ -612,6 +584,33 @@ def rebuild(self, *args, **kwargs):
612584
syncs=syncs,
613585
halo_scheme=halo_scheme)
614586

587+
def subs(self, mapper, compact=()):
588+
"""
589+
Build a new Cluster applying substitutions rules to `self`.
590+
"""
591+
if not mapper:
592+
return self
593+
594+
if self.halo_scheme:
595+
raise NotImplementedError
596+
597+
key0 = lambda i: i.is_Block
598+
subs0 = {d: self.ispace[d].promote(key0).dim for d in compact}
599+
600+
subs = {**mapper, **subs0}
601+
exprs = [uxreplace(e, subs) for e in self.exprs]
602+
603+
ispace = self.ispace.switch(mapper)
604+
key = lambda i: key0(i) and i in flatten(d._defines for d in subs0)
605+
ispace = ispace.promote(key, mode='total')
606+
607+
guards = self.guards.subs(mapper).promote(subs0)
608+
properties = self.properties.subs(mapper).promote(subs0)
609+
syncs = self.syncs.subs(mapper)
610+
611+
return self.__class__(exprs=exprs, ispace=ispace, guards=guards,
612+
properties=properties, syncs=syncs)
613+
615614

616615
class ClusterGroup(tuple):
617616

@@ -691,7 +690,15 @@ def dspace(self):
691690
"""Return the DataSpace of this ClusterGroup."""
692691
return DataSpace.union(*[i.dspace.reset() for i in self])
693692

694-
@property
693+
@cached_property
694+
def is_dense(self):
695+
return all(i.is_dense for i in self)
696+
697+
@cached_property
698+
def is_wild(self):
699+
return all(i.is_wild for i in self)
700+
701+
@cached_property
695702
def is_halo_touch(self):
696703
return all(i.is_halo_touch for i in self)
697704

devito/ir/clusters/visitors.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,7 @@ def _make_key(self, cluster, level):
3737
assert self._q_ispace_in_key
3838
ispace = cluster.ispace[:level]
3939

40-
if self._q_guards_in_key:
41-
try:
42-
guards = tuple(cluster.guards.get(i.dim) for i in ispace)
43-
except AttributeError:
44-
# `cluster` is actually a ClusterGroup
45-
assert len(cluster.guards) == 1
46-
guards = tuple(cluster.guards[0].get(i.dim) for i in ispace)
47-
else:
48-
guards = None
40+
guards = self._make_key_guards(cluster, ispace)
4941

5042
if self._q_properties_in_key:
5143
properties = cluster.properties.drop(cluster.ispace[level:].itdims)
@@ -68,6 +60,17 @@ def _make_key(self, cluster, level):
6860

6961
return (prefix,) + subkey
7062

63+
def _make_key_guards(self, cluster, ispace):
64+
if not self._q_guards_in_key:
65+
return None
66+
67+
try:
68+
return tuple(cluster.guards.get(i.dim) for i in ispace)
69+
except AttributeError:
70+
# `cluster` is actually a ClusterGroup
71+
assert len(cluster.guards) == 1
72+
return tuple(cluster.guards[0].get(i.dim) for i in ispace)
73+
7174
def _make_key_hook(self, cluster, level):
7275
return ()
7376

0 commit comments

Comments
 (0)