diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 8e4d1eefb..24797770a 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -63,6 +63,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable /// Minimum protocol version this SDK can communicate with. /// private const int MinProtocolVersion = 2; + private static readonly TimeSpan StderrPumpShutdownTimeout = TimeSpan.FromSeconds(5); private readonly ConcurrentDictionary _sessions = new(); private readonly CopilotClientOptions _options; @@ -207,30 +208,60 @@ async Task StartCoreAsync(CancellationToken ct) _logger.LogDebug("Starting Copilot client"); _disconnected = false; - Task result; + Connection? connection = null; + Process? cliProcess = null; + ProcessStderrPump? stderrPump = null; - if (_optionsHost is not null && _optionsPort is not null) + try { - // External server (TCP) - _actualPort = _optionsPort; - result = ConnectToServerAsync(null, _optionsHost, _optionsPort, null, ct); + if (_optionsHost is not null && _optionsPort is not null) + { + // External server (TCP) + _actualPort = _optionsPort; + connection = await ConnectToServerAsync(null, _optionsHost, _optionsPort, null, ct); + } + else + { + // Child process (stdio or TCP) + var portOrNull = (int?)null; + (cliProcess, portOrNull, stderrPump) = await StartCliServerAsync(_options, _logger, ct); + _actualPort = portOrNull; + connection = await ConnectToServerAsync(cliProcess, portOrNull is null ? null : "localhost", portOrNull, stderrPump, ct); + } + + // Verify protocol version compatibility + await VerifyProtocolVersionAsync(connection, ct); + await ConfigureSessionFsAsync(ct); + + _logger.LogInformation("Copilot client connected"); + return connection; } - else + catch { - // Child process (stdio or TCP) - var (cliProcess, portOrNull, stderrBuffer) = await StartCliServerAsync(_options, _logger, ct); - _actualPort = portOrNull; - result = ConnectToServerAsync(cliProcess, portOrNull is null ? null : "localhost", portOrNull, stderrBuffer, ct); - } - - var connection = await result; + var cleanupErrors = new List(); + try + { + if (connection is not null) + { + await CleanupConnectionAsync(connection, cleanupErrors); + } + else if (cliProcess is not null) + { + await CleanupCliProcessAsync(cliProcess, stderrPump, _logger, cleanupErrors); + } - // Verify protocol version compatibility - await VerifyProtocolVersionAsync(connection, ct); - await ConfigureSessionFsAsync(ct); + foreach (var cleanupError in cleanupErrors) + { + _logger.LogDebug(cleanupError, "Failed to clean up Copilot client connection after startup failure"); + } + } + finally + { + _connectionTask = null; + } - _logger.LogInformation("Copilot client connected"); - return connection; + throw; + } } } @@ -334,9 +365,21 @@ private async Task CleanupConnectionAsync(List? errors) return; } - var ctx = await _connectionTask; - _connectionTask = null; + Connection ctx; + try + { + ctx = await _connectionTask; + } + finally + { + _connectionTask = null; + } + + await CleanupConnectionAsync(ctx, errors); + } + private async Task CleanupConnectionAsync(Connection ctx, List? errors) + { try { ctx.Rpc.Dispose(); } catch (Exception ex) { errors?.Add(ex); } @@ -358,13 +401,34 @@ private async Task CleanupConnectionAsync(List? errors) if (ctx.CliProcess is { } childProcess) { - try + await CleanupCliProcessAsync(childProcess, ctx.StderrPump, _logger, errors); + } + } + + private static async Task CleanupCliProcessAsync(Process childProcess, ProcessStderrPump? stderrPump, ILogger logger, List? errors) + { + stderrPump?.Cancel(); + + try + { + if (!childProcess.HasExited) childProcess.Kill(); + } + catch (Exception ex) { errors?.Add(ex); } + + if (stderrPump is not null) + { + try { await stderrPump.WaitForCompletionAsync(StderrPumpShutdownTimeout); } + catch (TimeoutException ex) { - if (!childProcess.HasExited) childProcess.Kill(); - childProcess.Dispose(); + logger.LogDebug(ex, "Timed out waiting for CLI stderr pump to stop"); + errors?.Add(ex); } catch (Exception ex) { errors?.Add(ex); } + finally { stderrPump.Dispose(); } } + + try { childProcess.Dispose(); } + catch (Exception ex) { errors?.Add(ex); } } private static (SystemMessageConfig? wireConfig, Dictionary>>? callbacks) ExtractTransformCallbacks(SystemMessageConfig? systemMessage) @@ -1152,7 +1216,7 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio _negotiatedProtocolVersion = serverVersion; } - private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, ILogger logger, CancellationToken cancellationToken) + private static async Task<(Process Process, int? DetectedLocalhostTcpPort, ProcessStderrPump StderrPump)> StartCliServerAsync(CopilotClientOptions options, ILogger logger, CancellationToken cancellationToken) { // Use explicit path, COPILOT_CLI_PATH env var (from options.Environment or process env), or bundled CLI - no PATH fallback var envCliPath = options.Environment is not null && options.Environment.TryGetValue("COPILOT_CLI_PATH", out var envValue) ? envValue @@ -1240,49 +1304,61 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio } var cliProcess = new Process { StartInfo = startInfo }; - cliProcess.Start(); + try + { + cliProcess.Start(); + } + catch + { + cliProcess.Dispose(); + throw; + } + + // Capture stderr for error messages and forward to logger. + // The pump has its own lifetime token and is later cancelled/observed + // by the owning Connection before the process is disposed. + var stderrPump = ProcessStderrPump.Start(cliProcess, logger); - // Capture stderr for error messages and forward to logger - var stderrBuffer = new StringBuilder(); - _ = Task.Run(async () => + var detectedLocalhostTcpPort = (int?)null; + try { - while (cliProcess != null && !cliProcess.HasExited) + if (!options.UseStdio) { - var line = await cliProcess.StandardError.ReadLineAsync(cancellationToken); - if (line != null) - { - lock (stderrBuffer) - { - stderrBuffer.AppendLine(line); - } + // Wait for port announcement + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(30)); - if (logger.IsEnabled(LogLevel.Debug)) + try + { + while (true) { - logger.LogDebug("[CLI] {Line}", line); + var line = await cliProcess.StandardOutput.ReadLineAsync(cts.Token) ?? throw new IOException("CLI process exited unexpectedly"); + if (ListeningOnPortRegex().Match(line) is { Success: true } match) + { + detectedLocalhostTcpPort = int.Parse(match.Groups[1].Value, CultureInfo.InvariantCulture); + break; + } } } + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested && cts.IsCancellationRequested) + { + throw new IOException("Timed out waiting for Copilot CLI to report its TCP listening port."); + } } - }, cancellationToken); - var detectedLocalhostTcpPort = (int?)null; - if (!options.UseStdio) + return (cliProcess, detectedLocalhostTcpPort, stderrPump); + } + catch { - // Wait for port announcement - using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - cts.CancelAfter(TimeSpan.FromSeconds(30)); - - while (!cts.Token.IsCancellationRequested) + var cleanupErrors = new List(); + await CleanupCliProcessAsync(cliProcess, stderrPump, logger, cleanupErrors); + foreach (var cleanupError in cleanupErrors) { - var line = await cliProcess.StandardOutput.ReadLineAsync(cts.Token) ?? throw new IOException("CLI process exited unexpectedly"); - if (ListeningOnPortRegex().Match(line) is { Success: true } match) - { - detectedLocalhostTcpPort = int.Parse(match.Groups[1].Value, CultureInfo.InvariantCulture); - break; - } + logger.LogDebug(cleanupError, "Failed to clean up Copilot CLI process after startup failure"); } - } - return (cliProcess, detectedLocalhostTcpPort, stderrBuffer); + throw; + } } private static string? GetBundledCliPath(out string searchedPath) @@ -1326,65 +1402,89 @@ private static (string FileName, IEnumerable Args) ResolveCliCommand(str return (cliPath, args); } - private async Task ConnectToServerAsync(Process? cliProcess, string? tcpHost, int? tcpPort, StringBuilder? stderrBuffer, CancellationToken cancellationToken) + private async Task ConnectToServerAsync(Process? cliProcess, string? tcpHost, int? tcpPort, ProcessStderrPump? stderrPump, CancellationToken cancellationToken) { - Stream inputStream, outputStream; TcpClient? tcpClient = null; NetworkStream? networkStream = null; + JsonRpc? rpc = null; - if (_options.UseStdio) - { - if (cliProcess == null) throw new InvalidOperationException("CLI process not started"); - inputStream = cliProcess.StandardOutput.BaseStream; - outputStream = cliProcess.StandardInput.BaseStream; - } - else + try { - if (tcpHost is null || tcpPort is null) + Stream inputStream, outputStream; + + if (_options.UseStdio) { - throw new InvalidOperationException("Cannot connect because TCP host or port are not available"); + if (cliProcess == null) throw new InvalidOperationException("CLI process not started"); + inputStream = cliProcess.StandardOutput.BaseStream; + outputStream = cliProcess.StandardInput.BaseStream; } + else + { + if (tcpHost is null || tcpPort is null) + { + throw new InvalidOperationException("Cannot connect because TCP host or port are not available"); + } - tcpClient = new(); - await tcpClient.ConnectAsync(tcpHost, tcpPort.Value, cancellationToken); - networkStream = tcpClient.GetStream(); - inputStream = networkStream; - outputStream = networkStream; - } + tcpClient = new(); + await tcpClient.ConnectAsync(tcpHost, tcpPort.Value, cancellationToken); + networkStream = tcpClient.GetStream(); + inputStream = networkStream; + outputStream = networkStream; + } - var rpc = new JsonRpc(new HeaderDelimitedMessageHandler( - outputStream, - inputStream, - CreateSystemTextJsonFormatter())) - { - TraceSource = new LoggerTraceSource(_logger), - }; + rpc = new JsonRpc(new HeaderDelimitedMessageHandler( + outputStream, + inputStream, + CreateSystemTextJsonFormatter())) + { + TraceSource = new LoggerTraceSource(_logger), + }; - var handler = new RpcHandler(this); - rpc.AddLocalRpcMethod("session.event", handler.OnSessionEvent); - rpc.AddLocalRpcMethod("session.lifecycle", handler.OnSessionLifecycle); - // Protocol v3 servers send tool calls / permission requests as broadcast events. - // Protocol v2 servers use the older tool.call / permission.request RPC model. - // We always register v2 adapters because handlers are set up before version - // negotiation; a v3 server will simply never send these requests. - rpc.AddLocalRpcMethod("tool.call", handler.OnToolCallV2); - rpc.AddLocalRpcMethod("permission.request", handler.OnPermissionRequestV2); - rpc.AddLocalRpcMethod("userInput.request", handler.OnUserInputRequest); - rpc.AddLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke); - rpc.AddLocalRpcMethod("systemMessage.transform", handler.OnSystemMessageTransform); - ClientSessionApiRegistration.RegisterClientSessionApiHandlers(rpc, sessionId => - { - var session = GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); - return session.ClientSessionApis; - }); - rpc.StartListening(); + var handler = new RpcHandler(this); + rpc.AddLocalRpcMethod("session.event", handler.OnSessionEvent); + rpc.AddLocalRpcMethod("session.lifecycle", handler.OnSessionLifecycle); + // Protocol v3 servers send tool calls / permission requests as broadcast events. + // Protocol v2 servers use the older tool.call / permission.request RPC model. + // We always register v2 adapters because handlers are set up before version + // negotiation; a v3 server will simply never send these requests. + rpc.AddLocalRpcMethod("tool.call", handler.OnToolCallV2); + rpc.AddLocalRpcMethod("permission.request", handler.OnPermissionRequestV2); + rpc.AddLocalRpcMethod("userInput.request", handler.OnUserInputRequest); + rpc.AddLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke); + rpc.AddLocalRpcMethod("systemMessage.transform", handler.OnSystemMessageTransform); + ClientSessionApiRegistration.RegisterClientSessionApiHandlers(rpc, sessionId => + { + var session = GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + return session.ClientSessionApis; + }); + rpc.StartListening(); + + // Transition state to Disconnected if the JSON-RPC connection drops + _ = rpc.Completion.ContinueWith(_ => _disconnected = true, TaskScheduler.Default); - // Transition state to Disconnected if the JSON-RPC connection drops - _ = rpc.Completion.ContinueWith(_ => _disconnected = true, TaskScheduler.Default); + _rpc = new ServerRpc(rpc); - _rpc = new ServerRpc(rpc); + return new Connection(rpc, cliProcess, tcpClient, networkStream, stderrPump); + } + catch + { + try { rpc?.Dispose(); } + catch (Exception ex) { _logger.LogDebug(ex, "Failed to dispose JSON-RPC connection after startup failure"); } - return new Connection(rpc, cliProcess, tcpClient, networkStream, stderrBuffer); + if (networkStream is not null) + { + try { await networkStream.DisposeAsync(); } + catch (Exception ex) { _logger.LogDebug(ex, "Failed to dispose TCP stream after startup failure"); } + } + + if (tcpClient is not null) + { + try { tcpClient.Dispose(); } + catch (Exception ex) { _logger.LogDebug(ex, "Failed to dispose TCP client after startup failure"); } + } + + throw; + } } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Using happy path from https://microsoft.github.io/vs-streamjsonrpc/docs/nativeAOT.html")] @@ -1613,13 +1713,114 @@ private class Connection( Process? cliProcess, // Set if we created the child process TcpClient? tcpClient, // Set if using TCP NetworkStream? networkStream, // Set if using TCP - StringBuilder? stderrBuffer = null) // Captures stderr for error messages + ProcessStderrPump? stderrPump = null) // Captures stderr for error messages { public Process? CliProcess => cliProcess; public TcpClient? TcpClient => tcpClient; public JsonRpc Rpc => rpc; public NetworkStream? NetworkStream => networkStream; - public StringBuilder? StderrBuffer => stderrBuffer; + public ProcessStderrPump? StderrPump => stderrPump; + public StringBuilder? StderrBuffer => stderrPump?.Buffer; + } + + private sealed class ProcessStderrPump : IDisposable + { + private readonly CancellationTokenSource _cancellationTokenSource = new(); + private readonly Task _completion; + private int _disposeRequested; + + private ProcessStderrPump(Process process, ILogger logger) + { + _completion = Task.Run(() => PumpAsync(process, logger, _cancellationTokenSource.Token)); + } + + public StringBuilder Buffer { get; } = new(); + + public static ProcessStderrPump Start(Process process, ILogger logger) + { + return new ProcessStderrPump(process, logger); + } + + public void Cancel() + { + try + { + _cancellationTokenSource.Cancel(); + } + catch (ObjectDisposedException) + { + } + } + + public async Task WaitForCompletionAsync(TimeSpan timeout) + { + await _completion.WaitAsync(timeout); + } + + public void Dispose() + { + if (Interlocked.Exchange(ref _disposeRequested, 1) != 0) + { + return; + } + + Cancel(); + + if (_completion.IsCompleted) + { + _cancellationTokenSource.Dispose(); + } + else + { + _ = _completion.ContinueWith( + static (_, state) => ((CancellationTokenSource)state!).Dispose(), + _cancellationTokenSource, + CancellationToken.None, + TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + } + } + + private async Task PumpAsync(Process process, ILogger logger, CancellationToken cancellationToken) + { + try + { + while (true) + { + var line = await process.StandardError.ReadLineAsync(cancellationToken); + if (line is null) + { + break; + } + + lock (Buffer) + { + Buffer.AppendLine(line); + } + + if (logger.IsEnabled(LogLevel.Debug)) + { + logger.LogDebug("[CLI] {Line}", line); + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (InvalidOperationException) when (cancellationToken.IsCancellationRequested) + { + } + catch (ObjectDisposedException) when (cancellationToken.IsCancellationRequested) + { + } + catch (IOException) when (cancellationToken.IsCancellationRequested) + { + } + catch (Exception ex) + { + logger.LogDebug(ex, "CLI stderr pump stopped unexpectedly"); + } + } } private static class ProcessArgumentEscaper