Skip to content

Commit cbd91e1

Browse files
committed
fix: add special handling for mellea global event loop when forked
1 parent cbd63bd commit cbd91e1

2 files changed

Lines changed: 53 additions & 5 deletions

File tree

mellea/helpers/event_loop_helper.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Helper for event loop management. Allows consistently running async generate requests in sync code."""
22

33
import asyncio
4+
import os
45
import threading
56
from collections.abc import Coroutine
67
from typing import Any, TypeVar
@@ -13,18 +14,32 @@
1314
class _EventLoopHandler:
1415
"""A class that handles the event loop for Mellea code. Do not directly instantiate this. Use `_run_async_in_thread`."""
1516

16-
def __init__(self):
17-
"""Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea.
18-
19-
Do not instantiate this class. Rely on the exported `_run_async_in_thread` function.
20-
"""
17+
def _event_loop_setup(self):
18+
"""Sets up the event loop and thread."""
19+
# This code lives in a helper function since both __init__ and _reinit_if_forked
20+
# will need to use it.
21+
self._pid = os.getpid() # Store the pid in case users fork this process.
2122
self._event_loop = asyncio.new_event_loop()
2223
self._thread: threading.Thread = threading.Thread( # type: ignore[annotation-unchecked]
2324
target=self._event_loop.run_forever,
2425
daemon=True, # type: ignore
2526
)
2627
self._thread.start()
2728

29+
def __init__(self):
30+
"""Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea.
31+
32+
Do not instantiate this class. Rely on the exported `_run_async_in_thread` function.
33+
"""
34+
self._event_loop_setup()
35+
36+
def _reinit_if_forked(self) -> None:
37+
"""Reinitialize the event loop and thread if we're in a forked child to prevent hanging on awaited tasks."""
38+
if os.getpid() != self._pid:
39+
# If the process has been forked, reset the event loop and thread.
40+
# Don't cleanup the parent's objects.
41+
self._event_loop_setup()
42+
2843
def __del__(self):
2944
"""Delete the event loop handler."""
3045
self._close_event_loop()
@@ -55,6 +70,7 @@ async def finalize_tasks():
5570

5671
def __call__(self, co: Coroutine[Any, Any, R]) -> R:
5772
"""Runs the coroutine in the event loop."""
73+
self._reinit_if_forked()
5874
if self._event_loop == get_current_event_loop():
5975
# If this gets called from the same event loop, launch in a separate thread to prevent blocking.
6076
return _EventLoopHandler()(co)

test/helpers/test_event_loop_helper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import multiprocessing
2+
13
import pytest
24

35
import mellea.helpers.event_loop_helper as elh
@@ -32,6 +34,36 @@ async def testing() -> int:
3234
assert elh.__event_loop_handler is not None
3335

3436

37+
def test_event_loop_handler_with_forking():
38+
"""Importing mellea before fork must not crash the child process."""
39+
40+
ctx = multiprocessing.get_context("fork")
41+
42+
def child():
43+
import mellea.helpers.event_loop_helper as elh
44+
45+
async def hello():
46+
return 42
47+
48+
result = elh._run_async_in_thread(hello())
49+
assert result == 42
50+
51+
p = ctx.Process(target=child)
52+
53+
try:
54+
p.start()
55+
p.join(timeout=15)
56+
assert p.exitcode == 0, (
57+
f"Child process failed after fork (exit code: {p.exitcode if p.exitcode is not None else 'timed out'})"
58+
)
59+
60+
finally:
61+
# Make sure we always clean up the process.
62+
if p.is_alive():
63+
p.kill()
64+
p.join(timeout=15)
65+
66+
3567
if __name__ == "__main__":
3668
import pytest
3769

0 commit comments

Comments
 (0)