Skip to content

Commit 0483968

Browse files
Merge pull request #1736 from devitocodes/fuse-withlocks
compiler: Add optimization option to fuse WithLocks tasks
2 parents 0b84fe1 + f114164 commit 0483968

9 files changed

Lines changed: 134 additions & 43 deletions

File tree

devito/core/cpu.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def _normalize_kwargs(cls, **kwargs):
7878
# Buffering
7979
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
8080

81+
# Fusion
82+
o['fuse-tasks'] = oo.pop('fuse-tasks', False)
83+
8184
# Blocking
8285
o['blockinner'] = oo.pop('blockinner', False)
8386
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
@@ -298,13 +301,13 @@ def callback(f):
298301
'blocking': lambda i: blocking(i, options),
299302
'factorize': factorize,
300303
'fission': fission,
301-
'fuse': fuse,
304+
'fuse': lambda i: fuse(i, options=options),
302305
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
303306
options, platform)),
304307
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
305308
'cse': lambda i: cse(i, sregistry),
306309
'opt-pows': optimize_pows,
307-
'topofuse': lambda i: fuse(i, toposort=True)
310+
'topofuse': lambda i: fuse(i, toposort=True, options=options)
308311
}
309312

310313
@classmethod

devito/core/gpu.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def _normalize_kwargs(cls, **kwargs):
6161
# Buffering
6262
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
6363

64+
# Fusion
65+
o['fuse-tasks'] = oo.pop('fuse-tasks', False)
66+
6467
# Blocking
6568
o['blockinner'] = oo.pop('blockinner', True)
6669
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
@@ -148,7 +151,7 @@ def _specialize_clusters(cls, clusters, **kwargs):
148151
sregistry = kwargs['sregistry']
149152

150153
# Toposort+Fusion (the former to expose more fusion opportunities)
151-
clusters = fuse(clusters, toposort=True)
154+
clusters = fuse(clusters, toposort=True, options=options)
152155

153156
# Fission to increase parallelism
154157
clusters = fission(clusters)
@@ -245,13 +248,13 @@ def callback(f):
245248
'streaming': Streaming(reads_if_on_host).process,
246249
'factorize': factorize,
247250
'fission': fission,
248-
'fuse': fuse,
251+
'fuse': lambda i: fuse(i, options=options),
249252
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
250253
options, platform)),
251254
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
252255
'cse': lambda i: cse(i, sregistry),
253256
'opt-pows': optimize_pows,
254-
'topofuse': lambda i: fuse(i, toposort=True)
257+
'topofuse': lambda i: fuse(i, toposort=True, options=options)
255258
}
256259

257260
@classmethod

devito/ir/iet/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
'MetaCall', 'PointerCast', 'ForeignExpression', 'HaloSpot', 'IterationTree',
2323
'ExpressionBundle', 'AugmentedExpression', 'Increment', 'Return', 'While',
2424
'ParallelIteration', 'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot',
25-
'PragmaTransfer', 'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait',
26-
'CallableBody']
25+
'Pragma', 'PragmaTransfer', 'DummyExpr', 'BlankLine', 'ParallelTree',
26+
'BusyWait', 'CallableBody']
2727

2828
# First-class IET nodes
2929

devito/passes/clusters/misc.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import Counter
1+
from collections import Counter, defaultdict
22
from itertools import groupby, product
33

