|
1 | 1 | """Helper for event loop management. Allows consistently running async generate requests in sync code.""" |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import os |
4 | 5 | import threading |
5 | 6 | from collections.abc import Coroutine |
6 | 7 | from typing import Any, TypeVar |
|
13 | 14 | class _EventLoopHandler: |
14 | 15 | """A class that handles the event loop for Mellea code. Do not directly instantiate this. Use `_run_async_in_thread`.""" |
15 | 16 |
|
16 | | - def __init__(self): |
17 | | - """Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea. |
18 | | -
|
19 | | - Do not instantiate this class. Rely on the exported `_run_async_in_thread` function. |
20 | | - """ |
| 17 | + def _event_loop_setup(self): |
| 18 | + """Sets up the event loop and thread.""" |
| 19 | + # This code lives in a helper function since both __init__ and _reinit_if_forked |
| 20 | + # will need to use it. |
| 21 | + self._pid = os.getpid() # Store the pid in case users fork this process. |
21 | 22 | self._event_loop = asyncio.new_event_loop() |
22 | 23 | self._thread: threading.Thread = threading.Thread( # type: ignore[annotation-unchecked] |
23 | 24 | target=self._event_loop.run_forever, |
24 | 25 | daemon=True, # type: ignore |
25 | 26 | ) |
26 | 27 | self._thread.start() |
27 | 28 |
|
| 29 | + def __init__(self): |
| 30 | + """Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea. |
| 31 | +
|
| 32 | + Do not instantiate this class. Rely on the exported `_run_async_in_thread` function. |
| 33 | + """ |
| 34 | + self._event_loop_setup() |
| 35 | + |
| 36 | + def _reinit_if_forked(self) -> None: |
| 37 | + """Reinitialize the event loop and thread if we're in a forked child to prevent hanging on awaited tasks.""" |
| 38 | + if os.getpid() != self._pid: |
| 39 | + # If the process has been forked, reset the event loop and thread. |
| 40 | + # Don't cleanup the parent's objects. |
| 41 | + self._event_loop_setup() |
| 42 | + |
28 | 43 | def __del__(self): |
29 | 44 | """Delete the event loop handler.""" |
30 | 45 | self._close_event_loop() |
@@ -55,6 +70,7 @@ async def finalize_tasks(): |
55 | 70 |
|
56 | 71 | def __call__(self, co: Coroutine[Any, Any, R]) -> R: |
57 | 72 | """Runs the coroutine in the event loop.""" |
| 73 | + self._reinit_if_forked() |
58 | 74 | if self._event_loop == get_current_event_loop(): |
59 | 75 | # If this gets called from the same event loop, launch in a separate thread to prevent blocking. |
60 | 76 | return _EventLoopHandler()(co) |
|
0 commit comments