|
40 | 40 |
|
41 | 41 | from __future__ import annotations |
42 | 42 |
|
| 43 | +import asyncio |
43 | 44 | import logging |
44 | 45 | import shutil |
45 | 46 | from collections.abc import Awaitable, Callable |
46 | 47 | from typing import Any |
47 | 48 |
|
48 | 49 | from amplifier_core import ChatResponse, ToolCall |
49 | 50 |
|
50 | | -from .client import AuthStatus, SessionInfo, SessionListResult |
| 51 | +from ._constants import DEFAULT_TIMEOUT |
| 52 | +from .client import AuthStatus, CopilotClientWrapper, SessionInfo, SessionListResult |
51 | 53 | from .exceptions import ( |
52 | 54 | CopilotAbortError, |
53 | 55 | CopilotAuthenticationError, |
|
121 | 123 |
|
122 | 124 | logger = logging.getLogger(__name__) |
123 | 125 |
|
| 126 | +# ═══════════════════════════════════════════════════════════════════════════════ |
| 127 | +# Process-Level Singleton State |
| 128 | +# ═══════════════════════════════════════════════════════════════════════════════ |
| 129 | +# |
| 130 | +# Sub-agents spawned by the task tool run as async coroutines in the SAME |
| 131 | +# Python process and asyncio event loop as the parent session (kernel-guaranteed). |
| 132 | +# Each sub-agent gets its own fresh ModuleCoordinator — coordinators are not shared. |
| 133 | +# |
| 134 | +# This singleton ensures all mounts in a process share ONE CopilotClientWrapper |
| 135 | +# (one copilot CLI subprocess) regardless of how many sub-agents are spawned. |
| 136 | +# Without this, N sub-agents spawn N processes × ~500 MB each. |
| 137 | +# |
| 138 | +# Reference: docs/plans/2026-02-23-process-singleton-design.md |
| 139 | + |
| 140 | +_shared_client: CopilotClientWrapper | None = None |
| 141 | +_shared_client_refcount: int = 0 |
| 142 | +_shared_client_lock: asyncio.Lock | None = None |
| 143 | + |
| 144 | + |
| 145 | +def _get_lock() -> asyncio.Lock: |
| 146 | + """Return the singleton lock, creating it lazily on first call. |
| 147 | +
|
| 148 | + Lazy initialization avoids creating asyncio.Lock at import time, |
| 149 | + which can fail if no event loop exists yet (common in test environments). |
| 150 | + """ |
| 151 | + global _shared_client_lock |
| 152 | + if _shared_client_lock is None: |
| 153 | + _shared_client_lock = asyncio.Lock() |
| 154 | + return _shared_client_lock |
| 155 | + |
| 156 | + |
| 157 | +async def _acquire_shared_client( |
| 158 | + config: dict[str, Any], |
| 159 | + timeout: float, |
| 160 | +) -> CopilotClientWrapper: |
| 161 | + """Acquire the shared CopilotClientWrapper, creating it if this is the first mount. |
| 162 | +
|
| 163 | + Increments the reference count. Call _release_shared_client() in cleanup |
| 164 | + to decrement. The subprocess is shut down when the count reaches zero. |
| 165 | +
|
| 166 | + If a second caller passes a different timeout than the first, a DEBUG warning |
| 167 | + is logged and the existing client is returned unchanged — the second caller's |
| 168 | + timeout is silently ignored. Sub-agents inherit bundle config from the parent, |
| 169 | + so all callers typically pass the same values. |
| 170 | + """ |
| 171 | + global _shared_client, _shared_client_refcount |
| 172 | + async with _get_lock(): |
| 173 | + if _shared_client is None: |
| 174 | + logger.info("[MOUNT] Creating shared Copilot subprocess (first mount in process)") |
| 175 | + _shared_client = CopilotClientWrapper(config=config, timeout=timeout) |
| 176 | + else: |
| 177 | + existing_timeout = getattr(_shared_client, "_timeout", timeout) |
| 178 | + if existing_timeout != timeout: |
| 179 | + logger.debug( |
| 180 | + f"[MOUNT] Ignoring timeout={timeout} for shared client " |
| 181 | + f"(already created with timeout={existing_timeout})" |
| 182 | + ) |
| 183 | + _shared_client_refcount += 1 |
| 184 | + logger.debug(f"[MOUNT] Shared client refcount: {_shared_client_refcount}") |
| 185 | + return _shared_client |
| 186 | + |
| 187 | + |
| 188 | +async def _release_shared_client() -> None: |
| 189 | + """Release one reference to the shared client. |
| 190 | +
|
| 191 | + When the count reaches zero, closes and destroys the shared subprocess. |
| 192 | + Safe to call if already at zero (safety floor prevents negative counts). |
| 193 | +
|
| 194 | + NOTE: If the Python process is killed (SIGKILL/crash), the refcount |
| 195 | + never reaches zero and close() is never called. This is acceptable — |
| 196 | + the OS reclaims the copilot subprocess when the parent process exits. |
| 197 | + There is no mitigation needed for this case. |
| 198 | + """ |
| 199 | + global _shared_client, _shared_client_refcount |
| 200 | + async with _get_lock(): |
| 201 | + _shared_client_refcount -= 1 |
| 202 | + logger.debug(f"[MOUNT] Shared client refcount after release: {_shared_client_refcount}") |
| 203 | + if _shared_client_refcount <= 0: |
| 204 | + if _shared_client is not None: |
| 205 | + logger.info( |
| 206 | + "[MOUNT] Last mount cleaned up — shutting down shared Copilot subprocess" |
| 207 | + ) |
| 208 | + await _shared_client.close() |
| 209 | + _shared_client = None |
| 210 | + _shared_client_refcount = 0 # safety floor: prevent negative counts |
| 211 | + |
124 | 212 |
|
125 | 213 | async def mount( |
126 | 214 | coordinator: Any, # ModuleCoordinator |
@@ -178,26 +266,39 @@ async def mount( |
178 | 266 | # Set CLI path in config for provider to use |
179 | 267 | config["cli_path"] = cli_path |
180 | 268 |
|
| 269 | + # Track whether this call acquired a shared client reference, |
| 270 | + # so the error path only releases what this call actually acquired. |
| 271 | + acquired_client: CopilotClientWrapper | None = None |
| 272 | + |
181 | 273 | try: |
182 | | - # Create provider (api_key is None for Copilot - uses GitHub auth) |
183 | | - provider = CopilotSdkProvider(None, config, coordinator) |
| 274 | + timeout = float(config.get("timeout", DEFAULT_TIMEOUT)) |
| 275 | + |
| 276 | + # Acquire (or reuse) the process-level shared client. |
| 277 | + # All mounts in this Python process share one CopilotClientWrapper instance. |
| 278 | + acquired_client = await _acquire_shared_client(config, timeout) |
| 279 | + |
| 280 | + # Create provider, injecting the shared client |
| 281 | + provider = CopilotSdkProvider(None, config, coordinator, client=acquired_client) |
184 | 282 |
|
185 | 283 | # Register with coordinator |
186 | 284 | await coordinator.mount("providers", provider, name="github-copilot") |
187 | 285 |
|
188 | 286 | logger.info("[MOUNT] CopilotSdkProvider mounted successfully") |
189 | 287 |
|
190 | | - # Return cleanup function |
| 288 | + # Return cleanup function — releases the shared reference, not the provider |
191 | 289 | async def cleanup() -> None: |
192 | 290 | """Cleanup function called when unmounting.""" |
193 | 291 | logger.info("[MOUNT] Unmounting CopilotSdkProvider...") |
194 | | - await provider.close() |
| 292 | + await _release_shared_client() |
195 | 293 | logger.info("[MOUNT] CopilotSdkProvider unmounted") |
196 | 294 |
|
197 | 295 | return cleanup |
198 | 296 |
|
199 | 297 | except Exception as e: |
200 | 298 | logger.error(f"[MOUNT] Failed to mount CopilotSdkProvider: {e}") |
| 299 | + # Only release if this call successfully acquired a reference before the failure |
| 300 | + if acquired_client is not None: |
| 301 | + await _release_shared_client() |
201 | 302 | return None |
202 | 303 |
|
203 | 304 |
|
|
0 commit comments