From cbd91e1953cf747423e0b7db819657a3d0a94b86 Mon Sep 17 00:00:00 2001 From: Jake LoRocco Date: Wed, 11 Mar 2026 13:50:32 -0400 Subject: [PATCH] fix: add special handling for mellea global event loop when forked --- mellea/helpers/event_loop_helper.py | 26 +++++++++++++++++---- test/helpers/test_event_loop_helper.py | 32 ++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/mellea/helpers/event_loop_helper.py b/mellea/helpers/event_loop_helper.py index 749be4d02..d40f95df5 100644 --- a/mellea/helpers/event_loop_helper.py +++ b/mellea/helpers/event_loop_helper.py @@ -1,6 +1,7 @@ """Helper for event loop management. Allows consistently running async generate requests in sync code.""" import asyncio +import os import threading from collections.abc import Coroutine from typing import Any, TypeVar @@ -13,11 +14,11 @@ class _EventLoopHandler: """A class that handles the event loop for Mellea code. Do not directly instantiate this. Use `_run_async_in_thread`.""" - def __init__(self): - """Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea. - - Do not instantiate this class. Rely on the exported `_run_async_in_thread` function. - """ + def _event_loop_setup(self): + """Sets up the event loop and thread.""" + # This code lives in a helper function since both __init__ and _reinit_if_forked + # will need to use it. + self._pid = os.getpid() # Store the pid in case users fork this process. self._event_loop = asyncio.new_event_loop() self._thread: threading.Thread = threading.Thread( # type: ignore[annotation-unchecked] target=self._event_loop.run_forever, @@ -25,6 +26,20 @@ def __init__(self): ) self._thread.start() + def __init__(self): + """Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea. + + Do not instantiate this class. Rely on the exported `_run_async_in_thread` function. + """ + self._event_loop_setup() + + def _reinit_if_forked(self) -> None: + """Reinitialize the event loop and thread if we're in a forked child to prevent hanging on awaited tasks.""" + if os.getpid() != self._pid: + # If the process has been forked, reset the event loop and thread. + # Don't cleanup the parent's objects. + self._event_loop_setup() + def __del__(self): """Delete the event loop handler.""" self._close_event_loop() @@ -55,6 +70,7 @@ async def finalize_tasks(): def __call__(self, co: Coroutine[Any, Any, R]) -> R: """Runs the coroutine in the event loop.""" + self._reinit_if_forked() if self._event_loop == get_current_event_loop(): # If this gets called from the same event loop, launch in a separate thread to prevent blocking. return _EventLoopHandler()(co) diff --git a/test/helpers/test_event_loop_helper.py b/test/helpers/test_event_loop_helper.py index e8469ece4..28d234282 100644 --- a/test/helpers/test_event_loop_helper.py +++ b/test/helpers/test_event_loop_helper.py @@ -1,3 +1,5 @@ +import multiprocessing + import pytest import mellea.helpers.event_loop_helper as elh @@ -32,6 +34,36 @@ async def testing() -> int: assert elh.__event_loop_handler is not None +def test_event_loop_handler_with_forking(): + """Importing mellea before fork must not crash the child process.""" + + ctx = multiprocessing.get_context("fork") + + def child(): + import mellea.helpers.event_loop_helper as elh + + async def hello(): + return 42 + + result = elh._run_async_in_thread(hello()) + assert result == 42 + + p = ctx.Process(target=child) + + try: + p.start() + p.join(timeout=15) + assert p.exitcode == 0, ( + f"Child process failed after fork (exit code: {p.exitcode if p.exitcode is not None else 'timed out'})" + ) + + finally: + # Make sure we always clean up the process. + if p.is_alive(): + p.kill() + p.join(timeout=15) + + if __name__ == "__main__": import pytest