forked from mcp-use/mcp-use
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmiddleware_example.py
More file actions
267 lines (206 loc) · 9.06 KB
/
middleware_example.py
File metadata and controls
267 lines (206 loc) · 9.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
MCP Server Middleware Examples
This file demonstrates the middleware system for MCP servers, showing common patterns
like logging, authentication, rate limiting, and validation.
Key concepts demonstrated:
- on_initialize: Intercept client connections during MCP handshake
- on_request: Intercept ALL requests (wraps other hooks)
- on_call_tool: Intercept only tool calls
- Typed context: Each hook gets fully-typed context.message (IDE autocomplete works!)
- Middleware order: First added = outermost (sees requests first, responses last)
Session IDs come from the SDK via the mcp-session-id header (context.session_id).
Spec: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#session-management
Run with: python middleware_example.py
"""
import time
from collections import defaultdict
from datetime import datetime
from typing import Any
from mcp_use.server import MCPServer
from mcp_use.server.middleware import CallNext, Middleware, ServerMiddlewareContext
# =============================================================================
# CONNECTION MIDDLEWARE (on_initialize)
# These run during the MCP handshake, before any tools/resources are accessed.
# =============================================================================
class ConnectionGuard(Middleware):
"""Log incoming client connections.
Uses on_initialize to intercept the MCP handshake. The context.message is
typed as InitializeRequestParams, giving you access to:
- clientInfo.name: Client application name
- clientInfo.version: Client version
- protocolVersion: MCP protocol version
- capabilities: What the client supports
"""
async def on_initialize(self, context, call_next):
client_name = context.message.clientInfo.name
print(f"Incoming connection from: {client_name}")
return await call_next(context)
class ClientCapabilitiesGuard(Middleware):
"""Reject clients that don't support required capabilities.
Demonstrates:
- Typed context: context.message.capabilities is fully typed
- Early rejection: Raise an exception to reject the connection
"""
async def on_initialize(self, context, call_next) -> Any:
print(f"Client capabilities: {context.message.capabilities}")
capabilities = context.message.capabilities
if not capabilities.elicitation:
raise ValueError("Client must support elicitation")
if not capabilities.sampling:
raise ValueError("Client must support sampling")
if not capabilities.roots:
raise ValueError("Client must support roots")
return await call_next(context)
# =============================================================================
# REQUEST MIDDLEWARE (on_request)
# These run for EVERY request type, wrapping all other hooks.
# =============================================================================
class LoggingMiddleware(Middleware):
"""Log all requests with timing.
Demonstrates:
- on_request: Runs for ALL requests (tools, resources, prompts, etc.)
- on_call_tool: Additional logging specific to tool calls
- Hook nesting: on_request wraps on_call_tool when both are defined
- Error handling: Catch, log, and re-raise exceptions
"""
async def on_request(self, context: ServerMiddlewareContext[Any], call_next: CallNext[Any, Any]) -> Any:
start = time.time()
sid = context.session_id or "anonymous"
print(f"[{context.method}] sid={sid}")
try:
result = await call_next(context)
print(f"[{context.method}] ok ({(time.time() - start) * 1000:.1f}ms)")
return result
except Exception as e:
print(f"[{context.method}] failed: {e}")
raise
async def on_call_tool(self, context, call_next):
# This runs INSIDE on_request for tool calls
print(f" tool={context.message.name} args={context.message.arguments or {}}")
return await call_next(context)
# =============================================================================
# TOOL-SPECIFIC MIDDLEWARE (on_call_tool)
# These only run for tool calls, not for resources or prompts.
# =============================================================================
class AuthenticationMiddleware(Middleware):
"""Check x-api-key header for tool calls.
Demonstrates:
- on_call_tool: Only intercepts tool execution
- Header access: context.headers (available on HTTP transports)
- Configurable middleware: Pass valid keys via constructor
"""
def __init__(self, valid_api_keys: set[str] | None = None):
self.valid_api_keys = valid_api_keys or {"test-key-123"}
async def on_call_tool(self, context, call_next):
api_key = context.headers.get("x-api-key") if context.headers else None
if api_key and api_key not in self.valid_api_keys:
raise PermissionError("Invalid API key")
return await call_next(context)
class RateLimitingMiddleware(Middleware):
"""Limit tool calls per session.
Demonstrates:
- Session-based state: Use context.session_id to track per-client limits
- Stateful middleware: Store request history in instance variables
- Sliding window: Only count requests from last 60 seconds
"""
def __init__(self, max_requests_per_minute: int = 30):
self.max = max_requests_per_minute
self.seen: dict[str, list[datetime]] = defaultdict(list)
async def on_call_tool(self, context, call_next):
sid = context.session_id or "anonymous"
now = datetime.now()
self.seen[sid] = [t for t in self.seen[sid] if (now - t).total_seconds() < 60]
if len(self.seen[sid]) >= self.max:
raise Exception(f"Rate limit exceeded ({self.max}/min) for session {sid[:8]}")
self.seen[sid].append(now)
return await call_next(context)
class ValidationMiddleware(Middleware):
"""Validate tool arguments before execution.
Demonstrates:
- Typed context: context.message is CallToolRequestParams
- Tool-specific validation: Check message.name to apply rules per tool
- Early rejection: Raise before call_next to skip execution
"""
async def on_call_tool(self, context, call_next):
if context.message.name == "echo":
msg = (context.message.arguments or {}).get("message", "")
if len(msg) > 200:
raise ValueError("message too long")
return await call_next(context)
# =============================================================================
# SERVER SETUP
# Middleware order matters: first added = outermost (sees requests first)
# =============================================================================
server = MCPServer(
name="Middleware Demo Server",
version="1.0.0",
instructions="Minimal middleware demo",
middleware=[
# 1. Logging first - sees all requests including rejected ones
LoggingMiddleware(),
# 2. Auth early - reject unauthorized before expensive operations
AuthenticationMiddleware(),
# 3. Rate limiting - protect resources from abuse
RateLimitingMiddleware(max_requests_per_minute=10),
# 4. Validation - check data before handler runs
ValidationMiddleware(),
# 5-6. Connection guards - run during handshake (on_initialize)
ConnectionGuard(),
ClientCapabilitiesGuard(),
],
debug=True,
pretty_print_jsonrpc=True,
)
# =============================================================================
# TOOLS, RESOURCES, AND PROMPTS
# =============================================================================
@server.tool()
def echo(message: str) -> str:
"""Echo back a message."""
return f"Echo: {message}"
@server.tool()
def session_id() -> dict[str, Any]:
"""Return current session ID (None on stdio)."""
sid = server._get_session_id_from_request()
return {
"session_id": sid,
"session_id_short": sid[:8] if sid else None,
"is_anonymous": sid is None,
"note": "From mcp-session-id header",
}
@server.resource(uri="info://middleware", name="middleware_info", title="Middleware Info", mime_type="text/plain")
def middleware_info() -> str:
"""List active middleware."""
return f"""Active Middleware:
1. LoggingMiddleware
2. AuthenticationMiddleware
3. RateLimitingMiddleware
4. ValidationMiddleware
Generated at: {datetime.now().isoformat()}
"""
@server.resource(uri="data://config", name="config", title="Server Configuration", mime_type="application/json")
def get_config() -> str:
"""Server configuration."""
return """{
"server": "Middleware Demo",
"version": "1.0.0",
"features": ["logging", "auth", "rate-limit", "validation"]
}"""
@server.prompt(name="help")
def help_prompt() -> str:
"""Help prompt."""
return """# Middleware Demo Server
Middleware:
- Logging (request + tools)
- Auth (x-api-key)
- Rate limiting (per session)
- Validation (echo length)
Tools:
- echo(message)
- session_id()
Resources:
- info://middleware
- data://config
"""
if __name__ == "__main__":
server.run(transport="streamable-http", port=8000)