Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 131 additions & 73 deletions mcp_server_code_execution_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def _check_pydantic_compatibility() -> None:
DEFAULT_PIDS = int(os.environ.get("MCP_BRIDGE_PIDS", "128"))
DEFAULT_CPUS = os.environ.get("MCP_BRIDGE_CPUS")
CONTAINER_USER = os.environ.get("MCP_BRIDGE_CONTAINER_USER", "65534:65534")
# Docker-from-Docker path remapping: when running inside a container whose
# filesystem paths differ from the Docker daemon's view, set this to
# "local_prefix:host_prefix" (e.g. "/tmp:/mnt/disks/data/tmp") so volume
# mount source paths are translated for the daemon.
_DOCKER_HOST_PATH_MAP = os.environ.get("MCP_BRIDGE_DOCKER_HOST_PATH_MAP", "")
DEFAULT_RUNTIME_IDLE_TIMEOUT = int(
os.environ.get("MCP_BRIDGE_RUNTIME_IDLE_TIMEOUT", "300")
)
Expand Down Expand Up @@ -633,71 +638,113 @@ def __init__(self, server_info: MCPServerInfo) -> None:
self._session: Optional[ClientSession] = None
self._forward_task: Optional[asyncio.Task[None]] = None
self._captured_stderr: Optional[io.TextIOBase] = None
self._host_task: Optional[asyncio.Task[None]] = None
self._shutdown_event: Optional[asyncio.Event] = None
Comment on lines +641 to +642
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The PersistentMCPClient should be protected against concurrent start() calls. Since start() now spawns a background _host task and initializes several asynchronous resources, multiple concurrent calls could lead to race conditions where multiple tasks are created or state (like _shutdown_event) is overwritten. It is recommended to add an asyncio.Lock and wrap the body of start() with async with self._start_lock: to ensure only one initialization sequence occurs at a time.

Suggested change
self._host_task: Optional[asyncio.Task[None]] = None
self._shutdown_event: Optional[asyncio.Event] = None
self._host_task: Optional[asyncio.Task[None]] = None
self._shutdown_event: Optional[asyncio.Event] = None
self._start_lock = asyncio.Lock()


async def start(self) -> None:
if self._session:
return

params = StdioServerParameters(
command=self.server_info.command,
args=self.server_info.args,
env=self.server_info.env or None,
cwd=self.server_info.cwd or None,
)
# Run context-manager entries in a dedicated asyncio task to isolate
# anyio cancel scopes from the caller's scope stack. Without this,
# stdio_client.__aenter__() and ClientSession.__aenter__() push cancel
# scopes onto the calling task, which breaks when the MCP server's
# request-handler responder tries to exit its own cancel scope.
ready = asyncio.Event()
startup_error: list[BaseException] = []
self._shutdown_event = asyncio.Event()

async def _host() -> None:
"""Dedicated task that owns the cancel scopes for the MCP client."""
session: Optional[ClientSession] = None
client_cm = None
try:
params = StdioServerParameters(
command=self.server_info.command,
args=self.server_info.args,
env=self.server_info.env or None,
cwd=self.server_info.cwd or None,
)

# Capture stderr in a real file object for cross-platform compatibility
self._captured_stderr = tempfile.TemporaryFile(mode="w+t", encoding="utf-8")
# Only pass errlog if stdio_client supports it (tests may patch stdio_client)
if "errlog" in inspect.signature(stdio_client).parameters:
client_cm = stdio_client(params, errlog=self._captured_stderr)
else:
client_cm = stdio_client(params)
self._stdio_cm = client_cm
raw_read_stream, write_stream = await client_cm.__aenter__()
self._captured_stderr = tempfile.TemporaryFile(
mode="w+t", encoding="utf-8"
)
if "errlog" in inspect.signature(stdio_client).parameters:
client_cm = stdio_client(params, errlog=self._captured_stderr)
else:
client_cm = stdio_client(params)
self._stdio_cm = client_cm
raw_read_stream, write_stream = await client_cm.__aenter__()

# Create a filtered reader stream to hide benign XML/blank-line JSON parse errors
filtered_writer, filtered_read = anyio.create_memory_object_stream(0)
filtered_writer, filtered_read = anyio.create_memory_object_stream(0)

async def _forward_read() -> None:
try:
async with filtered_writer:
async for item in raw_read_stream:
# Filter out JSON parse errors that are likely caused by stray blank lines
if isinstance(item, Exception):
message = str(item)
if (
"Invalid JSON" in message
and "EOF while parsing a value" in message
and "input_value='\\n'" in message
):
# ignore blank line parse errors
continue
await filtered_writer.send(item)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

# Launch the forwarder task
self._forward_task = asyncio.create_task(_forward_read())

session = ClientSession(filtered_read, write_stream)
await session.__aenter__()
try:
await session.initialize()
except Exception as exc: # pragma: no cover - initialization failure reporting
# Read captured stderr content for diagnostics if present
stderr_text = ""
if self._captured_stderr is not None:
async def _forward_read() -> None:
try:
async with filtered_writer:
async for item in raw_read_stream:
if isinstance(item, Exception):
message = str(item)
if (
"Invalid JSON" in message
and "EOF while parsing a value" in message
and "input_value='\\n'" in message
):
continue
await filtered_writer.send(item)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

