Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fastapi_mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
__version__ = "0.0.0.dev0" # pragma: no cover

from .server import FastApiMCP
from .types import AuthConfig, OAuthMetadata
from .types import AuthConfig, HTTPRequestInfo, OAuthMetadata


__all__ = [
"FastApiMCP",
"AuthConfig",
"HTTPRequestInfo",
"OAuthMetadata",
]
36 changes: 36 additions & 0 deletions fastapi_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def __init__(
"""
),
] = ["authorization"],
custom_tools: Annotated[
Optional[List[tuple[types.Tool, Any]]],
Doc("List of (Tool, handler) tuples for custom tools not derived from OpenAPI"),
] = None,
):
# Validate operation and tag filtering options
if include_operations is not None and exclude_operations is not None:
Expand All @@ -95,6 +99,7 @@ def __init__(
self.operation_map: Dict[str, Dict[str, Any]]
self.tools: List[types.Tool]
self.server: Server
self._custom_handlers: Dict[str, Any] = {}

self.fastapi = fastapi
self.name = name or self.fastapi.title or "FastAPI MCP"
Expand Down Expand Up @@ -123,6 +128,29 @@ def __init__(

self.setup_server()

if custom_tools:
for tool, handler in custom_tools:
self.add_tool(tool, handler)

def add_tool(
self,
tool: Annotated[types.Tool, Doc("MCP Tool definition")],
handler: Annotated[
Any,
Doc("Async handler: (name, arguments, http_request_info, server) -> list of MCP content types"),
],
) -> None:
"""Register a custom MCP tool with its handler.

Custom tools are listed alongside auto-generated OpenAPI tools.
The handler receives (name, arguments, http_request_info, server)
and must return a list of MCP content types.

Must be called before mount_http() / mount_sse().
"""
self.tools.append(tool)
self._custom_handlers[tool.name] = handler

def setup_server(self) -> None:
openapi_schema = get_openapi(
title=self.fastapi.title,
Expand Down Expand Up @@ -175,6 +203,14 @@ async def handle_call_tool(
except (LookupError, AttributeError) as e:
logger.error(f"Could not extract HTTP request info from context: {e}")

if name in self._custom_handlers:
return await self._custom_handlers[name](
name=name,
arguments=arguments,
http_request_info=http_request_info,
server=mcp_server,
)

return await self._execute_api_tool(
client=self._http_client,
tool_name=name,
Expand Down
202 changes: 202 additions & 0 deletions tests/test_custom_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import json
from typing import Any, Dict, List, Optional, Union

import pytest
import mcp.types as types
from mcp.server.lowlevel import Server
from mcp.shared.memory import create_connected_server_and_client_session
from fastapi import FastAPI

from fastapi_mcp import FastApiMCP
from fastapi_mcp.types import HTTPRequestInfo


# -- Helpers --

def _make_tool(name: str = "my_custom_tool", description: str = "A custom tool") -> types.Tool:
return types.Tool(
name=name,
description=description,
inputSchema={"type": "object", "properties": {"q": {"type": "string"}}},
)


async def _echo_handler(
name: str,
arguments: Dict[str, Any],
http_request_info: Optional[HTTPRequestInfo],
server: Server,
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
return [types.TextContent(type="text", text=json.dumps({"name": name, "arguments": arguments}))]


async def _inspect_handler(
name: str,
arguments: Dict[str, Any],
http_request_info: Optional[HTTPRequestInfo],
server: Server,
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
"""Handler that returns info about what it received for inspection."""
info = {
"name": name,
"arguments": arguments,
"has_request_info": http_request_info is not None,
"has_server": server is not None,
"server_type": type(server).__name__,
}
if http_request_info:
info["request_method"] = http_request_info.method
info["request_headers"] = http_request_info.headers
return [types.TextContent(type="text", text=json.dumps(info))]


async def _error_handler(
name: str,
arguments: Dict[str, Any],
http_request_info: Optional[HTTPRequestInfo],
server: Server,
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
raise ValueError("Custom tool error")


# -- Fixtures --

@pytest.fixture
def simple_app() -> FastAPI:
app = FastAPI(title="Test App")

@app.get("/items", operation_id="list_items")
async def list_items():
return [{"id": 1, "name": "Item 1"}]

return app


@pytest.fixture
def mcp_with_add_tool(simple_app: FastAPI) -> FastApiMCP:
mcp = FastApiMCP(simple_app, name="Test MCP")
mcp.add_tool(_make_tool(), _echo_handler)
mcp.mount()
return mcp


@pytest.fixture
def mcp_with_constructor_tools(simple_app: FastAPI) -> FastApiMCP:
mcp = FastApiMCP(
simple_app,
name="Test MCP",
custom_tools=[(_make_tool(), _echo_handler)],
)
mcp.mount()
return mcp


# -- Tests --

@pytest.mark.asyncio
async def test_custom_tool_listed(mcp_with_add_tool: FastApiMCP):
"""Custom tool appears in list_tools alongside auto-generated tools."""
async with create_connected_server_and_client_session(mcp_with_add_tool.server) as client_session:
tools_result = await client_session.list_tools()
tool_names = [t.name for t in tools_result.tools]

assert "my_custom_tool" in tool_names
assert "list_items" in tool_names


@pytest.mark.asyncio
async def test_custom_tool_called(mcp_with_add_tool: FastApiMCP):
"""Calling a custom tool dispatches to its handler and returns expected content."""
async with create_connected_server_and_client_session(mcp_with_add_tool.server) as client_session:
response = await client_session.call_tool("my_custom_tool", {"q": "hello"})

assert not response.isError
assert len(response.content) > 0

text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)

assert result["name"] == "my_custom_tool"
assert result["arguments"] == {"q": "hello"}


@pytest.mark.asyncio
async def test_custom_tool_receives_request_info(simple_app: FastAPI):
"""Handler receives HTTPRequestInfo (or None in test context)."""
mcp = FastApiMCP(simple_app, name="Test MCP")
mcp.add_tool(_make_tool("inspect_tool"), _inspect_handler)
mcp.mount()

async with create_connected_server_and_client_session(mcp.server) as client_session:
response = await client_session.call_tool("inspect_tool", {"q": "test"})

assert not response.isError
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)

assert result["name"] == "inspect_tool"
# In test context (memory transport), request info may not be available
assert "has_request_info" in result


@pytest.mark.asyncio
async def test_custom_tool_receives_server(simple_app: FastAPI):
"""Handler receives MCP Server instance."""
mcp = FastApiMCP(simple_app, name="Test MCP")
mcp.add_tool(_make_tool("inspect_tool"), _inspect_handler)
mcp.mount()

async with create_connected_server_and_client_session(mcp.server) as client_session:
response = await client_session.call_tool("inspect_tool", {"q": "test"})

assert not response.isError
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)

assert result["has_server"] is True
assert result["server_type"] == "Server"


@pytest.mark.asyncio
async def test_custom_tools_via_constructor(mcp_with_constructor_tools: FastApiMCP):
"""Tools passed via custom_tools= param work identically to add_tool()."""
async with create_connected_server_and_client_session(mcp_with_constructor_tools.server) as client_session:
tools_result = await client_session.list_tools()
tool_names = [t.name for t in tools_result.tools]

assert "my_custom_tool" in tool_names
assert "list_items" in tool_names

response = await client_session.call_tool("my_custom_tool", {"q": "constructor"})
assert not response.isError

text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert result["arguments"] == {"q": "constructor"}


@pytest.mark.asyncio
async def test_openapi_tools_still_work(mcp_with_add_tool: FastApiMCP):
"""Existing auto-generated tools are unaffected by custom tool registration."""
async with create_connected_server_and_client_session(mcp_with_add_tool.server) as client_session:
response = await client_session.call_tool("list_items", {})

assert not response.isError
assert len(response.content) > 0

text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert result == [{"id": 1, "name": "Item 1"}]


@pytest.mark.asyncio
async def test_custom_tool_error_handling(simple_app: FastAPI):
"""Handler raising an exception is properly surfaced."""
mcp = FastApiMCP(simple_app, name="Test MCP")
mcp.add_tool(_make_tool("error_tool"), _error_handler)
mcp.mount()

async with create_connected_server_and_client_session(mcp.server) as client_session:
response = await client_session.call_tool("error_tool", {})

assert response.isError