Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ public override async Task SendMessageAsync(
messageId = messageWithId.Id.ToString();
}

LogTransportSendingMessageSensitive(message);

using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint);
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null);
var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -193,7 +195,10 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation
return;
}

LogTransportReceivedMessageSensitive(Name, data);
if (_logger.IsEnabled(LogLevel.Trace))
{
LogTransportReceivedMessageSensitive(Name, data);
}
Comment thread
stephentoub marked this conversation as resolved.
Outdated

try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation

var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);

if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportSendingMessageSensitive(Name, json);
}

using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
try
{
Expand Down Expand Up @@ -143,7 +148,10 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
continue;
}

LogTransportReceivedMessageSensitive(Name, line);
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportReceivedMessageSensitive(Name, line);
}

await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
$"Call {nameof(McpClient)}.{nameof(McpClient.ResumeSessionAsync)} to resume existing sessions.");
}

LogTransportSendingMessageSensitive(message);

using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token);
cancellationToken = sendCts.Token;

Expand Down Expand Up @@ -342,7 +344,10 @@ private async Task<SseResponse> ProcessSseResponseAsync(

private async Task<JsonRpcMessageWithId?> ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken)
{
LogTransportReceivedMessageSensitive(Name, data);
if (_logger.IsEnabled(LogLevel.Trace))
{
LogTransportReceivedMessageSensitive(Name, data);
}

try
{
Expand Down
16 changes: 16 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/TransportBase.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using System.Diagnostics;
using System.Text.Json;
using System.Threading.Channels;

namespace ModelContextProtocol.Protocol;
Expand Down Expand Up @@ -166,6 +167,21 @@ protected void SetDisconnected(Exception? error = null)
[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} transport send failed for message ID '{MessageId}'.")]
private protected partial void LogTransportSendFailed(string endpointName, string messageId, Exception exception);

[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} transport sending message. Message: '{Message}'.")]
private protected partial void LogTransportSendingMessageSensitive(string endpointName, string message);

/// <summary>
/// Logs a sending message at Trace level if trace logging is enabled.
/// </summary>
/// <param name="message">The JSON-RPC message to log.</param>
private protected void LogTransportSendingMessageSensitive(JsonRpcMessage message)
{
if (_logger.IsEnabled(LogLevel.Trace))
{
LogTransportSendingMessageSensitive(Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
}
}

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} transport reading messages.")]
private protected partial void LogTransportEnteringReadMessagesLoop(string endpointName);

Expand Down
12 changes: 10 additions & 2 deletions src/ModelContextProtocol.Core/Server/StreamServerTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation

try
{
await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), cancellationToken).ConfigureAwait(false);
var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportSendingMessageSensitive(Name, json);
}
await _outputStream.WriteAsync(Encoding.UTF8.GetBytes(json), cancellationToken).ConfigureAwait(false);
await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
Expand Down Expand Up @@ -107,7 +112,10 @@ private async Task ReadMessagesAsync()
continue;
}

LogTransportReceivedMessageSensitive(Name, line);
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportReceivedMessageSensitive(Name, line);
}

try
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Protocol;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using System.IO.Pipelines;
Expand Down Expand Up @@ -193,4 +194,72 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters()
Assert.True(magnifyingGlassFound, "Magnifying glass emoji not found in result");
Assert.True(rocketFound, "Rocket emoji not found in result");
}

[Fact]
public async Task SendMessageAsync_Should_Log_At_Trace_Level()
{
// Arrange
var mockLoggerProvider = new MockLoggerProvider();
Comment thread
halter73 marked this conversation as resolved.
Outdated
using var traceLoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
{
builder.AddProvider(mockLoggerProvider);
builder.SetMinimumLevel(LogLevel.Trace);
});
using var output = new MemoryStream();

await using var transport = new StreamServerTransport(
new Pipe().Reader.AsStream(),
output,
loggerFactory: traceLoggerFactory);
Comment thread
stephentoub marked this conversation as resolved.
Outdated

// Act
var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) };
await transport.SendMessageAsync(message, TestContext.Current.CancellationToken);

// Assert
var traceLogMessages = mockLoggerProvider.LogMessages
.Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport sending message"))
.ToList();

Assert.NotEmpty(traceLogMessages);
Assert.Contains(traceLogMessages, x => x.Message.Contains("\"method\":\"test\"") && x.Message.Contains("\"id\":44"));
}

[Fact]
public async Task ReadMessagesAsync_Should_Log_Received_At_Trace_Level()
{
// Arrange
var mockLoggerProvider = new MockLoggerProvider();
using var traceLoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
{
builder.AddProvider(mockLoggerProvider);
builder.SetMinimumLevel(LogLevel.Trace);
});

var message = new JsonRpcRequest { Method = "test", Id = new RequestId(99) };
var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions);

Pipe pipe = new();
using var input = pipe.Reader.AsStream();

await using var transport = new StreamServerTransport(
input,
Stream.Null,
loggerFactory: traceLoggerFactory);

// Act
await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes($"{json}\n"), TestContext.Current.CancellationToken);

// Wait for the message to be processed
var canRead = await transport.MessageReader.WaitToReadAsync(TestContext.Current.CancellationToken);
Assert.True(canRead, "Nothing to read here from transport message reader");

// Assert
var traceLogMessages = mockLoggerProvider.LogMessages
.Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport received message"))
.ToList();

Assert.NotEmpty(traceLogMessages);
Assert.Contains(traceLogMessages, x => x.Message.Contains("\"method\":\"test\"") && x.Message.Contains("\"id\":99"));
}
}