Skip to content

Commit 8a664ff

Browse files
committed
fix: add special handling for mellea global event loop when forked
1 parent cbd63bd commit 8a664ff

2 files changed

Lines changed: 40 additions & 0 deletions

File tree

mellea/helpers/event_loop_helper.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Helper for event loop management. Allows consistently running async generate requests in sync code."""
22

33
import asyncio
4+
import os
45
import threading
56
from collections.abc import Coroutine
67
from typing import Any, TypeVar
@@ -18,13 +19,21 @@ def __init__(self):
1819
1920
Do not instantiate this class. Rely on the exported `_run_async_in_thread` function.
2021
"""
22+
self._pid = os.getpid() # Store the pid incase users fork this process.
2123
self._event_loop = asyncio.new_event_loop()
2224
self._thread: threading.Thread = threading.Thread( # type: ignore[annotation-unchecked]
2325
target=self._event_loop.run_forever,
2426
daemon=True, # type: ignore
2527
)
2628
self._thread.start()
2729

30+
def _reinit_if_forked(self) -> None:
31+
"""Reinitialize the event loop and thread if we're in a forked child to prevent hanging on awaited tasks."""
32+
if os.getpid() != self._pid:
33+
# If the process has been forked, reset the event loop and thread.
34+
# Don't cleanup the parent's objects.
35+
self.__init__()
36+
2837
def __del__(self):
2938
"""Delete the event loop handler."""
3039
self._close_event_loop()
@@ -55,6 +64,7 @@ async def finalize_tasks():
5564

5665
def __call__(self, co: Coroutine[Any, Any, R]) -> R:
5766
"""Runs the coroutine in the event loop."""
67+
self._reinit_if_forked()
5868
if self._event_loop == get_current_event_loop():
5969
# If this gets called from the same event loop, launch in a separate thread to prevent blocking.
6070
return _EventLoopHandler()(co)

test/helpers/test_event_loop_helper.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import multiprocessing
2+
13
import pytest
24

35
import mellea.helpers.event_loop_helper as elh
@@ -32,6 +34,34 @@ async def testing() -> int:
3234
assert elh.__event_loop_handler is not None
3335

3436

37+
def test_event_loop_handler_with_forking():
38+
"""Importing mellea before fork must not crash the child process."""
39+
40+
ctx = multiprocessing.get_context("fork")
41+
42+
def child():
43+
import mellea.helpers.event_loop_helper as elh
44+
45+
async def hello():
46+
return 42
47+
48+
result = elh._run_async_in_thread(hello())
49+
assert result == 42
50+
51+
p = ctx.Process(target=child)
52+
53+
try:
54+
p.start()
55+
p.join(timeout=15)
56+
assert p.exitcode == 0, f"Child process failed after fork (exit code: {p.exitcode if p.exitcode is not None else 'timed out'})"
57+
58+
finally:
59+
# Make sure we always clean up the process.
60+
if p.is_alive():
61+
p.kill()
62+
p.join(timeout=15)
63+
64+
3565
if __name__ == "__main__":
3666
import pytest
3767

0 commit comments

Comments
 (0)