Skip to content

Commit e4889b7

Browse files
committed
fix(types): type converted MCPServer handler results
1 parent b33c811 commit e4889b7

File tree

6 files changed

+52
-20
lines changed

6 files changed

+52
-20
lines changed

src/mcp/server/mcpserver/prompts/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ class Prompt(BaseModel):
6969
title: str | None = Field(None, description="Human-readable title of the prompt")
7070
description: str | None = Field(None, description="Description of what the prompt does")
7171
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
72-
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
72+
fn: Callable[..., PromptResult] = Field(exclude=True)
7373
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this prompt")
7474
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)
7575

7676
@classmethod
7777
def from_function(
7878
cls,
79-
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
79+
fn: Callable[..., PromptResult],
8080
name: str | None = None,
8181
title: str | None = None,
8282
description: str | None = None,

src/mcp/server/mcpserver/resources/templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def from_function(
8282
context_kwarg=context_kwarg,
8383
)
8484

85-
def matches(self, uri: str) -> dict[str, Any] | None:
85+
def matches(self, uri: str) -> dict[str, str] | None:
8686
"""Check if URI matches template and extract parameters.
8787
8888
Extracted parameters are URL-decoded to handle percent-encoded characters.

src/mcp/server/mcpserver/server.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
import base64
66
import inspect
7-
import json
87
import re
98
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
109
from contextlib import AbstractAsyncContextManager, asynccontextmanager
11-
from typing import Any, Generic, Literal, TypeVar, overload
10+
from typing import Any, Generic, Literal, TypeVar, cast, overload
1211

1312
import anyio
1413
import pydantic_core
@@ -36,6 +35,7 @@
3635
from mcp.server.mcpserver.resources import FunctionResource, Resource, ResourceManager
3736
from mcp.server.mcpserver.tools import Tool, ToolManager
3837
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
38+
from mcp.server.mcpserver.utilities.func_metadata import ConvertedToolResult
3939
from mcp.server.mcpserver.utilities.logging import configure_logging, get_logger
4040
from mcp.server.sse import SseServerTransport
4141
from mcp.server.stdio import stdio_server
@@ -308,18 +308,13 @@ async def _handle_call_tool(
308308
if isinstance(result, CallToolResult):
309309
return result
310310
if isinstance(result, tuple) and len(result) == 2:
311-
unstructured_content, structured_content = result
312-
return CallToolResult(
313-
content=list(unstructured_content), # type: ignore[arg-type]
314-
structured_content=structured_content, # type: ignore[arg-type]
311+
unstructured_content, structured_content = cast(
312+
tuple[Sequence[ContentBlock], dict[str, Any]],
313+
result,
315314
)
316-
if isinstance(result, dict): # pragma: no cover
317-
# TODO: this code path is unreachable — convert_result never returns a raw dict.
318-
# The call_tool return type (Sequence[ContentBlock] | dict[str, Any]) is wrong
319-
# and needs to be cleaned up.
320315
return CallToolResult(
321-
content=[TextContent(type="text", text=json.dumps(result, indent=2))],
322-
structured_content=result,
316+
content=list(unstructured_content),
317+
structured_content=structured_content,
323318
)
324319
return CallToolResult(content=list(result))
325320

@@ -390,7 +385,7 @@ async def list_tools(self) -> list[MCPTool]:
390385

391386
async def call_tool(
392387
self, name: str, arguments: dict[str, Any], context: Context[LifespanResultT, Any] | None = None
393-
) -> Sequence[ContentBlock] | dict[str, Any]:
388+
) -> ConvertedToolResult:
394389
"""Call a tool by name with arguments."""
395390
if context is None:
396391
context = Context(mcp_server=self)

src/mcp/server/mcpserver/tools/base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import inspect
55
from collections.abc import Callable
66
from functools import cached_property
7-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING, Any, Literal, overload
88

99
from pydantic import BaseModel, Field
1010

1111
from mcp.server.mcpserver.exceptions import ToolError
1212
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
13-
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata
13+
from mcp.server.mcpserver.utilities.func_metadata import ConvertedToolResult, FuncMetadata, func_metadata
1414
from mcp.shared.exceptions import UrlElicitationRequiredError
1515
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
1616
from mcp.types import Icon, ToolAnnotations
@@ -89,6 +89,22 @@ def from_function(
8989
meta=meta,
9090
)
9191

92+
@overload
93+
async def run(
94+
self,
95+
arguments: dict[str, Any],
96+
context: Context[LifespanContextT, RequestT],
97+
convert_result: Literal[True],
98+
) -> ConvertedToolResult: ...
99+
100+
@overload
101+
async def run(
102+
self,
103+
arguments: dict[str, Any],
104+
context: Context[LifespanContextT, RequestT],
105+
convert_result: Literal[False] = False,
106+
) -> Any: ...
107+
92108
async def run(
93109
self,
94110
arguments: dict[str, Any],

src/mcp/server/mcpserver/tools/tool_manager.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, Literal, overload
55

66
from mcp.server.mcpserver.exceptions import ToolError
77
from mcp.server.mcpserver.tools.base import Tool
8+
from mcp.server.mcpserver.utilities.func_metadata import ConvertedToolResult
89
from mcp.server.mcpserver.utilities.logging import get_logger
910
from mcp.types import Icon, ToolAnnotations
1011

@@ -77,6 +78,24 @@ def remove_tool(self, name: str) -> None:
7778
raise ToolError(f"Unknown tool: {name}")
7879
del self._tools[name]
7980

81+
@overload
82+
async def call_tool(
83+
self,
84+
name: str,
85+
arguments: dict[str, Any],
86+
context: Context[LifespanContextT, RequestT],
87+
convert_result: Literal[True],
88+
) -> ConvertedToolResult: ...
89+
90+
@overload
91+
async def call_tool(
92+
self,
93+
name: str,
94+
arguments: dict[str, Any],
95+
context: Context[LifespanContextT, RequestT],
96+
convert_result: Literal[False] = False,
97+
) -> Any: ...
98+
8099
async def call_tool(
81100
self,
82101
name: str,

src/mcp/server/mcpserver/utilities/func_metadata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Awaitable, Callable, Sequence
55
from itertools import chain
66
from types import GenericAlias
7-
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
7+
from typing import Annotated, Any, TypeAlias, cast, get_args, get_origin, get_type_hints
88

99
import anyio
1010
import anyio.to_thread
@@ -28,6 +28,8 @@
2828

2929
logger = get_logger(__name__)
3030

31+
ConvertedToolResult: TypeAlias = CallToolResult | Sequence[ContentBlock] | tuple[Sequence[ContentBlock], dict[str, Any]]
32+
3133

3234
class StrictJsonSchema(GenerateJsonSchema):
3335
"""A JSON schema generator that raises exceptions instead of emitting warnings.

0 commit comments

Comments
 (0)