Skip to content

Commit ed1ed14

Browse files
Copilotericstj
andcommitted
Avoid string intermediates in MCP transport read side using PipeReader
Co-authored-by: ericstj <8918108+ericstj@users.noreply.github.com>
1 parent eab4708 commit ed1ed14

5 files changed

Lines changed: 200 additions & 62 deletions

File tree

src/Common/Polyfills/System/Text/EncodingExtensions.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,25 @@ public static int GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, Spa
4545
}
4646
}
4747
}
48+
49+
/// <summary>
50+
/// Decodes all the bytes in the specified span into a string.
51+
/// </summary>
52+
public static string GetString(this Encoding encoding, ReadOnlySpan<byte> bytes)
53+
{
54+
if (bytes.IsEmpty)
55+
{
56+
return string.Empty;
57+
}
58+
59+
unsafe
60+
{
61+
fixed (byte* bytesPtr = bytes)
62+
{
63+
return encoding.GetString(bytesPtr, bytes.Length);
64+
}
65+
}
66+
}
4867
}
4968

5069
#endif

src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace ModelContextProtocol.Client;
77
/// <summary>Provides the client side of a stdio-based session transport.</summary>
88
internal sealed class StdioClientSessionTransport(
99
StdioClientTransportOptions options, Process process, string endpointName, Queue<string> stderrRollingLog, ILoggerFactory? loggerFactory) :
10-
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory)
10+
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, endpointName, loggerFactory)
1111
{
1212
private readonly StdioClientTransportOptions _options = options;
1313
private readonly Process _process = process;

src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Microsoft.Extensions.Logging;
22
using ModelContextProtocol.Protocol;
3+
using System.Buffers;
4+
using System.IO.Pipelines;
35
using System.Text;
46
using System.Text.Json;
57

@@ -12,7 +14,7 @@ internal class StreamClientSessionTransport : TransportBase
1214

1315
internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false);
1416

15-
private readonly TextReader _serverOutput;
17+
private readonly PipeReader _serverOutputPipe;
1618
private readonly Stream _serverInputStream;
1719
private readonly SemaphoreSlim _sendLock = new(1, 1);
1820
private CancellationTokenSource? _shutdownCts = new();
@@ -27,9 +29,6 @@ internal class StreamClientSessionTransport : TransportBase
2729
/// <param name="serverOutput">
2830
/// The server's output stream. Messages read from this stream will be received from the server.
2931
/// </param>
30-
/// <param name="encoding">
31-
/// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
32-
/// </param>
3332
/// <param name="endpointName">
3433
/// A name that identifies this transport endpoint in logs.
3534
/// </param>
@@ -40,18 +39,14 @@ internal class StreamClientSessionTransport : TransportBase
4039
/// This constructor starts a background task to read messages from the server output stream.
4140
/// The transport will be marked as connected once initialized.
4241
/// </remarks>
43-
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory)
42+
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, string endpointName, ILoggerFactory? loggerFactory)
4443
: base(endpointName, loggerFactory)
4544
{
4645
Throw.IfNull(serverInput);
4746
Throw.IfNull(serverOutput);
4847

4948
_serverInputStream = serverInput;
50-
#if NET
51-
_serverOutput = new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
52-
#else
53-
_serverOutput = new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
54-
#endif
49+
_serverOutputPipe = PipeReader.Create(serverOutput);
5550

5651
SetConnected();
5752

@@ -105,20 +100,41 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
105100

106101
while (true)
107102
{
108-
if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line)
109-
{
110-
LogTransportEndOfStream(Name);
111-
break;
112-
}
103+
ReadResult result = await _serverOutputPipe.ReadAsync(cancellationToken).ConfigureAwait(false);
104+
ReadOnlySequence<byte> buffer = result.Buffer;
113105

114-
if (string.IsNullOrWhiteSpace(line))
106+
SequencePosition? position;
107+
while ((position = buffer.PositionOf((byte)'\n')) != null)
115108
{
116-
continue;
109+
ReadOnlySequence<byte> line = buffer.Slice(0, position.Value);
110+
111+
// Trim trailing \r for Windows-style CRLF line endings.
112+
if (EndsWithCarriageReturn(line))
113+
{
114+
line = line.Slice(0, line.Length - 1);
115+
}
116+
117+
if (!line.IsEmpty)
118+
{
119+
if (Logger.IsEnabled(LogLevel.Trace))
120+
{
121+
LogTransportReceivedMessageSensitive(Name, GetString(line));
122+
}
123+
124+
await ProcessLineAsync(line, cancellationToken).ConfigureAwait(false);
125+
}
126+
127+
// Advance past the '\n'.
128+
buffer = buffer.Slice(buffer.GetPosition(1, position.Value));
117129
}
118130

119-
LogTransportReceivedMessageSensitive(Name, line);
131+
_serverOutputPipe.AdvanceTo(buffer.Start, buffer.End);
120132

121-
await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
133+
if (result.IsCompleted)
134+
{
135+
LogTransportEndOfStream(Name);
136+
break;
137+
}
122138
}
123139
}
124140
catch (OperationCanceledException)
@@ -137,25 +153,38 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
137153
}
138154
}
139155

