Skip to content

Commit 13c3f53

Browse files
authored
chore(asyncio): migrate remaining background-task sites to asyncio_utils (#9596)
1 parent cdf6025 commit 13c3f53

5 files changed

Lines changed: 24 additions & 22 deletions

File tree

marimo/_server/ai/mcp/client.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MCPTransportRegistry,
2121
)
2222
from marimo._server.ai.mcp.types import MCPToolArgs
23+
from marimo._utils.asyncio_utils import cancel_and_wait, supervised_task
2324

2425
if TYPE_CHECKING:
2526
from anyio.streams.memory import (
@@ -262,8 +263,9 @@ async def _connection_lifecycle(self, server_name: str) -> None:
262263
await self._discover_tools(connection)
263264

264265
if server_name not in self.health_check_tasks:
265-
self.health_check_tasks[server_name] = asyncio.create_task(
266-
self._monitor_server_health(server_name)
266+
self.health_check_tasks[server_name] = supervised_task(
267+
self._monitor_server_health(server_name),
268+
name=f"mcp.health.{server_name}",
267269
)
268270

269271
# Signal that connection is established
@@ -331,9 +333,12 @@ async def connect_to_server(self, server_name: str) -> bool:
331333
self._update_server_status(server_name, MCPServerStatus.CONNECTING)
332334
self._remove_server_tools(server_name)
333335

334-
# Create task to run existing connection logic
336+
# Create task to run existing connection logic. Not supervised:
337+
# this task is awaited in disconnect_from_server(), so supervisor
338+
# logging would duplicate the awaiter's error handling.
335339
connection_task = asyncio.create_task(
336-
self._connection_lifecycle(server_name)
340+
self._connection_lifecycle(server_name),
341+
name=f"mcp.lifecycle.{server_name}",
337342
)
338343
connection.connection_task = connection_task
339344

@@ -790,12 +795,7 @@ async def _cancel_health_monitoring(
790795
if server_name is not None:
791796
# Cancel single server monitoring
792797
if server_name in self.health_check_tasks:
793-
task = self.health_check_tasks[server_name]
794-
task.cancel()
795-
try:
796-
await task
797-
except asyncio.CancelledError:
798-
pass
798+
await cancel_and_wait(self.health_check_tasks[server_name])
799799
del self.health_check_tasks[server_name]
800800
LOGGER.debug(f"Cancelled health monitoring for {server_name}")
801801
else:

marimo/_server/api/endpoints/terminal.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from marimo._server.codes import WebSocketCodes
2020
from marimo._server.router import APIRouter
2121
from marimo._session.model import SessionMode
22+
from marimo._utils.asyncio_utils import cancel_and_wait
2223
from marimo._utils.platform import is_pyodide, is_windows
2324

2425
if TYPE_CHECKING:
@@ -331,12 +332,7 @@ async def _write_to_pty(
331332

332333
async def _cancel_tasks(tasks: Iterable[asyncio.Task[Any]]) -> None:
333334
for task in tasks:
334-
if not task.done():
335-
task.cancel()
336-
try:
337-
await task
338-
except asyncio.CancelledError:
339-
pass
335+
await cancel_and_wait(task)
340336

341337

342338
def supports_terminal() -> bool:

marimo/_server/api/middleware.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from marimo._server.uvicorn_utils import close_uvicorn
4242
from marimo._session.model import SessionMode
4343
from marimo._tracer import server_tracer
44+
from marimo._utils.asyncio_utils import supervised_task
4445
from marimo._utils.print import print_tabbed
4546

4647
if TYPE_CHECKING:
@@ -700,8 +701,9 @@ def __init__(
700701
self.app_state.timeout_tracker = time.time()
701702
self.timeout_duration_minutes = timeout_duration_minutes
702703

703-
# Hold a strong reference so the monitor task isn't GC'd.
704-
self._monitor_task = asyncio.create_task(self.monitor())
704+
self._monitor_task = supervised_task(
705+
self.monitor(), name="timeout.monitor"
706+
)
705707

706708
async def __call__(
707709
self, scope: Scope, receive: Receive, send: Send

marimo/_server/lsp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
LspServerStatus,
2626
)
2727
from marimo._tracer import server_tracer
28+
from marimo._utils.asyncio_utils import supervised_task
2829
from marimo._utils.net import find_free_port
2930
from marimo._utils.paths import marimo_package_path
3031
from marimo._utils.platform import is_windows
@@ -178,8 +179,9 @@ async def _start_internal(self) -> AlertNotification | None:
178179
self._health_check_task is None
179180
or self._health_check_task.done()
180181
):
181-
self._health_check_task = asyncio.create_task(
182-
self._monitor_process_health()
182+
self._health_check_task = supervised_task(
183+
self._monitor_process_health(),
184+
name=f"lsp.health.{self.id}",
183185
)
184186

185187
except Exception as e:

marimo/_server/rtc/doc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from marimo import _loggers
88
from marimo._server.workspace import FileKey
99
from marimo._types.ids import CellId_t
10+
from marimo._utils.asyncio_utils import supervised_task
1011

1112
if TYPE_CHECKING:
1213
from loro import LoroDoc
@@ -143,8 +144,9 @@ async def remove_client(
143144

144145
# Create the cleaner task outside the lock to avoid deadlocks
145146
if should_create_cleaner:
146-
self.loro_docs_cleaners[file_key] = asyncio.create_task(
147-
self._clean_loro_doc(file_key, 60.0)
147+
self.loro_docs_cleaners[file_key] = supervised_task(
148+
self._clean_loro_doc(file_key, 60.0),
149+
name=f"rtc.cleaner.{file_key}",
148150
)
149151

150152
async def _do_remove_doc(self, file_key: FileKey) -> None:

0 commit comments

Comments
 (0)