Skip to content
23 changes: 22 additions & 1 deletion bbot/core/helpers/helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import sys
import signal
import ctypes
import ctypes.util
import asyncio
import logging
from pathlib import Path
Expand All @@ -24,6 +27,24 @@

log = logging.getLogger("bbot.core.helpers")

_PR_SET_PDEATHSIG = 1


def _pool_worker_init():
"""Set PR_SET_PDEATHSIG so pool workers die when the parent process dies.

Prevents zombie worker accumulation after OOM kills, SIGKILL, etc.
Uses SIGKILL because ProcessPoolExecutor's `except BaseException` catches
SIGTERM's SystemExit, keeping workers alive until the broken pipe surfaces.

prctl is Linux-specific, so this is a no-op elsewhere (the symbol is absent
on other platforms and would otherwise raise, breaking the whole pool).
"""
if not sys.platform.startswith("linux"):
return
libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
libc.prctl(_PR_SET_PDEATHSIG, signal.SIGKILL, 0, 0, 0)


class ConfigAwareHelper:
"""
Expand Down Expand Up @@ -220,7 +241,7 @@ def _create_process_pool():
# we spawn 1 fewer processes than cores
# this helps to avoid locking up the system or competing with the main python process for cpu time
num_processes = max(1, mp.cpu_count() - 1)
pool_kwargs = {"max_workers": num_processes}
pool_kwargs = {"max_workers": num_processes, "initializer": _pool_worker_init}
# max_tasks_per_child replaces workers after N tasks, preventing memory leaks
# and reducing the chance of a degraded worker process causing hangs
if sys.version_info >= (3, 11):
Expand Down
75 changes: 75 additions & 0 deletions bbot/test/test_step_1/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import sys
import time
import asyncio
import datetime
import ipaddress
Expand Down Expand Up @@ -1004,6 +1007,78 @@ async def test_run_in_executor_mp(helpers):
assert result == sum(range(50_000))


@pytest.mark.skipif(not sys.platform.startswith("linux"), reason="PR_SET_PDEATHSIG is Linux-only")
def test_pool_workers_die_with_parent():
"""Pool workers must not survive when the parent is SIGKILL'd (OOM, crash, etc.)."""
import json
import signal
import subprocess
import tempfile

script = """
import os, sys, json, time, signal, ctypes, ctypes.util, multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor

_PR_SET_PDEATHSIG = 1

def _init():
libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
libc.prctl(_PR_SET_PDEATHSIG, signal.SIGKILL, 0, 0, 0)

def _get_pid():
time.sleep(1)
return os.getpid()

# use fork context explicitly -- forkserver on 3.14 adds an intermediary process
# that complicates the parent-death chain; PR_SET_PDEATHSIG itself is start-method-agnostic
ctx = mp.get_context("fork")
pool = ProcessPoolExecutor(max_workers=2, initializer=_init, mp_context=ctx)
# submit concurrently so both workers are occupied (each takes 1s)
futs = [pool.submit(_get_pid) for _ in range(2)]
pids = list(set(f.result(timeout=30) for f in futs))
# keep workers busy so they stay alive
[pool.submit(time.sleep, 3600) for _ in range(2)]
print(json.dumps(pids), flush=True)
time.sleep(3600)
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(script)
script_path = f.name

def _is_running(pid):
"""Check /proc to distinguish running processes from zombies."""
try:
with open(f"/proc/{pid}/stat") as f:
# format: "pid (comm) state ..." -- state after the last ')'
state = f.read().split(")")[-1].strip().split()[0]
return state not in ("Z", "X", "x")
except (OSError, IndexError):
return False

try:
proc = subprocess.Popen([sys.executable, script_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
line = proc.stdout.readline()
assert line, f"Worker script exited early, stderr: {proc.stderr.read().decode()}"
worker_pids = json.loads(line)
assert len(worker_pids) >= 2

# simulate OOM kill
os.kill(proc.pid, signal.SIGKILL)
proc.wait()

time.sleep(2)

alive = [pid for pid in worker_pids if _is_running(pid)]

# clean up survivors so they don't leak into other tests
for pid in alive:
os.kill(pid, signal.SIGKILL)

assert not alive, f"Pool workers {alive} survived parent SIGKILL (zombie leak)"
finally:
os.unlink(script_path)


def test_simhash_similarity(helpers):
"""Test SimHash helper with increasingly different HTML pages."""

Expand Down