Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ This document contains critical information about working with this codebase. Fo
- Bug fixes require regression tests
- IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns.
- IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible.
- IMPORTANT: Do NOT test private functions (prefixed with `_`). Test them indirectly through the public API.

Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py`
Add tests to the existing file for that module.
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"pyjwt[crypto]>=2.10.1",
"typing-extensions>=4.13.0",
"typing-inspection>=0.4.1",
"opentelemetry-api>=1.28.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -71,6 +72,7 @@ dev = [
"coverage[toml]>=7.10.7,<=7.13",
"pillow>=12.0",
"strict-no-cover",
"opentelemetry-sdk>=1.28.0",
]
docs = [
"mkdocs>=1.6.1",
Expand Down
14 changes: 13 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mcp.shared.exceptions import MCPError
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.response_router import ResponseRouter
from mcp.shared.tracing import end_span_error, end_span_ok, start_client_span
from mcp.types import (
CONNECTION_CLOSED,
INVALID_PARAMS,
Expand Down Expand Up @@ -260,6 +261,9 @@ async def send_request(
# Store the callback for this request
self._progress_callbacks[request_id] = progress_callback

method: str = request_data["method"]
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the cast needed?

span = start_client_span(method, request_data.get("params"))

try:
jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)
await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata))
Expand All @@ -278,7 +282,15 @@ async def send_request(
if isinstance(response_or_error, JSONRPCError):
raise MCPError.from_jsonrpc_error(response_or_error)
else:
return result_type.model_validate(response_or_error.result, by_name=False)
result = result_type.model_validate(response_or_error.result, by_name=False)
if span is not None:
end_span_ok(span)
return result

except BaseException as exc:
if span is not None:
end_span_error(span, exc)
raise

finally:
self._response_streams.pop(request_id, None)
Expand Down
63 changes: 63 additions & 0 deletions src/mcp/shared/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from typing import Any

from opentelemetry import trace
from opentelemetry.trace import StatusCode

_tracer = trace.get_tracer("mcp")

_EXCLUDED_METHODS: frozenset[str] = frozenset({"notifications/message"})

# Semantic convention attribute keys
ATTR_MCP_METHOD_NAME = "mcp.method.name"
ATTR_ERROR_TYPE = "error.type"

# Methods that have a meaningful target name in params
_TARGET_PARAM_KEY: dict[str, str] = {
"tools/call": "name",
"prompts/get": "name",
"resources/read": "uri",
}


def _extract_target(method: str, params: dict[str, Any] | None) -> str | None:
"""Extract the target (e.g. tool name, prompt name) from request params."""
key = _TARGET_PARAM_KEY.get(method)
if key is None or params is None:
return None
value = params.get(key)
if isinstance(value, str):
return value
return None


def start_client_span(method: str, params: dict[str, Any] | None) -> trace.Span | None:
"""Start a CLIENT span for an outgoing MCP request.

Returns None if the method is excluded from tracing.
"""
if method in _EXCLUDED_METHODS:
return None

target = _extract_target(method, params)
span_name = f"{method} {target}" if target else method
span = _tracer.start_span(
span_name,
kind=trace.SpanKind.CLIENT,
attributes={ATTR_MCP_METHOD_NAME: method},
)
return span


def end_span_ok(span: trace.Span) -> None:
"""Mark a span as successful and end it."""
span.set_status(StatusCode.OK)
span.end()


def end_span_error(span: trace.Span, error: BaseException) -> None:
"""Mark a span as errored and end it."""
span.set_status(StatusCode.ERROR, str(error))
span.set_attribute(ATTR_ERROR_TYPE, type(error).__qualname__)
span.end()
129 changes: 129 additions & 0 deletions tests/shared/test_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from __future__ import annotations

from typing import Any

import anyio
import pytest
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace import SpanKind, StatusCode

from mcp import Client, types
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import MCPError
from mcp.shared.tracing import ATTR_ERROR_TYPE, ATTR_MCP_METHOD_NAME

# Module-level provider + exporter — avoids the "Overriding of current
# TracerProvider is not allowed" warning that happens if you call
# set_tracer_provider() more than once.
_provider = TracerProvider()
_exporter = InMemorySpanExporter()
_provider.add_span_processor(SimpleSpanProcessor(_exporter))


@pytest.fixture(autouse=True)
def _otel_setup(monkeypatch: pytest.MonkeyPatch) -> InMemorySpanExporter:
"""Patch the module-level tracer to use our test provider and clear spans between tests."""
import mcp.shared.tracing as tracing_mod

monkeypatch.setattr(tracing_mod, "_tracer", _provider.get_tracer("mcp"))
_exporter.clear()
return _exporter


@pytest.mark.anyio
async def test_span_created_on_send_request(_otel_setup: InMemorySpanExporter) -> None:
"""Verify a CLIENT span is created when send_request() succeeds."""
exporter = _otel_setup

server = Server(name="test server")
async with Client(server) as client:
await client.send_ping()

spans = exporter.get_finished_spans()
# Filter to only the ping span (initialize also produces one)
ping_spans = [s for s in spans if s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "ping"]
assert len(ping_spans) == 1

span = ping_spans[0]
assert span.name == "ping"
assert span.kind == SpanKind.CLIENT
assert span.status.status_code == StatusCode.OK


@pytest.mark.anyio
async def test_span_attributes_for_tool_call(_otel_setup: InMemorySpanExporter) -> None:
"""Verify span name includes tool name for tools/call requests."""
exporter = _otel_setup

server = Server(name="test server")

@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [types.Tool(name="echo", description="Echo tool", input_schema={"type": "object"})]

@server.call_tool()
async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]:
return [types.TextContent(type="text", text=str(arguments))]

async with Client(server) as client:
await client.call_tool("echo", {"msg": "hi"})

spans = exporter.get_finished_spans()
tool_spans = [s for s in spans if s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "tools/call"]
assert len(tool_spans) == 1

span = tool_spans[0]
assert span.name == "tools/call echo"
assert span.status.status_code == StatusCode.OK


@pytest.mark.anyio
async def test_span_error_on_failure(_otel_setup: InMemorySpanExporter) -> None:
"""Verify span records ERROR status when the request times out."""
exporter = _otel_setup

server = Server(name="test server")

@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [types.Tool(name="slow_tool", description="Slow", input_schema={"type": "object"})]

@server.call_tool()
async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]:
await anyio.sleep(10)
return [] # pragma: no cover

async with Client(server) as client:
with pytest.raises(MCPError, match="Timed out"):
await client.session.send_request(
types.CallToolRequest(params=types.CallToolRequestParams(name="slow_tool", arguments={})),
types.CallToolResult,
request_read_timeout_seconds=0.01,
)

spans = exporter.get_finished_spans()
tool_spans = [s for s in spans if s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "tools/call"]
assert len(tool_spans) == 1

span = tool_spans[0]
assert span.status.status_code == StatusCode.ERROR
assert span.attributes is not None
assert span.attributes.get(ATTR_ERROR_TYPE) == "MCPError"


@pytest.mark.anyio
async def test_no_span_for_excluded_method(_otel_setup: InMemorySpanExporter) -> None:
"""Verify no span is created for excluded methods (notifications/message)."""
exporter = _otel_setup

server = Server(name="test server")
async with Client(server) as client:
await client.send_ping()

spans = exporter.get_finished_spans()
excluded_spans = [
s for s in spans if s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "notifications/message"
]
assert len(excluded_spans) == 0
65 changes: 65 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading