88 Break , Call , Conditional , DummyExpr , EntryFunction , FindNodes , FindSymbols , Iteration ,
99 KernelLaunch , List , Return , Transformer , make_callable
1010)
11+ from devito .mpi .distributed import MPICommObject
1112from devito .passes .iet .engine import iet_pass
1213from devito .symbolics import CondEq , MathFunction
1314from devito .tools import dtype_to_ctype
@@ -104,28 +105,34 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
104105
105106def check_launch (graph , options = None , ** kwargs ):
106107 """
107- Insert the CHECK_LAUNCH macro if errctl is set to ensure graceful handling of
108- failed kernel launches. This macro should only be inserted if the kernel is
109- directly within a loop, as compilation will fail otherwise.
108+ Insert the CHECK_LAUNCH* macros if errctl is set to ensure graceful
109+ handling of failed kernel launches. This macro is only inserted if the
110+ kernel launch is directly within a loop, as compilation would fail
111+ otherwise.
110112 """
111113 if options is None or not options .get ('errctl' , False ):
112114 return
113115
114116 langbb = kwargs ['langbb' ]
115117
116- definition = make_launch_macros (langbb )
118+ definition = make_launch_macros (langbb , options = options )
117119 if not definition :
118120 return
119121
120- macro = [langbb ['check-launch' ]]
121-
122- _check_launch (graph , definition = definition , macro = macro , ** kwargs )
122+ _check_launch (graph , definition = definition , options = options , ** kwargs )
123123
124124
125125@iet_pass
126- def _check_launch (iet , definition = None , macro = None , ** kwargs ):
126+ def _check_launch (iet , definition = None , options = None , ** kwargs ):
127127 iterations = FindNodes (Iteration ).visit (iet )
128128
129+ if options ['mpi' ] and \
130+ isinstance (iet , EntryFunction ) and \
131+ any (isinstance (i , MPICommObject ) for i in iet .parameters ):
132+ check = [List (body = c .Statement ('CHECK_LAUNCH_RETURN' ))]
133+ else :
134+ check = [List (body = c .Statement ('CHECK_LAUNCH' ))]
135+
129136 mapper = {}
130137 for i in iterations :
131138 # Two stages of substitution to account for the edge case
@@ -135,7 +142,7 @@ def _check_launch(iet, definition=None, macro=None, **kwargs):
135142 launches = FindNodes (KernelLaunch ).visit (i )
136143
137144 for launch in launches :
138- launch_mapper [launch ] = List (body = [launch ] + macro )
145+ launch_mapper [launch ] = List (body = [launch ] + check )
139146
140147 if launch_mapper :
141148 mapper [i ] = Transformer (launch_mapper ).visit (i )
@@ -148,7 +155,7 @@ def _check_launch(iet, definition=None, macro=None, **kwargs):
148155 return iet , extras
149156
150157
151- def make_launch_macros (langbb ):
158+ def make_launch_macros (langbb , options = None ):
152159 """
153160 Define macros to check for errors to ensure graceful handling of failed kernel
154161 launches.
@@ -158,7 +165,20 @@ def make_launch_macros(langbb):
158165 with contextlib .suppress (NotImplementedError ):
159166 peek = langbb ['peek-error' ]
160167 success = langbb ['error-none' ]
161- return [('CHECK_LAUNCH' , f'if ({ peek ().name } () != { success } ) {{break;}}' )]
168+
169+ headers = [
170+ ('CHECK_LAUNCH' ,
171+ f'if ({ peek ().name } () != { success } ) {{break;}}' )
172+ ]
173+
174+ if options ['mpi' ]:
175+ headers .append (
176+ ('CHECK_LAUNCH_RETURN' ,
177+ f'if ({ peek ().name } () != { success } ) '
178+ f'{{return { error_mapper ["KernelLaunch" ]} ;}}' )
179+ )
180+
181+ return headers
162182
163183 return []
164184
0 commit comments