44
from devito.ir.clusters import Cluster, ClusterGroup, Queue
@@ -91,9 +91,13 @@ class Fusion(Queue):
9191
Fuse Clusters with compatible IterationSpace.
9292
"""
9393

94-
def __init__(self, toposort):
95-
super(Fusion, self).__init__()
94+
def __init__(self, toposort, options=None):
95+
options = options or {}
96+
9697
self.toposort = toposort
98+
self.fusetasks = options.get('fuse-tasks', False)
99+
100+
super().__init__()
97101

98102
def _make_key_hook(self, cgroup, level):
99103
assert level > 0
@@ -137,15 +141,24 @@ def _key(self, c):
137141

138142
key = (frozenset(c.itintervals), c.guards)
139143

140-
# We allow fusing Clusters/ClusterGroups with WaitLocks over different Locks,
141-
# while the WithLocks are to be kept separated (i.e. the remain separate tasks)
144+
# We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and
145+
# WithLocks, but not with any other SyncOps
142146
if isinstance(c, Cluster):
143147
sync_locks = (c.sync_locks,)
144148
else:
145149
sync_locks = c.sync_locks
146150
for i in sync_locks:
147-
key += (frozendict({k: frozenset(type(i) if i.is_WaitLock else i for i in v)
148-
for k, v in i.items()}),)
151+
mapper = defaultdict(set)
152+
for k, v in i.items():
153+
for s in v:
154+
if s.is_WaitLock or \
155+
(self.fusetasks and s.is_WithLock):
156+
mapper[k].add(type(s))
157+
else:
158+
mapper[k].add(s)
159+
mapper[k] = frozenset(mapper[k])
160+
mapper = frozendict(mapper)
161+
key += (mapper,)
149162

150163
return key
151164

@@ -243,14 +256,14 @@ def _build_dag(self, cgroups, prefix):
243256

244257

245258
@timed_pass()
246-
def fuse(clusters, toposort=False):
259+
def fuse(clusters, toposort=False, options=None):
247260
"""
248261
Clusters fusion.
249262
250263
If ``toposort=True``, then the Clusters are reordered to maximize the likelihood
251264
of fusion; the new ordering is computed such that all data dependencies are honored.
252265
"""
253-
return Fusion(toposort=toposort).process(clusters)
266+
return Fusion(toposort, options).process(clusters)
254267

255268

256269
@cluster_pass(mode='all')

devito/passes/iet/langbase.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ def _map_present(cls, f, imask=None):
7171
"""
7272
raise NotImplementedError
7373

