|
20 | 20 | MCPTransportRegistry, |
21 | 21 | ) |
22 | 22 | from marimo._server.ai.mcp.types import MCPToolArgs |
| 23 | +from marimo._utils.asyncio_utils import cancel_and_wait, supervised_task |
23 | 24 |
|
24 | 25 | if TYPE_CHECKING: |
25 | 26 | from anyio.streams.memory import ( |
@@ -262,8 +263,9 @@ async def _connection_lifecycle(self, server_name: str) -> None: |
262 | 263 | await self._discover_tools(connection) |
263 | 264 |
|
264 | 265 | if server_name not in self.health_check_tasks: |
265 | | - self.health_check_tasks[server_name] = asyncio.create_task( |
266 | | - self._monitor_server_health(server_name) |
| 266 | + self.health_check_tasks[server_name] = supervised_task( |
| 267 | + self._monitor_server_health(server_name), |
| 268 | + name=f"mcp.health.{server_name}", |
267 | 269 | ) |
268 | 270 |
|
269 | 271 | # Signal that connection is established |
@@ -331,9 +333,12 @@ async def connect_to_server(self, server_name: str) -> bool: |
331 | 333 | self._update_server_status(server_name, MCPServerStatus.CONNECTING) |
332 | 334 | self._remove_server_tools(server_name) |
333 | 335 |
|
334 | | - # Create task to run existing connection logic |
| 336 | + # Create task to run existing connection logic. Not supervised: |
| 337 | + # this task is awaited in disconnect_from_server(), so supervisor |
| 338 | + # logging would duplicate the awaiter's error handling. |
335 | 339 | connection_task = asyncio.create_task( |
336 | | - self._connection_lifecycle(server_name) |
| 340 | + self._connection_lifecycle(server_name), |
| 341 | + name=f"mcp.lifecycle.{server_name}", |
337 | 342 | ) |
338 | 343 | connection.connection_task = connection_task |
339 | 344 |
|
@@ -790,12 +795,7 @@ async def _cancel_health_monitoring( |
790 | 795 | if server_name is not None: |
791 | 796 | # Cancel single server monitoring |
792 | 797 | if server_name in self.health_check_tasks: |
793 | | - task = self.health_check_tasks[server_name] |
794 | | - task.cancel() |
795 | | - try: |
796 | | - await task |
797 | | - except asyncio.CancelledError: |
798 | | - pass |
| 798 | + await cancel_and_wait(self.health_check_tasks[server_name]) |
799 | 799 | del self.health_check_tasks[server_name] |
800 | 800 | LOGGER.debug(f"Cancelled health monitoring for {server_name}") |
801 | 801 | else: |
|
0 commit comments