diff --git a/mcp_server_code_execution_mode.py b/mcp_server_code_execution_mode.py index 1b8eabf..49a99c2 100644 --- a/mcp_server_code_execution_mode.py +++ b/mcp_server_code_execution_mode.py @@ -136,6 +136,11 @@ def _check_pydantic_compatibility() -> None: DEFAULT_PIDS = int(os.environ.get("MCP_BRIDGE_PIDS", "128")) DEFAULT_CPUS = os.environ.get("MCP_BRIDGE_CPUS") CONTAINER_USER = os.environ.get("MCP_BRIDGE_CONTAINER_USER", "65534:65534") +# Docker-from-Docker path remapping: when running inside a container whose +# filesystem paths differ from the Docker daemon's view, set this to +# "local_prefix:host_prefix" (e.g. "/tmp:/mnt/disks/data/tmp") so volume +# mount source paths are translated for the daemon. +_DOCKER_HOST_PATH_MAP = os.environ.get("MCP_BRIDGE_DOCKER_HOST_PATH_MAP", "") DEFAULT_RUNTIME_IDLE_TIMEOUT = int( os.environ.get("MCP_BRIDGE_RUNTIME_IDLE_TIMEOUT", "300") ) @@ -633,71 +638,113 @@ def __init__(self, server_info: MCPServerInfo) -> None: self._session: Optional[ClientSession] = None self._forward_task: Optional[asyncio.Task[None]] = None self._captured_stderr: Optional[io.TextIOBase] = None + self._host_task: Optional[asyncio.Task[None]] = None + self._shutdown_event: Optional[asyncio.Event] = None async def start(self) -> None: if self._session: return - params = StdioServerParameters( - command=self.server_info.command, - args=self.server_info.args, - env=self.server_info.env or None, - cwd=self.server_info.cwd or None, - ) + # Run context-manager entries in a dedicated asyncio task to isolate + # anyio cancel scopes from the caller's scope stack. Without this, + # stdio_client.__aenter__() and ClientSession.__aenter__() push cancel + # scopes onto the calling task, which breaks when the MCP server's + # request-handler responder tries to exit its own cancel scope. + ready = asyncio.Event() + startup_error: list[BaseException] = [] + self._shutdown_event = asyncio.Event() + + async def _host() -> None: + """Dedicated task that owns the cancel scopes for the MCP client.""" + session: Optional[ClientSession] = None + client_cm = None + try: + params = StdioServerParameters( + command=self.server_info.command, + args=self.server_info.args, + env=self.server_info.env or None, + cwd=self.server_info.cwd or None, + ) - # Capture stderr in a real file object for cross-platform compatibility - self._captured_stderr = tempfile.TemporaryFile(mode="w+t", encoding="utf-8") - # Only pass errlog if stdio_client supports it (tests may patch stdio_client) - if "errlog" in inspect.signature(stdio_client).parameters: - client_cm = stdio_client(params, errlog=self._captured_stderr) - else: - client_cm = stdio_client(params) - self._stdio_cm = client_cm - raw_read_stream, write_stream = await client_cm.__aenter__() + self._captured_stderr = tempfile.TemporaryFile( + mode="w+t", encoding="utf-8" + ) + if "errlog" in inspect.signature(stdio_client).parameters: + client_cm = stdio_client(params, errlog=self._captured_stderr) + else: + client_cm = stdio_client(params) + self._stdio_cm = client_cm + raw_read_stream, write_stream = await client_cm.__aenter__() - # Create a filtered reader stream to hide benign XML/blank-line JSON parse errors - filtered_writer, filtered_read = anyio.create_memory_object_stream(0) + filtered_writer, filtered_read = anyio.create_memory_object_stream(0) - async def _forward_read() -> None: - try: - async with filtered_writer: - async for item in raw_read_stream: - # Filter out JSON parse errors that are likely caused by stray blank lines - if isinstance(item, Exception): - message = str(item) - if ( - "Invalid JSON" in message - and "EOF while parsing a value" in message - and "input_value='\\n'" in message - ): - # ignore blank line parse errors - continue - await filtered_writer.send(item) - except anyio.ClosedResourceError: - await anyio.lowlevel.checkpoint() - - # Launch the forwarder task - self._forward_task = asyncio.create_task(_forward_read()) - - session = ClientSession(filtered_read, write_stream) - await session.__aenter__() - try: - await session.initialize() - except Exception as exc: # pragma: no cover - initialization failure reporting - # Read captured stderr content for diagnostics if present - stderr_text = "" - if self._captured_stderr is not None: + async def _forward_read() -> None: + try: + async with filtered_writer: + async for item in raw_read_stream: + if isinstance(item, Exception): + message = str(item) + if ( + "Invalid JSON" in message + and "EOF while parsing a value" in message + and "input_value='\\n'" in message + ): + continue + await filtered_writer.send(item) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + self._forward_task = asyncio.create_task(_forward_read()) + + session = ClientSession(filtered_read, write_stream) + await session.__aenter__() try: - self._captured_stderr.seek(0) - stderr_text = self._captured_stderr.read() - except Exception: - stderr_text = "" - logger.debug( - "Client session failed to initialize: %s (stderr=%s)", exc, stderr_text - ) - # Re-raise for callers to handle; captured stderr is useful for debugging - raise - self._session = session + await session.initialize() + except Exception as exc: + stderr_text = "" + if self._captured_stderr is not None: + try: + self._captured_stderr.seek(0) + stderr_text = self._captured_stderr.read() + except Exception: + stderr_text = "" + logger.debug( + "Client session failed to initialize: %s (stderr=%s)", + exc, + stderr_text, + ) + raise + + self._session = session + ready.set() + + # Keep task alive — its cancel scope stack owns the client + # connection. stop() signals this event to begin teardown. + assert self._shutdown_event is not None + await self._shutdown_event.wait() + + except BaseException as e: + if not ready.is_set(): + startup_error.append(e) + ready.set() + finally: + # Clean up in the same task that entered the scopes + self._session = None + if session is not None: + try: + await session.__aexit__(None, None, None) + except BaseException: + pass + if client_cm is not None: + try: + await client_cm.__aexit__(None, None, None) + except BaseException: + pass + + self._host_task = asyncio.create_task(_host()) + await ready.wait() + if startup_error: + raise startup_error[0] async def list_tools(self) -> List[Dict[str, object]]: if not self._session: @@ -718,27 +765,25 @@ async def call_tool( return call_result.model_dump(by_alias=True, exclude_none=True) async def stop(self) -> None: - if self._session: - try: - await self._session.__aexit__(None, None, None) - except* Exception as exc: # pragma: no cover - defensive cleanup - logger.debug("MCP session shutdown raised %s", exc, exc_info=True) - finally: - self._session = None - if self._stdio_cm: - try: - await self._stdio_cm.__aexit__(None, None, None) # type: ignore[union-attr] - except* Exception as exc: # pragma: no cover - defensive cleanup - logger.debug("MCP stdio shutdown raised %s", exc, exc_info=True) - finally: - self._stdio_cm = None - # Ensure the forwarder task is cancelled + # Cancel the forwarder task first if self._forward_task: task = self._forward_task self._forward_task = None task.cancel() with suppress(asyncio.CancelledError): await task + # Signal the host task to exit its context managers (session and + # stdio_client __aexit__ run in the same task that entered them, + # keeping the anyio cancel-scope stack consistent). + if self._shutdown_event: + self._shutdown_event.set() + if self._host_task: + task = self._host_task + self._host_task = None + with suppress(asyncio.CancelledError, Exception): + await task + self._session = None + self._stdio_cm = None if self._captured_stderr is not None: try: self._captured_stderr.close() @@ -1949,6 +1994,19 @@ def _filter_runtime_stderr(self, text: str) -> str: return "\n".join(filtered_lines).strip("\n") +def _remap_volume_source(path: str) -> str: + """Translate a local path to the Docker daemon's view when running DfD.""" + if not _DOCKER_HOST_PATH_MAP: + return path + parts = _DOCKER_HOST_PATH_MAP.split(":", 1) + if len(parts) != 2: + return path + local_prefix, host_prefix = parts + if path.startswith(local_prefix): + return host_prefix + path[len(local_prefix):] + return path + + def detect_runtime(preferred: Optional[str] = None) -> Optional[str]: """Return the first available container runtime, or None if not found.""" @@ -2016,8 +2074,8 @@ async def __aenter__(self) -> "SandboxInvocation": os.chmod(host_dir, 0o755) self.host_dir = host_dir - self.volume_mounts.append(f"{host_dir}:/ipc:rw") - self.volume_mounts.append(f"{user_tools_dir}:/projects:rw") + self.volume_mounts.append(f"{_remap_volume_source(str(host_dir))}:/ipc:rw") + self.volume_mounts.append(f"{_remap_volume_source(str(user_tools_dir))}:/projects:rw") self.container_env["MCP_AVAILABLE_SERVERS"] = json.dumps( self.server_metadata,