Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public override async Task SendMessageAsync(
messageId = messageWithId.Id.ToString();
}

if (_logger.IsEnabled(LogLevel.Trace))
{
LogTransportSendingMessageSensitive(Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
}

using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint);
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null);
var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation

var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);

if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportSendingMessageSensitive(Name, json);
}

using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
$"Call {nameof(McpClient)}.{nameof(McpClient.ResumeSessionAsync)} to resume existing sessions.");
}

if (_logger.IsEnabled(LogLevel.Trace))
{
LogTransportSendingMessageSensitive(Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
}

using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token);
cancellationToken = sendCts.Token;

Expand Down
3 changes: 3 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/TransportBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ protected void SetDisconnected(Exception? error = null)
[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} transport send failed for message ID '{MessageId}'.")]
private protected partial void LogTransportSendFailed(string endpointName, string messageId, Exception exception);

[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} transport sending message. Message: '{Message}'.")]
private protected partial void LogTransportSendingMessageSensitive(string endpointName, string message);

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} transport reading messages.")]
private protected partial void LogTransportEnteringReadMessagesLoop(string endpointName);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation

try
{
await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), cancellationToken).ConfigureAwait(false);
var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportSendingMessageSensitive(Name, json);
}
await _outputStream.WriteAsync(Encoding.UTF8.GetBytes(json), cancellationToken).ConfigureAwait(false);
await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Protocol;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using System.IO.Pipelines;
Expand Down Expand Up @@ -193,4 +194,72 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters()
Assert.True(magnifyingGlassFound, "Magnifying glass emoji not found in result");
Assert.True(rocketFound, "Rocket emoji not found in result");
}

[Fact]
public async Task SendMessageAsync_Should_Log_At_Trace_Level()
{
// Arrange
var mockLoggerProvider = new MockLoggerProvider();
using var traceLoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
{
builder.AddProvider(mockLoggerProvider);
builder.SetMinimumLevel(LogLevel.Trace);
});
using var output = new MemoryStream();

await using var transport = new StreamServerTransport(
new Pipe().Reader.AsStream(),
output,
loggerFactory: traceLoggerFactory);

// Act
var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) };
await transport.SendMessageAsync(message, TestContext.Current.CancellationToken);

// Assert
var traceLogMessages = mockLoggerProvider.LogMessages
.Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport sending message"))
.ToList();

Assert.NotEmpty(traceLogMessages);
Assert.Contains(traceLogMessages, x => x.Message.Contains("\"method\":\"test\"") && x.Message.Contains("\"id\":44"));
}

[Fact]
public async Task ReadMessagesAsync_Should_Log_Received_At_Trace_Level()
{
// Arrange
var mockLoggerProvider = new MockLoggerProvider();
using var traceLoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
{
builder.AddProvider(mockLoggerProvider);
builder.SetMinimumLevel(LogLevel.Trace);
});

var message = new JsonRpcRequest { Method = "test", Id = new RequestId(99) };
var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions);

Pipe pipe = new();
using var input = pipe.Reader.AsStream();

await using var transport = new StreamServerTransport(
input,
Stream.Null,
loggerFactory: traceLoggerFactory);

// Act
await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes($"{json}\n"), TestContext.Current.CancellationToken);

// Wait for the message to be processed
var canRead = await transport.MessageReader.WaitToReadAsync(TestContext.Current.CancellationToken);
Assert.True(canRead, "Nothing to read here from transport message reader");

// Assert
var traceLogMessages = mockLoggerProvider.LogMessages
.Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport received message"))
.ToList();

Assert.NotEmpty(traceLogMessages);
Assert.Contains(traceLogMessages, x => x.Message.Contains("\"method\":\"test\"") && x.Message.Contains("\"id\":99"));
}
}