140-
private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken)
156+
private async Task ProcessLineAsync(ReadOnlySequence<byte> line, CancellationToken cancellationToken)
141157
{
142158
try
143159
{
144-
var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)));
145-
if (message != null)
160+
JsonRpcMessage? message;
161+
if (line.IsSingleSegment)
162+
{
163+
message = JsonSerializer.Deserialize(line.First.Span, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
164+
}
165+
else
166+
{
167+
var reader = new Utf8JsonReader(line, isFinalBlock: true, state: default);
168+
message = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
169+
}
170+
171+
if (message is not null)
146172
{
147173
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
148174
}
149175
else
150176
{
151-
LogTransportMessageParseUnexpectedTypeSensitive(Name, line);
177+
if (Logger.IsEnabled(LogLevel.Trace))
178+
{
179+
LogTransportMessageParseUnexpectedTypeSensitive(Name, GetString(line));
180+
}
152181
}
153182
}
154183
catch (JsonException ex)
155184
{
156185
if (Logger.IsEnabled(LogLevel.Trace))
157186
{
158-
LogTransportMessageParseFailedSensitive(Name, line, ex);
187+
LogTransportMessageParseFailedSensitive(Name, GetString(line), ex);
159188
}
160189
else
161190
{
@@ -164,6 +193,32 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati
164193
}
165194
}
166195

196+
private static string GetString(in ReadOnlySequence<byte> sequence) =>
197+
sequence.IsSingleSegment
198+
? Encoding.UTF8.GetString(sequence.First.Span)
199+
: Encoding.UTF8.GetString(sequence.ToArray());
200+
201+
private static bool EndsWithCarriageReturn(in ReadOnlySequence<byte> sequence)
202+
{
203+
if (sequence.IsSingleSegment)
204+
{
205+
ReadOnlySpan<byte> span = sequence.First.Span;
206+
return span.Length > 0 && span[span.Length - 1] == (byte)'\r';
207+
}
208+
209+
// Multi-segment: find the last non-empty segment to check its last byte.
210+
ReadOnlyMemory<byte> last = default;
211+
foreach (ReadOnlyMemory<byte> segment in sequence)
212+
{
213+
if (!segment.IsEmpty)
214+
{
215+
last = segment;
216+
}
217+
}
218+
219+
return !last.IsEmpty && last.Span[last.Length - 1] == (byte)'\r';
220+
}
221+
167222
protected virtual async ValueTask CleanupAsync(Exception? error = null, CancellationToken cancellationToken = default)
168223
{
169224
LogTransportShuttingDown(Name);

src/ModelContextProtocol.Core/Client/StreamClientTransport.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ public Task<ITransport> ConnectAsync(CancellationToken cancellationToken = defau
5050
return Task.FromResult<ITransport>(new StreamClientSessionTransport(
5151
_serverInput,
5252
_serverOutput,
53-
encoding: null,
5453
"Client (stream)",
5554
_loggerFactory));
5655
}

0 commit comments

Comments
 (0)