Skip to content

Commit 36bcfb7

Browse files
committed
fix(utils): Improve exec_in_proc to handle more failure modes
1 parent ad0eac7 commit 36bcfb7

1 file changed

Lines changed: 149 additions & 21 deletions

File tree

context_chat_backend/utils.py

Lines changed: 149 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
import traceback
1111
from collections.abc import Callable
12+
from contextlib import suppress
1213
from functools import partial, wraps
1314
from multiprocessing.connection import Connection
1415
from time import perf_counter_ns
@@ -72,31 +73,95 @@ def JSONResponse(
7273

7374

7475
class SubprocessKilledError(RuntimeError):
75-
"""Raised when a subprocess exits with a non-zero exit code (likely OOM kill or unhandled signal)."""
76+
"""Raised when a subprocess is terminated by a signal (for example SIGKILL)."""
7677

7778
def __init__(self, pid: int, target_name: str, exitcode: int):
7879
super().__init__(
79-
f'Subprocess PID {pid} for {target_name} exited with non-zero exit code {exitcode}'
80-
' — possible OOM kill or unhandled signal'
80+
f'Subprocess PID {pid} for {target_name} exited with signal {abs(exitcode)} '
81+
f'(raw exit code: {exitcode})'
8182
)
8283
self.exitcode = exitcode
8384

8485

86+
class SubprocessExecutionError(RuntimeError):
87+
"""Raised when a subprocess exits non-zero without a recoverable Python exception payload."""
88+
89+
def __init__(self, pid: int, target_name: str, exitcode: int, details: str = ''):
90+
msg = f'Subprocess PID {pid} for {target_name} exited with non-zero exit code {exitcode}'
91+
if details:
92+
msg = f'{msg}: {details}'
93+
super().__init__(msg)
94+
self.exitcode = exitcode
95+
96+
97+
_MAX_STD_CAPTURE_CHARS = 64 * 1024
98+
99+
100+
def _truncate_capture(text: str) -> tuple[str, bool]:
101+
if len(text) <= _MAX_STD_CAPTURE_CHARS:
102+
return text, False
103+
104+
head = _MAX_STD_CAPTURE_CHARS // 2
105+
tail = _MAX_STD_CAPTURE_CHARS - head
106+
omitted = len(text) - _MAX_STD_CAPTURE_CHARS
107+
truncated = (
108+
f'[truncated {omitted} chars]\n'
109+
f'{text[:head]}\n'
110+
'[...snip...]\n'
111+
f'{text[-tail:]}'
112+
)
113+
return truncated, True
114+
115+
85116
def exception_wrap(fun: Callable | None, *args, resconn: Connection, stdconn: Connection, **kwargs):
86117
stdout_capture = io.StringIO()
87118
stderr_capture = io.StringIO()
119+
orig_stdout = sys.stdout
120+
orig_stderr = sys.stderr
88121
sys.stdout = stdout_capture
89122
sys.stderr = stderr_capture
90123

91124
try:
92125
if fun is None:
93-
return resconn.send({ 'value': None, 'error': None })
94-
resconn.send({ 'value': fun(*args, **kwargs), 'error': None })
126+
resconn.send({ 'value': None, 'error': None })
127+
else:
128+
resconn.send({ 'value': fun(*args, **kwargs), 'error': None })
95129
except BaseException as e:
96130
tb = traceback.format_exc()
97-
resconn.send({ 'value': None, 'error': e, 'traceback': tb })
131+
payload = {
132+
'value': None,
133+
'error': e,
134+
'traceback': tb,
135+
'error_type': type(e).__name__,
136+
'error_module': type(e).__module__,
137+
'error_message': str(e),
138+
}
139+
try:
140+
resconn.send(payload)
141+
except Exception as send_err:
142+
# Fallback for unpicklable exceptions.
143+
with suppress(Exception):
144+
resconn.send({
145+
'value': None,
146+
'error': None,
147+
'traceback': tb,
148+
'error_type': type(e).__name__,
149+
'error_module': type(e).__module__,
150+
'error_message': str(e),
151+
'send_error': str(send_err),
152+
})
98153
finally:
99-
stdconn.send({'stdout': stdout_capture.getvalue(), 'stderr': stderr_capture.getvalue()})
154+
sys.stdout = orig_stdout
155+
sys.stderr = orig_stderr
156+
stdout_text, stdout_truncated = _truncate_capture(stdout_capture.getvalue())
157+
stderr_text, stderr_truncated = _truncate_capture(stderr_capture.getvalue())
158+
with suppress(Exception):
159+
stdconn.send({
160+
'stdout': stdout_text,
161+
'stderr': stderr_text,
162+
'stdout_truncated': stdout_truncated,
163+
'stderr_truncated': stderr_truncated,
164+
})
100165

101166

102167
def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None): # noqa: B006
@@ -117,30 +182,93 @@ def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daem
117182
start = perf_counter_ns()
118183
p.start()
119184
_logger.debug('Subprocess PID %d started for %s, waiting for it to finish (no timeout)', p.pid, target_name)
185+
186+
result = None
187+
stdobj = {
188+
'stdout': '',
189+
'stderr': '',
190+
'stdout_truncated': False,
191+
'stderr_truncated': False,
192+
}
193+
got_result = False
194+
got_std = False
195+
196+
# Drain result/std pipes while child is still alive to avoid deadlock on full pipe buffers.
197+
while p.is_alive() and (not got_result or not got_std):
198+
if not got_result and pconn.poll(0.1):
199+
with suppress(EOFError, OSError, BrokenPipeError):
200+
result = pconn.recv()
201+
got_result = True
202+
if not got_std and std_pconn.poll():
203+
with suppress(EOFError, OSError, BrokenPipeError):
204+
stdobj = std_pconn.recv()
205+
got_std = True
206+
120207
p.join()
121208
elapsed_ms = (perf_counter_ns() - start) / 1e6
122209
_logger.debug(
123210
'Subprocess PID %d for %s finished in %.2f ms (exit code: %s)',
124211
p.pid, target_name, elapsed_ms, p.exitcode,
125212
)
126-
stdobj = std_pconn.recv()
127-
_logger.info(f'std info for {target_name}', extra={
128-
'stdout': stdobj['stdout'],
129-
'stderr': stdobj['stderr'],
130-
})
131-
132-
result = pconn.recv()
133-
if result['error'] is not None:
134-
_logger.error('original traceback: %s', result['traceback'])
213+
214+
if not got_std:
215+
with suppress(EOFError, OSError, BrokenPipeError):
216+
if std_pconn.poll():
217+
stdobj = std_pconn.recv()
218+
got_std = True
219+
if stdobj['stdout'] or stdobj['stderr']:
220+
extra = {
221+
'stdout': stdobj['stdout'],
222+
'stderr': stdobj['stderr'],
223+
}
224+
if stdobj.get('stdout_truncated') or stdobj.get('stderr_truncated'):
225+
extra['stdio_truncated'] = {
226+
'stdout': bool(stdobj.get('stdout_truncated')),
227+
'stderr': bool(stdobj.get('stderr_truncated')),
228+
}
229+
_logger.info('std info for %s', target_name, extra=extra)
230+
231+
if not got_result:
232+
with suppress(EOFError, OSError, BrokenPipeError):
233+
if pconn.poll():
234+
result = pconn.recv()
235+
got_result = True
236+
237+
if result is not None and result.get('error') is not None:
238+
_logger.error('original traceback: %s', result.get('traceback', ''))
135239
raise result['error']
136240

