Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from types import TracebackType

import anyio
import httpx
Expand All @@ -17,6 +18,7 @@
from mcp.client._transport import TransportStreams
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared._stream_protocols import WriteStream
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
INTERNAL_ERROR,
Expand Down Expand Up @@ -512,6 +514,35 @@ def get_session_id(self) -> str | None:
return self.session_id # pragma: no cover


class _SessionAwareWriteStream:
"""Write-stream wrapper that exposes the transport session ID."""

def __init__(self, inner: WriteStream[SessionMessage], transport: StreamableHTTPTransport) -> None:
self._inner = inner
self._transport = transport

async def send(self, item: SessionMessage) -> None:
await self._inner.send(item)

async def aclose(self) -> None:
await self._inner.aclose()

def get_session_id(self) -> str | None:
return self._transport.session_id

async def __aenter__(self) -> _SessionAwareWriteStream:
await self._inner.__aenter__()
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
return await self._inner.__aexit__(exc_type, exc_val, exc_tb)


# TODO(Marcelo): I've dropped the `get_session_id` callback because it breaks the Transport protocol. Is that needed?
# It's a completely wrong abstraction, so removal is a good idea. But if we need the client to find the session ID,
# we should think about a better way to do it. I believe we can achieve it with other means.
Expand Down Expand Up @@ -581,7 +612,7 @@ def start_get_stream() -> None:
)

try:
yield read_stream, write_stream
yield read_stream, _SessionAwareWriteStream(write_stream, transport)
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
Expand Down
31 changes: 18 additions & 13 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,14 @@
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._otel import extract_trace_context, otel_span
from mcp.shared._otel import build_server_span_attributes, extract_trace_context, otel_span
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.exceptions import MCPError
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder

logger = logging.getLogger(__name__)
MCP_SESSION_ID_HEADER = "mcp-session-id"

Check warning on line 76 in src/mcp/server/lowlevel/server.py

View check run for this annotation

Claude / Claude Code Review

Duplicate MCP_SESSION_ID_HEADER constant

This PR introduces a duplicate definition of `MCP_SESSION_ID_HEADER = "mcp-session-id"` at `src/mcp/server/lowlevel/server.py:76`, when the identical constant already exists in `src/mcp/server/streamable_http.py:51`. Since `server.py` already imports `EventStore` from `mcp.server.streamable_http`, adding `MCP_SESSION_ID_HEADER` to that import would be a one-line fix that eliminates a maintenance hazard.

LifespanResultT = TypeVar("LifespanResultT", default=Any)

Expand Down Expand Up @@ -454,28 +455,32 @@
# Extract W3C trace context from _meta (SEP-414).
meta = cast(dict[str, Any] | None, getattr(req.params, "meta", None)) if req.params else None
parent_context = extract_trace_context(meta) if meta is not None else None
request_data = None
close_sse_stream_cb = None
close_standalone_sse_stream_cb = None
if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata):
request_data = message.message_metadata.request_context
close_sse_stream_cb = message.message_metadata.close_sse_stream
close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream
request_headers = getattr(request_data, "headers", None)
session_id = request_headers.get(MCP_SESSION_ID_HEADER) if request_headers is not None else None

with otel_span(
span_name,
kind=SpanKind.SERVER,
attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id},
attributes=build_server_span_attributes(
service_name=self.name,
method=req.method,
request_id=message.request_id,
params=req.params,
session_id=session_id,
),
context=parent_context,
) as span:
if handler := self._request_handlers.get(req.method):
logger.debug("Dispatching request of type %s", type(req).__name__)

try:
# Extract request context and close_sse_stream from message metadata
request_data = None
close_sse_stream_cb = None
close_standalone_sse_stream_cb = None
if message.message_metadata is not None and isinstance(
message.message_metadata, ServerMessageMetadata
):
request_data = message.message_metadata.request_context
close_sse_stream_cb = message.message_metadata.close_sse_stream
close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream

client_capabilities = session.client_params.capabilities if session.client_params else None
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
# Get task metadata from request params if present
Expand Down
52 changes: 52 additions & 0 deletions src/mcp/shared/_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from opentelemetry.trace import SpanKind, get_tracer

_tracer = get_tracer("mcp-python-sdk")
MCP_RPC_SYSTEM = "mcp"


@contextmanager
Expand All @@ -34,3 +35,54 @@ def inject_trace_context(meta: dict[str, Any]) -> None:
def extract_trace_context(meta: dict[str, Any]) -> Context:
"""Extract W3C trace context from a `_meta` dict."""
return extract(meta)


