-
Notifications
You must be signed in to change notification settings - Fork 28
feat: add Docker-from-Docker path remapping and fix cancel-scope isolation #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -136,6 +136,11 @@ def _check_pydantic_compatibility() -> None: | |||||||||||||||
| DEFAULT_PIDS = int(os.environ.get("MCP_BRIDGE_PIDS", "128")) | ||||||||||||||||
| DEFAULT_CPUS = os.environ.get("MCP_BRIDGE_CPUS") | ||||||||||||||||
| CONTAINER_USER = os.environ.get("MCP_BRIDGE_CONTAINER_USER", "65534:65534") | ||||||||||||||||
| # Docker-from-Docker path remapping: when running inside a container whose | ||||||||||||||||
| # filesystem paths differ from the Docker daemon's view, set this to | ||||||||||||||||
| # "local_prefix:host_prefix" (e.g. "/tmp:/mnt/disks/data/tmp") so volume | ||||||||||||||||
| # mount source paths are translated for the daemon. | ||||||||||||||||
| _DOCKER_HOST_PATH_MAP = os.environ.get("MCP_BRIDGE_DOCKER_HOST_PATH_MAP", "") | ||||||||||||||||
| DEFAULT_RUNTIME_IDLE_TIMEOUT = int( | ||||||||||||||||
| os.environ.get("MCP_BRIDGE_RUNTIME_IDLE_TIMEOUT", "300") | ||||||||||||||||
| ) | ||||||||||||||||
|
|
@@ -633,71 +638,113 @@ def __init__(self, server_info: MCPServerInfo) -> None: | |||||||||||||||
| self._session: Optional[ClientSession] = None | ||||||||||||||||
| self._forward_task: Optional[asyncio.Task[None]] = None | ||||||||||||||||
| self._captured_stderr: Optional[io.TextIOBase] = None | ||||||||||||||||
| self._host_task: Optional[asyncio.Task[None]] = None | ||||||||||||||||
| self._shutdown_event: Optional[asyncio.Event] = None | ||||||||||||||||
|
|
||||||||||||||||
| async def start(self) -> None: | ||||||||||||||||
| if self._session: | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| params = StdioServerParameters( | ||||||||||||||||
| command=self.server_info.command, | ||||||||||||||||
| args=self.server_info.args, | ||||||||||||||||
| env=self.server_info.env or None, | ||||||||||||||||
| cwd=self.server_info.cwd or None, | ||||||||||||||||
| ) | ||||||||||||||||
| # Run context-manager entries in a dedicated asyncio task to isolate | ||||||||||||||||
| # anyio cancel scopes from the caller's scope stack. Without this, | ||||||||||||||||
| # stdio_client.__aenter__() and ClientSession.__aenter__() push cancel | ||||||||||||||||
| # scopes onto the calling task, which breaks when the MCP server's | ||||||||||||||||
| # request-handler responder tries to exit its own cancel scope. | ||||||||||||||||
| ready = asyncio.Event() | ||||||||||||||||
| startup_error: list[BaseException] = [] | ||||||||||||||||
| self._shutdown_event = asyncio.Event() | ||||||||||||||||
|
|
||||||||||||||||
| async def _host() -> None: | ||||||||||||||||
| """Dedicated task that owns the cancel scopes for the MCP client.""" | ||||||||||||||||
| session: Optional[ClientSession] = None | ||||||||||||||||
| client_cm = None | ||||||||||||||||
| try: | ||||||||||||||||
| params = StdioServerParameters( | ||||||||||||||||
| command=self.server_info.command, | ||||||||||||||||
| args=self.server_info.args, | ||||||||||||||||
| env=self.server_info.env or None, | ||||||||||||||||
| cwd=self.server_info.cwd or None, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| # Capture stderr in a real file object for cross-platform compatibility | ||||||||||||||||
| self._captured_stderr = tempfile.TemporaryFile(mode="w+t", encoding="utf-8") | ||||||||||||||||
| # Only pass errlog if stdio_client supports it (tests may patch stdio_client) | ||||||||||||||||
| if "errlog" in inspect.signature(stdio_client).parameters: | ||||||||||||||||
| client_cm = stdio_client(params, errlog=self._captured_stderr) | ||||||||||||||||
| else: | ||||||||||||||||
| client_cm = stdio_client(params) | ||||||||||||||||
| self._stdio_cm = client_cm | ||||||||||||||||
| raw_read_stream, write_stream = await client_cm.__aenter__() | ||||||||||||||||
| self._captured_stderr = tempfile.TemporaryFile( | ||||||||||||||||
| mode="w+t", encoding="utf-8" | ||||||||||||||||
| ) | ||||||||||||||||
| if "errlog" in inspect.signature(stdio_client).parameters: | ||||||||||||||||
| client_cm = stdio_client(params, errlog=self._captured_stderr) | ||||||||||||||||
| else: | ||||||||||||||||
| client_cm = stdio_client(params) | ||||||||||||||||
| self._stdio_cm = client_cm | ||||||||||||||||
| raw_read_stream, write_stream = await client_cm.__aenter__() | ||||||||||||||||
|
|
||||||||||||||||
| # Create a filtered reader stream to hide benign XML/blank-line JSON parse errors | ||||||||||||||||
| filtered_writer, filtered_read = anyio.create_memory_object_stream(0) | ||||||||||||||||
| filtered_writer, filtered_read = anyio.create_memory_object_stream(0) | ||||||||||||||||
|
|
||||||||||||||||
| async def _forward_read() -> None: | ||||||||||||||||
| try: | ||||||||||||||||
| async with filtered_writer: | ||||||||||||||||
| async for item in raw_read_stream: | ||||||||||||||||
| # Filter out JSON parse errors that are likely caused by stray blank lines | ||||||||||||||||
| if isinstance(item, Exception): | ||||||||||||||||
| message = str(item) | ||||||||||||||||
| if ( | ||||||||||||||||
| "Invalid JSON" in message | ||||||||||||||||
| and "EOF while parsing a value" in message | ||||||||||||||||
| and "input_value='\\n'" in message | ||||||||||||||||
| ): | ||||||||||||||||
| # ignore blank line parse errors | ||||||||||||||||
| continue | ||||||||||||||||
| await filtered_writer.send(item) | ||||||||||||||||
| except anyio.ClosedResourceError: | ||||||||||||||||
| await anyio.lowlevel.checkpoint() | ||||||||||||||||
|
|
||||||||||||||||
| # Launch the forwarder task | ||||||||||||||||
| self._forward_task = asyncio.create_task(_forward_read()) | ||||||||||||||||
|
|
||||||||||||||||
| session = ClientSession(filtered_read, write_stream) | ||||||||||||||||
| await session.__aenter__() | ||||||||||||||||
| try: | ||||||||||||||||
| await session.initialize() | ||||||||||||||||
| except Exception as exc: # pragma: no cover - initialization failure reporting | ||||||||||||||||
| # Read captured stderr content for diagnostics if present | ||||||||||||||||
| stderr_text = "" | ||||||||||||||||
| if self._captured_stderr is not None: | ||||||||||||||||
| async def _forward_read() -> None: | ||||||||||||||||
| try: | ||||||||||||||||
| async with filtered_writer: | ||||||||||||||||
| async for item in raw_read_stream: | ||||||||||||||||
| if isinstance(item, Exception): | ||||||||||||||||
| message = str(item) | ||||||||||||||||
| if ( | ||||||||||||||||
| "Invalid JSON" in message | ||||||||||||||||
| and "EOF while parsing a value" in message | ||||||||||||||||
| and "input_value='\\n'" in message | ||||||||||||||||
| ): | ||||||||||||||||
| continue | ||||||||||||||||
| await filtered_writer.send(item) | ||||||||||||||||
| except anyio.ClosedResourceError: | ||||||||||||||||
| await anyio.lowlevel.checkpoint() | ||||||||||||||||
|
|
||||||||||||||||
| self._forward_task = asyncio.create_task(_forward_read()) | ||||||||||||||||
|
|
||||||||||||||||
| session = ClientSession(filtered_read, write_stream) | ||||||||||||||||
| await session.__aenter__() | ||||||||||||||||
| try: | ||||||||||||||||
| self._captured_stderr.seek(0) | ||||||||||||||||
| stderr_text = self._captured_stderr.read() | ||||||||||||||||
| except Exception: | ||||||||||||||||
| stderr_text = "<failed to read captured stderr>" | ||||||||||||||||
| logger.debug( | ||||||||||||||||
| "Client session failed to initialize: %s (stderr=%s)", exc, stderr_text | ||||||||||||||||
| ) | ||||||||||||||||
| # Re-raise for callers to handle; captured stderr is useful for debugging | ||||||||||||||||
| raise | ||||||||||||||||
| self._session = session | ||||||||||||||||
| await session.initialize() | ||||||||||||||||
| except Exception as exc: | ||||||||||||||||
| stderr_text = "" | ||||||||||||||||
| if self._captured_stderr is not None: | ||||||||||||||||
| try: | ||||||||||||||||
| self._captured_stderr.seek(0) | ||||||||||||||||
| stderr_text = self._captured_stderr.read() | ||||||||||||||||
| except Exception: | ||||||||||||||||
| stderr_text = "<failed to read captured stderr>" | ||||||||||||||||
| logger.debug( | ||||||||||||||||
| "Client session failed to initialize: %s (stderr=%s)", | ||||||||||||||||
| exc, | ||||||||||||||||
| stderr_text, | ||||||||||||||||
| ) | ||||||||||||||||
| raise | ||||||||||||||||
|
|
||||||||||||||||
| self._session = session | ||||||||||||||||
| ready.set() | ||||||||||||||||
|
|
||||||||||||||||
| # Keep task alive — its cancel scope stack owns the client | ||||||||||||||||
| # connection. stop() signals this event to begin teardown. | ||||||||||||||||
| assert self._shutdown_event is not None | ||||||||||||||||
| await self._shutdown_event.wait() | ||||||||||||||||
|
|
||||||||||||||||
| except BaseException as e: | ||||||||||||||||
| if not ready.is_set(): | ||||||||||||||||
| startup_error.append(e) | ||||||||||||||||
| ready.set() | ||||||||||||||||
| finally: | ||||||||||||||||
| # Clean up in the same task that entered the scopes | ||||||||||||||||
| self._session = None | ||||||||||||||||
| if session is not None: | ||||||||||||||||
| try: | ||||||||||||||||
| await session.__aexit__(None, None, None) | ||||||||||||||||
| except BaseException: | ||||||||||||||||
| pass | ||||||||||||||||
| if client_cm is not None: | ||||||||||||||||
| try: | ||||||||||||||||
| await client_cm.__aexit__(None, None, None) | ||||||||||||||||
| except BaseException: | ||||||||||||||||
| pass | ||||||||||||||||
|
|
||||||||||||||||
| self._host_task = asyncio.create_task(_host()) | ||||||||||||||||
| await ready.wait() | ||||||||||||||||
| if startup_error: | ||||||||||||||||
| raise startup_error[0] | ||||||||||||||||
|
|
||||||||||||||||
| async def list_tools(self) -> List[Dict[str, object]]: | ||||||||||||||||
| if not self._session: | ||||||||||||||||
|
|
@@ -718,27 +765,25 @@ async def call_tool( | |||||||||||||||
| return call_result.model_dump(by_alias=True, exclude_none=True) | ||||||||||||||||
|
|
||||||||||||||||
| async def stop(self) -> None: | ||||||||||||||||
| if self._session: | ||||||||||||||||
| try: | ||||||||||||||||
| await self._session.__aexit__(None, None, None) | ||||||||||||||||
| except* Exception as exc: # pragma: no cover - defensive cleanup | ||||||||||||||||
| logger.debug("MCP session shutdown raised %s", exc, exc_info=True) | ||||||||||||||||
| finally: | ||||||||||||||||
| self._session = None | ||||||||||||||||
| if self._stdio_cm: | ||||||||||||||||
| try: | ||||||||||||||||
| await self._stdio_cm.__aexit__(None, None, None) # type: ignore[union-attr] | ||||||||||||||||
| except* Exception as exc: # pragma: no cover - defensive cleanup | ||||||||||||||||
| logger.debug("MCP stdio shutdown raised %s", exc, exc_info=True) | ||||||||||||||||
| finally: | ||||||||||||||||
| self._stdio_cm = None | ||||||||||||||||
| # Ensure the forwarder task is cancelled | ||||||||||||||||
| # Cancel the forwarder task first | ||||||||||||||||
| if self._forward_task: | ||||||||||||||||
| task = self._forward_task | ||||||||||||||||
| self._forward_task = None | ||||||||||||||||
| task.cancel() | ||||||||||||||||
| with suppress(asyncio.CancelledError): | ||||||||||||||||
| await task | ||||||||||||||||
| # Signal the host task to exit its context managers (session and | ||||||||||||||||
| # stdio_client __aexit__ run in the same task that entered them, | ||||||||||||||||
| # keeping the anyio cancel-scope stack consistent). | ||||||||||||||||
| if self._shutdown_event: | ||||||||||||||||
| self._shutdown_event.set() | ||||||||||||||||
| if self._host_task: | ||||||||||||||||
| task = self._host_task | ||||||||||||||||
| self._host_task = None | ||||||||||||||||
| with suppress(asyncio.CancelledError, Exception): | ||||||||||||||||
| await task | ||||||||||||||||
| self._session = None | ||||||||||||||||
| self._stdio_cm = None | ||||||||||||||||
| if self._captured_stderr is not None: | ||||||||||||||||
| try: | ||||||||||||||||
| self._captured_stderr.close() | ||||||||||||||||
|
|
@@ -1949,6 +1994,19 @@ def _filter_runtime_stderr(self, text: str) -> str: | |||||||||||||||
| return "\n".join(filtered_lines).strip("\n") | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def _remap_volume_source(path: str) -> str: | ||||||||||||||||
| """Translate a local path to the Docker daemon's view when running DfD.""" | ||||||||||||||||
| if not _DOCKER_HOST_PATH_MAP: | ||||||||||||||||
| return path | ||||||||||||||||
| parts = _DOCKER_HOST_PATH_MAP.split(":", 1) | ||||||||||||||||
| if len(parts) != 2: | ||||||||||||||||
| return path | ||||||||||||||||
| local_prefix, host_prefix = parts | ||||||||||||||||
| if path.startswith(local_prefix): | ||||||||||||||||
| return host_prefix + path[len(local_prefix):] | ||||||||||||||||
|
Comment on lines
+2005
to
+2006
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current path remapping logic uses a simple string prefix check, which can lead to incorrect translations if a path shares a prefix but is not a sub-directory (e.g., remapping
Suggested change
|
||||||||||||||||
| return path | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def detect_runtime(preferred: Optional[str] = None) -> Optional[str]: | ||||||||||||||||
| """Return the first available container runtime, or None if not found.""" | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -2016,8 +2074,8 @@ async def __aenter__(self) -> "SandboxInvocation": | |||||||||||||||
| os.chmod(host_dir, 0o755) | ||||||||||||||||
| self.host_dir = host_dir | ||||||||||||||||
|
|
||||||||||||||||
| self.volume_mounts.append(f"{host_dir}:/ipc:rw") | ||||||||||||||||
| self.volume_mounts.append(f"{user_tools_dir}:/projects:rw") | ||||||||||||||||
| self.volume_mounts.append(f"{_remap_volume_source(str(host_dir))}:/ipc:rw") | ||||||||||||||||
| self.volume_mounts.append(f"{_remap_volume_source(str(user_tools_dir))}:/projects:rw") | ||||||||||||||||
|
|
||||||||||||||||
| self.container_env["MCP_AVAILABLE_SERVERS"] = json.dumps( | ||||||||||||||||
| self.server_metadata, | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
PersistentMCPClientshould be protected against concurrentstart()calls. Sincestart()now spawns a background_hosttask and initializes several asynchronous resources, multiple concurrent calls could lead to race conditions where multiple tasks are created or state (like_shutdown_event) is overwritten. It is recommended to add anasyncio.Lockand wrap the body ofstart()withasync with self._start_lock:to ensure only one initialization sequence occurs at a time.