74+
@classmethod
75+
def _map_wait(cls, queueid=None):
76+
"""
77+
Explicitly wait on event.
78+
"""
79+
raise NotImplementedError
80+
7481
@classmethod
7582
def _map_update(cls, f, imask=None):
7683
"""
@@ -86,9 +93,9 @@ def _map_update_host(cls, f, imask=None, queueid=None):
8693
raise NotImplementedError
8794

8895
@classmethod
89-
def _map_update_wait_host(cls, f, imask=None, queueid=None):
96+
def _map_update_host_async(cls, f, imask=None, queueid=None):
9097
"""
91-
Copy Function from device to host memory and explicitly wait.
98+
Asynchronously copy Function from device to host memory.
9299
"""
93100
raise NotImplementedError
94101

@@ -100,9 +107,9 @@ def _map_update_device(cls, f, imask=None, queueid=None):
100107
raise NotImplementedError
101108

102109
@classmethod
103-
def _map_update_wait_device(cls, f, imask=None, queueid=None):
110+
def _map_update_device_async(cls, f, imask=None, queueid=None):
104111
"""
105-
Copy Function from host to device memory and explicitly wait.
112+
Asynchronously copy Function from host to device memory and explicitly wait.
106113
"""
107114
raise NotImplementedError
108115

devito/passes/iet/languages/openacc.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,18 @@ class AccBB(PragmaLangBB):
9393
c.Pragma('acc enter data create(%s%s)' % (i, j)),
9494
'map-present': lambda i, j:
9595
c.Pragma('acc data present(%s%s)' % (i, j)),
96+
'map-wait': lambda i:
97+
c.Pragma('acc wait(%s)' % i),
9698
'map-update': lambda i, j:
9799
c.Pragma('acc exit data copyout(%s%s)' % (i, j)),
98100
'map-update-host': lambda i, j:
99101
c.Pragma('acc update self(%s%s)' % (i, j)),
100-
'map-update-wait-host': lambda i, j, k:
101-
(c.Pragma('acc update self(%s%s) async(%s)' % (i, j, k)),
102-
c.Pragma('acc wait(%s)' % k)),
102+
'map-update-host-async': lambda i, j, k:
103+
c.Pragma('acc update self(%s%s) async(%s)' % (i, j, k)),
103104
'map-update-device': lambda i, j:
104105
c.Pragma('acc update device(%s%s)' % (i, j)),
105-
'map-update-wait-device': lambda i, j, k:
106-
(c.Pragma('acc update device(%s%s) async(%s)' % (i, j, k)),
107-
c.Pragma('acc wait(%s)' % k)),
106+
'map-update-device-async': lambda i, j, k:
107+
c.Pragma('acc update device(%s%s) async(%s)' % (i, j, k)),
108108
'map-release': lambda i, j, k:
109109
c.Pragma('acc exit data delete(%s%s)%s' % (i, j, k)),
110110
'map-exit-delete': lambda i, j, k:
@@ -147,14 +147,14 @@ def _map_delete(cls, f, imask=None, devicerm=None):
147147
return cls.mapper['map-exit-delete'](f.name, sections, cond)
148148

149149
@classmethod
150-
def _map_update_wait_host(cls, f, imask=None, queueid=None):
150+
def _map_update_host_async(cls, f, imask=None, queueid=None):
151151
sections = cls._make_sections_from_imask(f, imask)
152-
return cls.mapper['map-update-wait-host'](f.name, sections, queueid)
152+
return cls.mapper['map-update-host-async'](f.name, sections, queueid)
153153

154154
@classmethod
155-
def _map_update_wait_device(cls, f, imask=None, queueid=None):
155+
def _map_update_device_async(cls, f, imask=None, queueid=None):
156156
sections = cls._make_sections_from_imask(f, imask)
157-
return cls.mapper['map-update-wait-device'](f.name, sections, queueid)
157+
return cls.mapper['map-update-device-async'](f.name, sections, queueid)
158158

159159

160160
class DeviceAccizer(PragmaDeviceAwareTransformer):

devito/passes/iet/orchestration.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from devito.data import FULL
77
from devito.ir.iet import (Call, Callable, Conditional, List, SyncSpot, FindNodes,
8-
Transformer, BlankLine, BusyWait, PragmaTransfer,
8+
Transformer, BlankLine, BusyWait, Pragma, PragmaTransfer,
99
DummyExpr, derive_parameters, make_thread_ctx)
1010
from devito.passes.iet.engine import iet_pass
1111
from devito.passes.iet.langbase import LangBB
@@ -54,17 +54,21 @@ def _make_withlock(self, iet, sync_ops, pieces, root):
5454
# will never be more than 2 threads in flight concurrently
5555
npthreads = min(i.size for i in locks)
5656

57-
preactions = []
58-
postactions = []
57+
preactions = [BlankLine]
5958
for s in sync_ops:
6059
imask = [s.handle.indices[d] if d.root in s.lock.locked_dimensions else FULL
6160
for d in s.target.dimensions]
62-
update = PragmaTransfer(self.lang._map_update_wait_host, s.target,
61+
update = PragmaTransfer(self.lang._map_update_host_async, s.target,
6362
imask=imask, queueid=SharedData._field_id)
64-
preactions.append(List(body=[BlankLine, update, DummyExpr(s.handle, 1)]))
65-
postactions.append(DummyExpr(s.handle, 2))
63+
preactions.append(update)
64+
wait = self.lang._map_wait(SharedData._field_id)
65+
if wait is not None:
66+
preactions.append(Pragma(wait))
67+
preactions.extend([DummyExpr(s.handle, 1) for s in sync_ops])
6668
preactions.append(BlankLine)
67-
postactions.insert(0, BlankLine)
69+
70+
postactions = [BlankLine]
71+
postactions.extend([DummyExpr(s.handle, 2) for s in sync_ops])
6872

6973
# Turn `iet` into a ThreadFunction so that it can be executed
7074
# asynchronously by a pthread in the `npthreads` pool
@@ -120,7 +124,7 @@ def _make_fetchupdate(self, iet, sync_ops, pieces, *args):
120124
def _make_prefetchupdate(self, iet, sync_ops, pieces, root):
121125
fid = SharedData._field_id
122126

123-
postactions = []
127+
postactions = [BlankLine]
124128
for s in sync_ops:
125129
# `pcond` is not None, but we won't use it here because the condition
126130
# is actually already encoded in `iet` itself (it stems from the
@@ -129,8 +133,11 @@ def _make_prefetchupdate(self, iet, sync_ops, pieces, root):
129133

130134
imask = [(s.tstore, s.size) if d.root is s.dim.root else FULL
131135
for d in s.dimensions]
132-
postactions.append(PragmaTransfer(self.lang._map_update_wait_device,
136+
postactions.append(PragmaTransfer(self.lang._map_update_device_async,
133137
s.target, imask=imask, queueid=fid))
138+
wait = self.lang._map_wait(fid)
139+
if wait is not None:
140+
postactions.append(Pragma(wait))
134141

135142
# Turn prefetch IET into a ThreadFunction
136143
name = self.sregistry.make_name(prefix='prefetch_host_to_device')
@@ -156,8 +163,8 @@ def _make_waitprefetch(self, iet, sync_ops, pieces, *args):
156163
ff = SharedData._field_flag
157164

158165
waits = []
159-
for s in sync_ops:
160-
sdata, threads = pieces.objs.get(s)
166+
objs = filter_ordered(pieces.objs.get(s) for s in sync_ops)
167+
for sdata, threads in objs:
161168
wait = BusyWait(CondNe(FieldFromComposite(ff, sdata[threads.index]), 1))
162169
waits.append(wait)
163170

devito/passes/iet/parpragma.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,14 @@ def _map_alloc(cls, f, imask=None):
536536
def _map_present(cls, f, imask=None):
537537
return
538538

539+
@classmethod
540+
def _map_wait(cls, queueid=None):
541+
try:
542+
return cls.mapper['map-wait'](queueid)
543+
except KeyError:
544+
# Not all languages may provide an explicit wait construct
545+
return None
546+
539547
@classmethod
540548
def _map_update(cls, f, imask=None):
541549
sections = cls._make_sections_from_imask(f, imask)
@@ -546,14 +554,14 @@ def _map_update_host(cls, f, imask=None, queueid=None):
546554
sections = cls._make_sections_from_imask(f, imask)
547555
return cls.mapper['map-update-host'](f.name, sections)
548556

549-
_map_update_wait_host = _map_update_host
557+
_map_update_host_async = _map_update_host
550558

551559
@classmethod
552560
def _map_update_device(cls, f, imask=None, queueid=None):
553561
sections = cls._make_sections_from_imask(f, imask)
554562
return cls.mapper['map-update-device'](f.name, sections)
555563

556-
_map_update_wait_device = _map_update_device
564+
_map_update_device_async = _map_update_device
557565

558566
@classmethod
559567
def _map_release(cls, f, imask=None, devicerm=None):

tests/test_gpu_common.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,56 @@ def test_tasking_unfused_two_locks(self):
211211
assert np.all(u.data[nt-1] == 9)
212212
assert np.all(v.data[nt-1] == 9)
213213

214+
def test_tasking_forcefuse(self):
215+
nt = 10
216+
bundle0 = Bundle()
217+
grid = Grid(shape=(10, 10, 10), subdomains=bundle0)
218+
219+
tmp0 = Function(name='tmp0', grid=grid)
220+
tmp1 = Function(name='tmp1', grid=grid)
221+
u = TimeFunction(name='u', grid=grid, save=nt)
222+
v = TimeFunction(name='v', grid=grid, save=nt)
223+
w = TimeFunction(name='w', grid=grid)
224+
225+
eqns = [Eq(w.forward, w + 1),
226+
Eq(tmp0, w.forward),
227+
Eq(tmp1, w.forward),
228+
Eq(u.forward, tmp0, subdomain=bundle0),
229+
Eq(v.forward, tmp1, subdomain=bundle0)]
230+
231+
op = Operator(eqns, opt=('tasking', 'fuse', 'orchestrate', {'fuse-tasks': True}))
232+
233+
# Check generated code
234+
assert len(retrieve_iteration_tree(op)) == 5
235+
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 2
236+
sections = FindNodes(Section).visit(op)
237+
assert len(sections) == 3
238+
assert (str(sections[1].body[0].body[0].body[0].body[0]) ==
239+
'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock
240+
body = sections[2].body[0].body[0]
241+
assert (str(body.body[1].condition) ==
242+
'Ne(lock0[0], 2) | '
243+
'Ne(lock1[0], 2) | '
244+
'Ne(FieldFromComposite(sdata0[wi0]), 1)') # Wait-thread
245+
assert (str(body.body[1].body[0]) ==
246+
'wi0 = (wi0 + 1)%(npthreads0);')
247+
assert str(body.body[2]) == 'sdata0[wi0].time = time;'
248+
assert str(body.body[3]) == 'lock0[0] = 0;' # Set-lock
249+
assert str(body.body[4]) == 'lock1[0] = 0;' # Set-lock
250+
assert str(body.body[5]) == 'sdata0[wi0].flag = 2;'
251+
assert len(op._func_table) == 2
252+
exprs = FindNodes(Expression).visit(op._func_table['copy_device_to_host0'].root)
253+
assert len(exprs) == 22
254+
assert str(exprs[15]) == 'lock0[0] = 1;'
255+
assert str(exprs[16]) == 'lock1[0] = 1;'
256+
assert exprs[17].write is u
257+
assert exprs[18].write is v
258+
259+
op.apply(time_M=nt-2)
260+
261+
assert np.all(u.data[nt-1] == 9)
262+
assert np.all(v.data[nt-1] == 9)
263+
214264
@pytest.mark.parametrize('opt', [
215265
('tasking', 'orchestrate'),
216266
('tasking', 'streaming', 'orchestrate'),

0 commit comments

Comments
 (0)