137-
if p.exitcode != 0:
241+
if result is not None and result.get('error_type'):
242+
details = (
243+
f"{result.get('error_module', '')}.{result.get('error_type', '')}: "
244+
f"{result.get('error_message', '')}"
245+
)
246+
if result.get('traceback'):
247+
_logger.error('remote traceback: %s', result['traceback'])
248+
raise SubprocessExecutionError(p.pid or 0, target_name, p.exitcode or 1, details)
249+
250+
if p.exitcode and p.exitcode < 0:
138251
_logger.warning(
139-
'Subprocess PID %d for %s exited with non-zero exit code %d after %.2f ms'
140-
' — possible OOM kill or unhandled signal',
141-
p.pid, target_name, p.exitcode, elapsed_ms,
252+
'Subprocess PID %d for %s exited due to signal %d after %.2f ms',
253+
p.pid, target_name, abs(p.exitcode), elapsed_ms,
254+
)
255+
raise SubprocessKilledError(p.pid or 0, target_name, p.exitcode)
256+
257+
if p.exitcode not in (None, 0):
258+
raise SubprocessExecutionError(
259+
p.pid or 0,
260+
target_name,
261+
p.exitcode,
262+
'No structured exception payload received from child process',
263+
)
264+
265+
if result is None:
266+
raise SubprocessExecutionError(
267+
p.pid or 0,
268+
target_name,
269+
0,
270+
'Subprocess exited successfully but returned no result payload',
142271
)
143-
raise SubprocessKilledError(p.pid or 0, target_name, p.exitcode or -1)
144272

145273
return result['value']
146274

0 commit comments

Comments
 (0)