self._forward_task = asyncio.create_task(_forward_read())

session = ClientSession(filtered_read, write_stream)
await session.__aenter__()
try:
self._captured_stderr.seek(0)
stderr_text = self._captured_stderr.read()
except Exception:
stderr_text = "<failed to read captured stderr>"
logger.debug(
"Client session failed to initialize: %s (stderr=%s)", exc, stderr_text
)
# Re-raise for callers to handle; captured stderr is useful for debugging
raise
self._session = session
await session.initialize()
except Exception as exc:
stderr_text = ""
if self._captured_stderr is not None:
try:
self._captured_stderr.seek(0)
stderr_text = self._captured_stderr.read()
except Exception:
stderr_text = "<failed to read captured stderr>"
logger.debug(
"Client session failed to initialize: %s (stderr=%s)",
exc,
stderr_text,
)
raise

self._session = session
ready.set()

# Keep task alive — its cancel scope stack owns the client
# connection. stop() signals this event to begin teardown.
assert self._shutdown_event is not None
await self._shutdown_event.wait()

except BaseException as e:
if not ready.is_set():
startup_error.append(e)
ready.set()
finally:
# Clean up in the same task that entered the scopes
self._session = None
if session is not None:
try:
await session.__aexit__(None, None, None)
except BaseException:
pass
if client_cm is not None:
try:
await client_cm.__aexit__(None, None, None)
except BaseException:
pass

self._host_task = asyncio.create_task(_host())
await ready.wait()
if startup_error:
raise startup_error[0]

async def list_tools(self) -> List[Dict[str, object]]:
if not self._session:
Expand All @@ -718,27 +765,25 @@ async def call_tool(
return call_result.model_dump(by_alias=True, exclude_none=True)

async def stop(self) -> None:
if self._session:
try:
await self._session.__aexit__(None, None, None)
except* Exception as exc: # pragma: no cover - defensive cleanup
logger.debug("MCP session shutdown raised %s", exc, exc_info=True)
finally:
self._session = None
if self._stdio_cm:
try:
await self._stdio_cm.__aexit__(None, None, None) # type: ignore[union-attr]
except* Exception as exc: # pragma: no cover - defensive cleanup
logger.debug("MCP stdio shutdown raised %s", exc, exc_info=True)
finally:
self._stdio_cm = None
# Ensure the forwarder task is cancelled
# Cancel the forwarder task first
if self._forward_task:
task = self._forward_task
self._forward_task = None
task.cancel()
with suppress(asyncio.CancelledError):
await task
# Signal the host task to exit its context managers (session and
# stdio_client __aexit__ run in the same task that entered them,
# keeping the anyio cancel-scope stack consistent).
if self._shutdown_event:
self._shutdown_event.set()
if self._host_task:
task = self._host_task
self._host_task = None
with suppress(asyncio.CancelledError, Exception):
await task
self._session = None
self._stdio_cm = None
if self._captured_stderr is not None:
try:
self._captured_stderr.close()
Expand Down Expand Up @@ -1949,6 +1994,19 @@ def _filter_runtime_stderr(self, text: str) -> str:
return "\n".join(filtered_lines).strip("\n")


def _remap_volume_source(path: str) -> str:
"""Translate a local path to the Docker daemon's view when running DfD."""
if not _DOCKER_HOST_PATH_MAP:
return path
parts = _DOCKER_HOST_PATH_MAP.split(":", 1)
if len(parts) != 2:
return path
local_prefix, host_prefix = parts
if path.startswith(local_prefix):
return host_prefix + path[len(local_prefix):]
Comment on lines +2005 to +2006
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current path remapping logic uses a simple string prefix check, which can lead to incorrect translations if a path shares a prefix but is not a sub-directory (e.g., remapping /tmp to /mnt/tmp would incorrectly translate /tmpp/file to /mnt/tmpp/file). Using Path.relative_to ensures that remapping only occurs for actual sub-paths and correctly respects directory boundaries.

Suggested change
if path.startswith(local_prefix):
return host_prefix + path[len(local_prefix):]
try:
rel_path = Path(path).relative_to(local_prefix)
return str(Path(host_prefix) / rel_path)
except ValueError:
return path

return path


def detect_runtime(preferred: Optional[str] = None) -> Optional[str]:
"""Return the first available container runtime, or None if not found."""

Expand Down Expand Up @@ -2016,8 +2074,8 @@ async def __aenter__(self) -> "SandboxInvocation":
os.chmod(host_dir, 0o755)
self.host_dir = host_dir

self.volume_mounts.append(f"{host_dir}:/ipc:rw")
self.volume_mounts.append(f"{user_tools_dir}:/projects:rw")
self.volume_mounts.append(f"{_remap_volume_source(str(host_dir))}:/ipc:rw")
self.volume_mounts.append(f"{_remap_volume_source(str(user_tools_dir))}:/projects:rw")

self.container_env["MCP_AVAILABLE_SERVERS"] = json.dumps(
self.server_metadata,
Expand Down