Skip to content

Commit 3d9032a

Browse files
authored
chore(asyncio): hardening pass — utilities, deprecated API cleanup, lifespan fix (#9552)
1 parent 73ec9d2 commit 3d9032a

20 files changed

Lines changed: 465 additions & 122 deletions

marimo/_server/ai/tools/tool_manager.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,7 @@ async def _call_handler(
306306

307307
if inspect.iscoroutinefunction(handler):
308308
return await handler(arguments)
309-
else:
310-
# Run sync function in thread pool to avoid blocking
311-
return await asyncio.get_event_loop().run_in_executor(
312-
None, handler, arguments
313-
)
309+
return await asyncio.to_thread(handler, arguments)
314310

315311
async def _invoke_mcp_tool(
316312
self, tool_name: str, arguments: FunctionArgs

marimo/_server/api/endpoints/execution.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from __future__ import annotations
33

44
import asyncio
5-
import contextlib
65
from typing import TYPE_CHECKING
76
from uuid import uuid4
87

@@ -33,6 +32,7 @@
3332
from marimo._server.uvicorn_utils import close_uvicorn
3433
from marimo._server.workspace import MarimoFileKey
3534
from marimo._types.ids import ConsumerId
35+
from marimo._utils.asyncio_utils import cancel_and_wait
3636

3737
if TYPE_CHECKING:
3838
from collections.abc import AsyncGenerator
@@ -335,9 +335,7 @@ async def sse_generator() -> AsyncGenerator[str, None]:
335335

336336
yield build_done_event(session, listener)
337337
finally:
338-
disconnect_task.cancel()
339-
with contextlib.suppress(asyncio.CancelledError):
340-
await disconnect_task
338+
await cancel_and_wait(disconnect_task)
341339

342340
return StreamingResponse(sse_generator(), media_type="text/event-stream")
343341

marimo/_server/api/endpoints/ws_endpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def _close() -> None:
369369
self.manager.close_session(self.params.session_id)
370370

371371
if session is not None:
372-
cancellation_handle = asyncio.get_event_loop().call_later(
372+
cancellation_handle = asyncio.get_running_loop().call_later(
373373
session.ttl_seconds, _close
374374
)
375375
self.cancel_close_handle = cancellation_handle

marimo/_server/api/interrupt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class InterruptHandler:
2121
def __init__(self, quiet: bool, shutdown: Callable[[], None]) -> None:
2222
self.quiet = quiet
2323
self.shutdown = shutdown
24-
self.loop = asyncio.get_event_loop()
24+
self.loop = asyncio.get_running_loop()
2525
self.original_handler = signal.getsignal(signal.SIGINT)
2626
self._time_of_last_confirmation: float | None = None
2727

marimo/_server/api/lifespans.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from marimo._server.uvicorn_utils import close_uvicorn
2828
from marimo._server.workspace import NEW_FILE
2929
from marimo._session.model import SessionMode
30+
from marimo._utils.asyncio_utils import cancel_and_wait, supervised_task
3031
from marimo._utils.subprocess import cancel_pending_reaps
3132

3233
if TYPE_CHECKING:
@@ -57,18 +58,15 @@ async def lsp(app: Starlette) -> AsyncIterator[None]:
5758

5859
LOGGER.debug("Language Servers are enabled")
5960
# Start LSP server in background to avoid blocking server startup
60-
task = asyncio.create_task(session_mgr.start_lsp_server())
61-
background_tasks.add(task) # Keep a reference to prevent GC
62-
task.add_done_callback(background_tasks.discard) # Clean up when done
61+
task = supervised_task(
62+
session_mgr.start_lsp_server(),
63+
name="lsp.start",
64+
registry=background_tasks,
65+
)
6366

6467
yield
6568

66-
# Shutdown
67-
task.cancel()
68-
try:
69-
await task
70-
except asyncio.CancelledError:
71-
pass
69+
await cancel_and_wait(task)
7270

7371

7472
@contextlib.asynccontextmanager
@@ -119,24 +117,30 @@ async def background_connect_mcp_servers() -> MCPClient | None:
119117
LOGGER.warning(f"Failed to connect MCP servers: {e}")
120118
return None
121119

122-
task = asyncio.create_task(background_connect_mcp_servers())
123-
background_tasks.add(task) # Keep a reference to prevent GC
124-
task.add_done_callback(background_tasks.discard) # Clean up when done
120+
# Awaited below — opt out of supervisor logging to avoid duplicate logs.
121+
task = supervised_task(
122+
background_connect_mcp_servers(),
123+
name="mcp.connect",
124+
registry=background_tasks,
125+
on_exception=lambda _exc: None,
126+
)
125127

126128
yield
127129

128-
# Shutdown
129-
task.cancel()
130+
await cancel_and_wait(task)
131+
if task.cancelled():
132+
return
133+
134+
mcp_client = task.result()
135+
if not mcp_client:
136+
return
137+
130138
try:
131-
mcp_client = await task
132-
if mcp_client:
133-
LOGGER.info("Disconnecting from all MCP servers")
134-
await mcp_client.disconnect_from_all_servers()
135-
LOGGER.info("Successfully disconnected from all MCP servers")
136-
except asyncio.CancelledError:
137-
pass
139+
LOGGER.info("Disconnecting from all MCP servers")
140+
await mcp_client.disconnect_from_all_servers()
141+
LOGGER.info("Successfully disconnected from all MCP servers")
138142
except Exception as e:
139-
LOGGER.error(f"Error during MCP cleanup: {e}")
143+
LOGGER.error(f"Error during MCP disconnect: {e}")
140144

141145

142146
@contextlib.asynccontextmanager

marimo/_server/api/middleware.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ async def send(
343343
self, request: _URLRequest, stream: bool = False, max_retries: int = 2
344344
) -> _AsyncHTTPResponse:
345345
del stream
346-
loop = asyncio.get_event_loop()
346+
loop = asyncio.get_running_loop()
347347

348348
body = await self._collect_body(request)
349349

marimo/_server/export/exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ async def _save_file(
746746
filepath = export_dir / download_name
747747

748748
# Run blocking file I/O in thread pool
749-
loop = asyncio.get_event_loop()
749+
loop = asyncio.get_running_loop()
750750
await loop.run_in_executor(
751751
self._executor, self._write_file_sync, filepath, content
752752
)

marimo/_server/session_manager.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88

99
from __future__ import annotations
1010

11-
import asyncio
1211
from pathlib import Path
13-
from typing import TYPE_CHECKING, TypeVar
12+
from typing import TYPE_CHECKING
1413

1514
from marimo import _loggers
1615
from marimo._cli.sandbox import SandboxMode
@@ -48,10 +47,11 @@
4847
from marimo._session.session_repository import SessionRepository
4948
from marimo._session.types import KernelState
5049
from marimo._types.ids import ConsumerId, SessionId
50+
from marimo._utils.asyncio_utils import fire_and_forget
5151
from marimo._utils.file_watcher import FileWatcherManager
5252

5353
if TYPE_CHECKING:
54-
from collections.abc import Awaitable, Coroutine, Mapping
54+
from collections.abc import Mapping
5555

5656
from marimo._session.notebook import AppFileManager
5757

@@ -241,7 +241,10 @@ def create_session(
241241
self._repository.add_sync(session_id, session)
242242

243243
# Emit session created event (triggers file watcher attachment, recents, etc.)
244-
run_async(self._event_bus.emit_session_created(session))
244+
fire_and_forget(
245+
self._event_bus.emit_session_created(session),
246+
name="session.created",
247+
)
245248

246249
return session
247250

@@ -333,10 +336,11 @@ def maybe_resume_session(
333336
if resumed_session:
334337
# Emit resume event (use new_session_id as both old and new since
335338
# the strategy already updated it)
336-
run_async(
339+
fire_and_forget(
337340
self._event_bus.emit_session_resumed(
338341
resumed_session, new_session_id
339-
)
342+
),
343+
name="session.resumed",
340344
)
341345

342346
return resumed_session
@@ -389,7 +393,10 @@ def close_session(self, session_id: SessionId) -> bool:
389393
if session is None:
390394
return False
391395

392-
run_async(self._event_bus.emit_session_closed(session))
396+
fire_and_forget(
397+
self._event_bus.emit_session_closed(session),
398+
name="session.closed",
399+
)
393400

394401
session.close()
395402
return True
@@ -418,28 +425,3 @@ def should_send_code_to_frontend(self) -> bool:
418425
def get_active_connection_count(self) -> int:
419426
"""Get the number of sessions with active connections."""
420427
return len(self._repository.get_active_sessions())
421-
422-
423-
T = TypeVar("T")
424-
425-
426-
def run_async(coro: Coroutine[None, None, T] | Awaitable[T]) -> T:
427-
"""Run an async coroutine, handling various event loop states.
428-
429-
1. Event loop is running: create a task
430-
2. Event loop exists but not running: run_until_complete
431-
3. No event loop: create one with asyncio.run
432-
"""
433-
try:
434-
loop = asyncio.get_event_loop()
435-
if loop.is_running():
436-
# Create a task and return it (fire and forget)
437-
# Note: This doesn't wait for completion
438-
task = asyncio.create_task(coro) # type: ignore
439-
return task # type: ignore
440-
else:
441-
# Run to completion
442-
return loop.run_until_complete(coro) # type: ignore
443-
except RuntimeError:
444-
# No event loop exists, create one
445-
return asyncio.run(coro) # type: ignore

marimo/_server/utils.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,20 @@
22
from __future__ import annotations
33

44
import asyncio
5-
import sys
65
from typing import TYPE_CHECKING, Any, TypeVar
76

7+
from marimo._utils.asyncio_utils import initialize_asyncio
8+
89
if TYPE_CHECKING:
910
from collections.abc import Coroutine
1011

12+
__all__ = [
13+
"asyncio_run",
14+
"initialize_asyncio",
15+
"initialize_fd_limit",
16+
"initialize_mimetypes",
17+
]
18+
1119

1220
def initialize_mimetypes() -> None:
1321
import mimetypes
@@ -34,16 +42,6 @@ def initialize_mimetypes() -> None:
3442
)
3543

3644

37-
def initialize_asyncio() -> None:
38-
"""Platform-specific initialization of asyncio.
39-
40-
Sessions use the `add_reader()` API, which is only available in the
41-
SelectorEventLoop policy; Windows uses the Proactor by default.
42-
"""
43-
if sys.platform == "win32":
44-
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
45-
46-
4745
def initialize_fd_limit(limit: int) -> None:
4846
"""Raise the limit on open file descriptors.
4947

marimo/_session/extensions/extensions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,12 @@ async def _heartbeat() -> None:
144144
await _check_alive()
145145

146146
try:
147-
loop = asyncio.get_event_loop()
148-
self.heartbeat_task = loop.create_task(_heartbeat())
147+
loop = asyncio.get_running_loop()
149148
except RuntimeError:
150-
# This can happen if there is no event loop running
149+
# No loop (tests, scripts) — nothing to schedule against.
151150
self.heartbeat_task = None
151+
return
152+
self.heartbeat_task = loop.create_task(_heartbeat())
152153

153154
def _stop(self) -> None:
154155
"""Stop the heartbeat monitoring."""

0 commit comments

Comments
 (0)