Skip to content

Commit f114164

Browse files
author
dummy
committed
compiler: Add fuse-tasks optimization option
1 parent b49845d commit f114164

4 files changed

Lines changed: 83 additions & 14 deletions

File tree

devito/core/cpu.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def _normalize_kwargs(cls, **kwargs):
7878
# Buffering
7979
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
8080

81+
# Fusion
82+
o['fuse-tasks'] = oo.pop('fuse-tasks', False)
83+
8184
# Blocking
8285
o['blockinner'] = oo.pop('blockinner', False)
8386
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
@@ -298,13 +301,13 @@ def callback(f):
298301
'blocking': lambda i: blocking(i, options),
299302
'factorize': factorize,
300303
'fission': fission,
301-
'fuse': fuse,
304+
'fuse': lambda i: fuse(i, options=options),
302305
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
303306
options, platform)),
304307
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
305308
'cse': lambda i: cse(i, sregistry),
306309
'opt-pows': optimize_pows,
307-
'topofuse': lambda i: fuse(i, toposort=True)
310+
'topofuse': lambda i: fuse(i, toposort=True, options=options)
308311
}
309312

310313
@classmethod

devito/core/gpu.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def _normalize_kwargs(cls, **kwargs):
6161
# Buffering
6262
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
6363

64+
# Fusion
65+
o['fuse-tasks'] = oo.pop('fuse-tasks', False)
66+
6467
# Blocking
6568
o['blockinner'] = oo.pop('blockinner', True)
6669
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
@@ -148,7 +151,7 @@ def _specialize_clusters(cls, clusters, **kwargs):
148151
sregistry = kwargs['sregistry']
149152

150153
# Toposort+Fusion (the former to expose more fusion opportunities)
151-
clusters = fuse(clusters, toposort=True)
154+
clusters = fuse(clusters, toposort=True, options=options)
152155

153156
# Fission to increase parallelism
154157
clusters = fission(clusters)
@@ -245,13 +248,13 @@ def callback(f):
245248
'streaming': Streaming(reads_if_on_host).process,
246249
'factorize': factorize,
247250
'fission': fission,
248-
'fuse': fuse,
251+
'fuse': lambda i: fuse(i, options=options),
249252
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
250253
options, platform)),
251254
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
252255
'cse': lambda i: cse(i, sregistry),
253256
'opt-pows': optimize_pows,
254-
'topofuse': lambda i: fuse(i, toposort=True)
257+
'topofuse': lambda i: fuse(i, toposort=True, options=options)
255258
}
256259

257260
@classmethod

devito/passes/clusters/misc.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import Counter
1+
from collections import Counter, defaultdict
22
from itertools import groupby, product
33