def build_client_span_attributes(
*,
method: str,
request_id: str | int,
params: dict[str, Any] | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Build OTel attributes for an MCP client request span."""
attributes: dict[str, Any] = {
"rpc.system": MCP_RPC_SYSTEM,
"rpc.method": method,
"mcp.method.name": method,
"jsonrpc.request.id": request_id,
}

if params is not None and (resource_uri := params.get("uri")) is not None:
attributes["mcp.resource.uri"] = resource_uri

if session_id is not None:
attributes["mcp.session.id"] = session_id

return attributes


def build_server_span_attributes(
*,
service_name: str,
method: str,
request_id: str | int,
params: Any = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Build OTel attributes for an MCP server request span."""
attributes: dict[str, Any] = {
"rpc.system": MCP_RPC_SYSTEM,
"rpc.service": service_name,
"rpc.method": method,
"mcp.method.name": method,
"jsonrpc.request.id": request_id,
}

resource_uri = getattr(params, "uri", None)
if resource_uri is not None:
attributes["mcp.resource.uri"] = str(resource_uri)

if session_id is not None:
attributes["mcp.session.id"] = session_id

return attributes
18 changes: 15 additions & 3 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from collections.abc import Callable
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Any, Generic, Protocol, TypeVar
from typing import Any, Generic, Protocol, TypeVar, cast

import anyio
from anyio.streams.memory import MemoryObjectSendStream
from opentelemetry.trace import SpanKind
from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self

from mcp.shared._otel import inject_trace_context, otel_span
from mcp.shared._otel import build_client_span_attributes, inject_trace_context, otel_span
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.exceptions import MCPError
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
Expand Down Expand Up @@ -236,6 +236,13 @@ async def __aexit__(
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)

def _get_transport_session_id(self) -> str | None:
"""Return the transport session ID when the write stream exposes it."""
get_session_id = getattr(self._write_stream, "get_session_id", None)
if callable(get_session_id):
return cast("str | None", get_session_id())
return None

async def send_request(
self,
request: SendRequestT,
Expand Down Expand Up @@ -276,7 +283,12 @@ async def send_request(
with otel_span(
span_name,
kind=SpanKind.CLIENT,
attributes={"mcp.method.name": request.method, "jsonrpc.request.id": request_id},
attributes=build_client_span_attributes(
method=request.method,
request_id=request_id,
params=request_data.get("params"),
session_id=self._get_transport_session_id(),
),
):
# Inject W3C trace context into _meta (SEP-414).
meta: dict[str, Any] = request_data.setdefault("params", {}).setdefault("_meta", {})
Expand Down
36 changes: 36 additions & 0 deletions tests/shared/test_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,44 @@ def greet(name: str) -> str:
client_span = next(s for s in spans if s["name"] == "MCP send tools/call greet")
server_span = next(s for s in spans if s["name"] == "MCP handle tools/call greet")

assert client_span["attributes"]["rpc.system"] == "mcp"
assert client_span["attributes"]["rpc.method"] == "tools/call"
assert client_span["attributes"]["mcp.method.name"] == "tools/call"
assert server_span["attributes"]["rpc.system"] == "mcp"
assert server_span["attributes"]["rpc.service"] == "test"
assert server_span["attributes"]["rpc.method"] == "tools/call"
assert server_span["attributes"]["mcp.method.name"] == "tools/call"

# Server span should be in the same trace as the client span (context propagation).
assert server_span["context"]["trace_id"] == client_span["context"]["trace_id"]


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
async def test_resource_read_spans_include_resource_uri(capfire: CaptureLogfire):
"""Verify that resource reads include MCP resource and RPC attributes."""
server = MCPServer("test")

@server.resource("test://resource")
def test_resource() -> str:
return "hello"

async with Client(server) as client:
result = await client.read_resource("test://resource")

assert result.contents[0].uri == "test://resource"

spans = capfire.exporter.exported_spans_as_dict()

client_span = next(s for s in spans if s["name"] == "MCP send resources/read")
server_span = next(s for s in spans if s["name"] == "MCP handle resources/read")

assert client_span["attributes"]["rpc.system"] == "mcp"
assert client_span["attributes"]["rpc.method"] == "resources/read"
assert client_span["attributes"]["mcp.method.name"] == "resources/read"
assert client_span["attributes"]["mcp.resource.uri"] == "test://resource"

assert server_span["attributes"]["rpc.system"] == "mcp"
assert server_span["attributes"]["rpc.service"] == "test"
assert server_span["attributes"]["rpc.method"] == "resources/read"
assert server_span["attributes"]["mcp.method.name"] == "resources/read"
assert server_span["attributes"]["mcp.resource.uri"] == "test://resource"
21 changes: 21 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import requests
import uvicorn
from httpx_sse import ServerSentEvent
from logfire.testing import CaptureLogfire
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount
Expand Down Expand Up @@ -1081,6 +1082,26 @@ async def test_streamable_http_client_resource_read(initialized_client_session:
assert response.contents[0].text == "Read test-resource"


@pytest.mark.anyio
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
async def test_streamable_http_resource_read_spans_include_session_id(
capfire: CaptureLogfire, basic_server: None, basic_server_url: str
):
"""Verify streamable HTTP spans include the negotiated MCP session ID."""
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
response = await session.read_resource(uri="foobar://test-resource")

assert response.contents[0].uri == "foobar://test-resource"

spans = capfire.exporter.exported_spans_as_dict()
client_span = next(s for s in spans if s["name"] == "MCP send resources/read")

assert client_span["attributes"]["mcp.session.id"]
assert client_span["attributes"]["mcp.resource.uri"] == "foobar://test-resource"


@pytest.mark.anyio
async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession):
"""Test client tool invocation."""
Expand Down
Loading