55
66from devito .data import FULL
77from devito .ir .iet import (Call , Callable , Conditional , List , SyncSpot , FindNodes ,
8- Transformer , BlankLine , BusyWait , PragmaTransfer ,
8+ Transformer , BlankLine , BusyWait , Pragma , PragmaTransfer ,
99 DummyExpr , derive_parameters , make_thread_ctx )
1010from devito .passes .iet .engine import iet_pass
1111from devito .passes .iet .langbase import LangBB
@@ -54,17 +54,21 @@ def _make_withlock(self, iet, sync_ops, pieces, root):
5454 # will never be more than 2 threads in flight concurrently
5555 npthreads = min (i .size for i in locks )
5656
57- preactions = []
58- postactions = []
57+ preactions = [BlankLine ]
5958 for s in sync_ops :
6059 imask = [s .handle .indices [d ] if d .root in s .lock .locked_dimensions else FULL
6160 for d in s .target .dimensions ]
62- update = PragmaTransfer (self .lang ._map_update_wait_host , s .target ,
61+ update = PragmaTransfer (self .lang ._map_update_host_async , s .target ,
6362 imask = imask , queueid = SharedData ._field_id )
64- preactions .append (List (body = [BlankLine , update , DummyExpr (s .handle , 1 )]))
65- postactions .append (DummyExpr (s .handle , 2 ))
63+ preactions .append (update )
64+ wait = self .lang ._map_wait (SharedData ._field_id )
65+ if wait is not None :
66+ preactions .append (Pragma (wait ))
67+ preactions .extend ([DummyExpr (s .handle , 1 ) for s in sync_ops ])
6668 preactions .append (BlankLine )
67- postactions .insert (0 , BlankLine )
69+
70+ postactions = [BlankLine ]
71+ postactions .extend ([DummyExpr (s .handle , 2 ) for s in sync_ops ])
6872
6973 # Turn `iet` into a ThreadFunction so that it can be executed
7074 # asynchronously by a pthread in the `npthreads` pool
@@ -120,7 +124,7 @@ def _make_fetchupdate(self, iet, sync_ops, pieces, *args):
120124 def _make_prefetchupdate (self , iet , sync_ops , pieces , root ):
121125 fid = SharedData ._field_id
122126
123- postactions = []
127+ postactions = [BlankLine ]
124128 for s in sync_ops :
125129 # `pcond` is not None, but we won't use it here because the condition
126130 # is actually already encoded in `iet` itself (it stems from the
@@ -129,8 +133,11 @@ def _make_prefetchupdate(self, iet, sync_ops, pieces, root):
129133
130134 imask = [(s .tstore , s .size ) if d .root is s .dim .root else FULL
131135 for d in s .dimensions ]
132- postactions .append (PragmaTransfer (self .lang ._map_update_wait_device ,
136+ postactions .append (PragmaTransfer (self .lang ._map_update_device_async ,
133137 s .target , imask = imask , queueid = fid ))
138+ wait = self .lang ._map_wait (fid )
139+ if wait is not None :
140+ postactions .append (Pragma (wait ))
134141
135142 # Turn prefetch IET into a ThreadFunction
136143 name = self .sregistry .make_name (prefix = 'prefetch_host_to_device' )
@@ -156,8 +163,8 @@ def _make_waitprefetch(self, iet, sync_ops, pieces, *args):
156163 ff = SharedData ._field_flag
157164
158165 waits = []
159- for s in sync_ops :
160- sdata , threads = pieces . objs . get ( s )
166+ objs = filter_ordered ( pieces . objs . get ( s ) for s in sync_ops )
167+ for sdata , threads in objs :
161168 wait = BusyWait (CondNe (FieldFromComposite (ff , sdata [threads .index ]), 1 ))
162169 waits .append (wait )
163170
0 commit comments