44
from devito.ir.clusters import Cluster, ClusterGroup, Queue
@@ -91,9 +91,13 @@ class Fusion(Queue):
9191
Fuse Clusters with compatible IterationSpace.
9292
"""
9393

94-
def __init__(self, toposort):
95-
super(Fusion, self).__init__()
94+
def __init__(self, toposort, options=None):
95+
options = options or {}
96+
9697
self.toposort = toposort
98+
self.fusetasks = options.get('fuse-tasks', False)
99+
100+
super().__init__()
97101

98102
def _make_key_hook(self, cgroup, level):
99103
assert level > 0
@@ -137,15 +141,24 @@ def _key(self, c):
137141

138142
key = (frozenset(c.itintervals), c.guards)
139143

140-
# We allow fusing Clusters/ClusterGroups with WaitLocks over different Locks,
141-
# while the WithLocks are to be kept separated (i.e. the remain separate tasks)
144+
# We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and
145+
# WithLocks, but not with any other SyncOps
142146
if isinstance(c, Cluster):
143147
sync_locks = (c.sync_locks,)
144148
else:
145149
sync_locks = c.sync_locks
146150
for i in sync_locks:
147-
key += (frozendict({k: frozenset(type(i) if i.is_WaitLock else i for i in v)
148-
for k, v in i.items()}),)
151+
mapper = defaultdict(set)
152+
for k, v in i.items():
153+
for s in v:
154+
if s.is_WaitLock or \
155+
(self.fusetasks and s.is_WithLock):
156+
mapper[k].add(type(s))
157+
else:
158+
mapper[k].add(s)
159+
mapper[k] = frozenset(mapper[k])
160+
mapper = frozendict(mapper)
161+
key += (mapper,)
149162

150163
return key
151164

@@ -243,14 +256,14 @@ def _build_dag(self, cgroups, prefix):
243256

244257

245258
@timed_pass()
246-
def fuse(clusters, toposort=False):
259+
def fuse(clusters, toposort=False, options=None):
247260
"""
248261
Clusters fusion.
249262
250263
If ``toposort=True``, then the Clusters are reordered to maximize the likelihood
251264
of fusion; the new ordering is computed such that all data dependencies are honored.
252265
"""
253-
return Fusion(toposort=toposort).process(clusters)
266+
return Fusion(toposort, options).process(clusters)
254267

255268

256269
@cluster_pass(mode='all')

tests/test_gpu_common.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,56 @@ def test_tasking_unfused_two_locks(self):
211211
assert np.all(u.data[nt-1] == 9)
212212
assert np.all(v.data[nt-1] == 9)
213213

214+
def test_tasking_forcefuse(self):
215+
nt = 10
216+
bundle0 = Bundle()
217+
grid = Grid(shape=(10, 10, 10), subdomains=bundle0)
218+
219+
tmp0 = Function(name='tmp0', grid=grid)
220+
tmp1 = Function(name='tmp1', grid=grid)
221+
u = TimeFunction(name='u', grid=grid, save=nt)
222+
v = TimeFunction(name='v', grid=grid, save=nt)
223+
w = TimeFunction(name='w', grid=grid)
224+
225+
eqns = [Eq(w.forward, w + 1),
226+
Eq(tmp0, w.forward),
227+
Eq(tmp1, w.forward),
228+
Eq(u.forward, tmp0, subdomain=bundle0),
229+
Eq(v.forward, tmp1, subdomain=bundle0)]
230+
231+
op = Operator(eqns, opt=('tasking', 'fuse', 'orchestrate', {'fuse-tasks': True}))
232+
233+
# Check generated code
234+
assert len(retrieve_iteration_tree(op)) == 5
235+
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 2
236+
sections = FindNodes(Section).visit(op)
237+
assert len(sections) == 3
238+
assert (str(sections[1].body[0].body[0].body[0].body[0]) ==
239+
'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock
240+
body = sections[2].body[0].body[0]
241+
assert (str(body.body[1].condition) ==
242+
'Ne(lock0[0], 2) | '
243+
'Ne(lock1[0], 2) | '
244+
'Ne(FieldFromComposite(sdata0[wi0]), 1)') # Wait-thread
245+
assert (str(body.body[1].body[0]) ==
246+
'wi0 = (wi0 + 1)%(npthreads0);')
247+
assert str(body.body[2]) == 'sdata0[wi0].time = time;'
248+
assert str(body.body[3]) == 'lock0[0] = 0;' # Set-lock
249+
assert str(body.body[4]) == 'lock1[0] = 0;' # Set-lock
250+
assert str(body.body[5]) == 'sdata0[wi0].flag = 2;'
251+
assert len(op._func_table) == 2
252+
exprs = FindNodes(Expression).visit(op._func_table['copy_device_to_host0'].root)
253+
assert len(exprs) == 22
254+
assert str(exprs[15]) == 'lock0[0] = 1;'
255+
assert str(exprs[16]) == 'lock1[0] = 1;'
256+
assert exprs[17].write is u
257+
assert exprs[18].write is v
258+
259+
op.apply(time_M=nt-2)
260+
261+
assert np.all(u.data[nt-1] == 9)
262+
assert np.all(v.data[nt-1] == 9)
263+
214264
@pytest.mark.parametrize('opt', [
215265
('tasking', 'orchestrate'),
216266
('tasking', 'streaming', 'orchestrate'),

0 commit comments

Comments
 (0)