Skip to content

Commit 27011b6

Browse files
committed
compiler: Improve kernel error check with MPI
1 parent 95a39e7 commit 27011b6

2 files changed

Lines changed: 38 additions & 13 deletions

File tree

devito/operator/operator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -706,9 +706,14 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
706706

707707
return args
708708

709-
def _postprocess_errors(self, retval):
709+
def _postprocess_errors(self, retval, comm=None):
710710
if retval == 0:
711711
return
712+
elif comm and comm is not MPI.COMM_NULL:
713+
# A rank-local Python exception can leave peer ranks blocked inside
714+
# generated MPI waits/exchanges. Abort the communicator instead of
715+
# risking an indefinite hang.
716+
comm.Abort(retval)
712717
elif retval == error_mapper['Stability']:
713718
raise ExecutionError("Detected nan/inf in some output Functions")
714719
elif retval == error_mapper['KernelLaunch']:
@@ -1008,7 +1013,7 @@ def apply(self, **kwargs):
10081013

10091014
with self._profiler.timer_on('arguments-postprocess'):
10101015
# Perform error checking
1011-
self._postprocess_errors(retval)
1016+
self._postprocess_errors(retval, comm=args.comm)
10121017
# Post-process runtime arguments
10131018
self._postprocess_arguments(args, **kwargs)
10141019

devito/passes/iet/errors.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Break, Call, Conditional, DummyExpr, EntryFunction, FindNodes, FindSymbols, Iteration,
99
KernelLaunch, List, Return, Transformer, make_callable
1010
)
11+
from devito.mpi.distributed import MPICommObject
1112
from devito.passes.iet.engine import iet_pass
1213
from devito.symbolics import CondEq, MathFunction
1314
from devito.tools import dtype_to_ctype
@@ -104,28 +105,34 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
104105

105106
def check_launch(graph, options=None, **kwargs):
106107
"""
107-
Insert the CHECK_LAUNCH macro if errctl is set to ensure graceful handling of
108-
failed kernel launches. This macro should only be inserted if the kernel is
109-
directly within a loop, as compilation will fail otherwise.
108+
Insert the CHECK_LAUNCH* macros if errctl is set to ensure graceful
109+
handling of failed kernel launches. This macro is only inserted if the
110+
kernel launch is directly within a loop, as compilation would fail
111+
otherwise.
110112
"""
111113
if options is None or not options.get('errctl', False):
112114
return
113115

114116
langbb = kwargs['langbb']
115117

116-
definition = make_launch_macros(langbb)
118+
definition = make_launch_macros(langbb, options=options)
117119
if not definition:
118120
return
119121

120-
macro = [langbb['check-launch']]
121-
122-
_check_launch(graph, definition=definition, macro=macro, **kwargs)
122+
_check_launch(graph, definition=definition, options=options, **kwargs)
123123

124124

125125
@iet_pass
126-
def _check_launch(iet, definition=None, macro=None, **kwargs):
126+
def _check_launch(iet, definition=None, options=None, **kwargs):
127127
iterations = FindNodes(Iteration).visit(iet)
128128

129+
if options['mpi'] and \
130+
isinstance(iet, EntryFunction) and \
131+
any(isinstance(i, MPICommObject) for i in iet.parameters):
132+
check = [List(body=c.Statement('CHECK_LAUNCH_RETURN'))]
133+
else:
134+
check = [List(body=c.Statement('CHECK_LAUNCH'))]
135+
129136
mapper = {}
130137
for i in iterations:
131138
# Two stages of substitution to account for the edge case
@@ -135,7 +142,7 @@ def _check_launch(iet, definition=None, macro=None, **kwargs):
135142
launches = FindNodes(KernelLaunch).visit(i)
136143

137144
for launch in launches:
138-
launch_mapper[launch] = List(body=[launch] + macro)
145+
launch_mapper[launch] = List(body=[launch] + check)
139146

140147
if launch_mapper:
141148
mapper[i] = Transformer(launch_mapper).visit(i)
@@ -148,7 +155,7 @@ def _check_launch(iet, definition=None, macro=None, **kwargs):
148155
return iet, extras
149156

150157

151-
def make_launch_macros(langbb):
158+
def make_launch_macros(langbb, options=None):
152159
"""
153160
Define macros to check for errors to ensure graceful handling of failed kernel
154161
launches.
@@ -158,7 +165,20 @@ def make_launch_macros(langbb):
158165
with contextlib.suppress(NotImplementedError):
159166
peek = langbb['peek-error']
160167
success = langbb['error-none']
161-
return [('CHECK_LAUNCH', f'if ({peek().name}() != {success}) {{break;}}')]
168+
169+
headers = [
170+
('CHECK_LAUNCH',
171+
f'if ({peek().name}() != {success}) {{break;}}')
172+
]
173+
174+
if options['mpi']:
175+
headers.append(
176+
('CHECK_LAUNCH_RETURN',
177+
f'if ({peek().name}() != {success}) '
178+
f'{{return {error_mapper["KernelLaunch"]};}}')
179+
)
180+
181+
return headers
162182

163183
return []
164184

0 commit comments

Comments
 (0)