Skip to content

Commit 3bc8e9c

Browse files
authored
Merge pull request #3207 from blacklanternsecurity/pool-pdeathsig
Prevent process pool zombies on crash
2 parents 9cd2601 + 8e0431a commit 3bc8e9c

2 files changed

Lines changed: 97 additions & 1 deletion

File tree

bbot/core/helpers/helper.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
22
import sys
3+
import signal
4+
import ctypes
5+
import ctypes.util
36
import asyncio
47
import logging
58
from pathlib import Path
@@ -24,6 +27,24 @@
2427

2528
log = logging.getLogger("bbot.core.helpers")
2629

30+
_PR_SET_PDEATHSIG = 1
31+
32+
33+
def _pool_worker_init():
34+
"""Set PR_SET_PDEATHSIG so pool workers die when the parent process dies.
35+
36+
Prevents zombie worker accumulation after OOM kills, SIGKILL, etc.
37+
Uses SIGKILL because ProcessPoolExecutor's `except BaseException` catches
38+
SIGTERM's SystemExit, keeping workers alive until the broken pipe surfaces.
39+
40+
prctl is Linux-specific, so this is a no-op elsewhere (the symbol is absent
41+
on other platforms and would otherwise raise, breaking the whole pool).
42+
"""
43+
if not sys.platform.startswith("linux"):
44+
return
45+
libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
46+
libc.prctl(_PR_SET_PDEATHSIG, signal.SIGKILL, 0, 0, 0)
47+
2748

2849
class ConfigAwareHelper:
2950
"""
@@ -220,7 +241,7 @@ def _create_process_pool():
220241
# we spawn 1 fewer processes than cores
221242
# this helps to avoid locking up the system or competing with the main python process for cpu time
222243
num_processes = max(1, mp.cpu_count() - 1)
223-
pool_kwargs = {"max_workers": num_processes}
244+
pool_kwargs = {"max_workers": num_processes, "initializer": _pool_worker_init}
224245
# max_tasks_per_child replaces workers after N tasks, preventing memory leaks
225246
# and reducing the chance of a degraded worker process causing hangs
226247
if sys.version_info >= (3, 11):

bbot/test/test_step_1/test_helpers.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import sys
3+
import time
14
import asyncio
25
import datetime
36
import ipaddress
@@ -1004,6 +1007,78 @@ async def test_run_in_executor_mp(helpers):
10041007
assert result == sum(range(50_000))
10051008

10061009

1010+
@pytest.mark.skipif(not sys.platform.startswith("linux"), reason="PR_SET_PDEATHSIG is Linux-only")
1011+
def test_pool_workers_die_with_parent():
1012+
"""Pool workers must not survive when the parent is SIGKILL'd (OOM, crash, etc.)."""
1013+
import json
1014+
import signal
1015+
import subprocess
1016+
import tempfile
1017+
1018+
script = """
1019+
import os, sys, json, time, signal, ctypes, ctypes.util, multiprocessing as mp
1020+
from concurrent.futures import ProcessPoolExecutor
1021+
1022+
_PR_SET_PDEATHSIG = 1
1023+
1024+
def _init():
1025+
libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
1026+
libc.prctl(_PR_SET_PDEATHSIG, signal.SIGKILL, 0, 0, 0)
1027+
1028+
def _get_pid():
1029+
time.sleep(1)
1030+
return os.getpid()
1031+
1032+
# use fork context explicitly -- forkserver on 3.14 adds an intermediary process
1033+
# that complicates the parent-death chain; PR_SET_PDEATHSIG itself is start-method-agnostic
1034+
ctx = mp.get_context("fork")
1035+
pool = ProcessPoolExecutor(max_workers=2, initializer=_init, mp_context=ctx)
1036+
# submit concurrently so both workers are occupied (each takes 1s)
1037+
futs = [pool.submit(_get_pid) for _ in range(2)]
1038+
pids = list(set(f.result(timeout=30) for f in futs))
1039+
# keep workers busy so they stay alive
1040+
[pool.submit(time.sleep, 3600) for _ in range(2)]
1041+
print(json.dumps(pids), flush=True)
1042+
time.sleep(3600)
1043+
"""
1044+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
1045+
f.write(script)
1046+
script_path = f.name
1047+
1048+
def _is_running(pid):
1049+
"""Check /proc to distinguish running processes from zombies."""
1050+
try:
1051+
with open(f"/proc/{pid}/stat") as f:
1052+
# format: "pid (comm) state ..." -- state after the last ')'
1053+
state = f.read().split(")")[-1].strip().split()[0]
1054+
return state not in ("Z", "X", "x")
1055+
except (OSError, IndexError):
1056+
return False
1057+
1058+
try:
1059+
proc = subprocess.Popen([sys.executable, script_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1060+
line = proc.stdout.readline()
1061+
assert line, f"Worker script exited early, stderr: {proc.stderr.read().decode()}"
1062+
worker_pids = json.loads(line)
1063+
assert len(worker_pids) >= 2
1064+
1065+
# simulate OOM kill
1066+
os.kill(proc.pid, signal.SIGKILL)
1067+
proc.wait()
1068+
1069+
time.sleep(2)
1070+
1071+
alive = [pid for pid in worker_pids if _is_running(pid)]
1072+
1073+
# clean up survivors so they don't leak into other tests
1074+
for pid in alive:
1075+
os.kill(pid, signal.SIGKILL)
1076+
1077+
assert not alive, f"Pool workers {alive} survived parent SIGKILL (zombie leak)"
1078+
finally:
1079+
os.unlink(script_path)
1080+
1081+
10071082
def test_simhash_similarity(helpers):
10081083
"""Test SimHash helper with increasingly different HTML pages."""
10091084

0 commit comments

Comments
 (0)