Skip to content

Commit 01a0157

Browse files
committed
Wire MCP server, SSE transport, and Dash app integration
1 parent 8c1f392 commit 01a0157

19 files changed

Lines changed: 2508 additions & 2 deletions

dash/_configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def load_dash_env_vars():
3232
"DASH_DISABLE_VERSION_CHECK",
3333
"DASH_PRUNE_ERRORS",
3434
"DASH_COMPRESS",
35+
"DASH_MCP_ENABLED",
36+
"DASH_MCP_PATH",
3537
"HOST",
3638
"PORT",
3739
)

dash/dash.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ def __init__( # pylint: disable=too-many-statements
472472
on_error: Optional[Callable[[Exception], Any]] = None,
473473
use_async: Optional[bool] = None,
474474
health_endpoint: Optional[str] = None,
475+
enable_mcp: Optional[bool] = None,
476+
mcp_path: Optional[str] = None,
475477
**obsolete,
476478
):
477479

@@ -573,6 +575,13 @@ def __init__( # pylint: disable=too-many-statements
573575
# keep title as a class property for backwards compatibility
574576
self.title = title
575577

578+
# MCP (Model Context Protocol) configuration
579+
self._enable_mcp = get_combined_config("mcp_enabled", enable_mcp, True)
580+
_mcp_path = get_combined_config("mcp_path", mcp_path, "_mcp")
581+
self._mcp_path = (
582+
_mcp_path.lstrip("/") if isinstance(_mcp_path, str) else _mcp_path
583+
)
584+
576585
# list of dependencies - this one is used by the back end for dispatching
577586
self.callback_map: dict = {}
578587
# same deps as a list to catch duplicate outputs, and to send to the front end
@@ -793,6 +802,21 @@ def _setup_routes(self):
793802
hook.data["methods"],
794803
)
795804

805+
if self._enable_mcp:
806+
from .mcp import ( # pylint: disable=import-outside-toplevel
807+
enable_mcp_server,
808+
)
809+
810+
try:
811+
enable_mcp_server(self, self._mcp_path)
812+
except Exception as e: # pylint: disable=broad-exception-caught
813+
self._enable_mcp = False
814+
self.logger.warning(
815+
"MCP server could not be started at '%s': %s",
816+
self._mcp_path,
817+
e,
818+
)
819+
796820
# catch-all for front-end routes, used by dcc.Location
797821
self._add_url("<path:path>", self.index)
798822

@@ -2526,6 +2550,13 @@ def verify_url_part(served_part, url_part, part_name):
25262550

25272551
if not jupyter_dash or not jupyter_dash.in_ipython:
25282552
self.logger.info("Dash is running on %s://%s%s%s\n", *display_url)
2553+
if self._enable_mcp:
2554+
self.logger.info(
2555+
" * MCP available at %s://%s%s%s%s\n",
2556+
*display_url[:3],
2557+
self.config.routes_pathname_prefix,
2558+
self._mcp_path,
2559+
)
25292560

25302561
if self.config.extra_hot_reload_paths:
25312562
extra_files = flask_run_options["extra_files"] = []

dash/mcp/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Dash MCP (Model Context Protocol) server integration."""
2+
3+
from dash.mcp._server import enable_mcp_server
4+
5+
__all__ = [
6+
enable_mcp_server,
7+
]

dash/mcp/_server.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
"""Flask route setup, Streamable HTTP transport, and MCP message handling."""
2+
3+
from __future__ import annotations
4+
5+
import atexit
6+
import json
7+
import logging
8+
import uuid
9+
from typing import TYPE_CHECKING, Any
10+
11+
from flask import Response, request
12+
13+
from dash.mcp.types import MCPError
14+
15+
if TYPE_CHECKING:
16+
from dash import Dash
17+
18+
from dash import get_app
19+
20+
from mcp.types import (
21+
LATEST_PROTOCOL_VERSION,
22+
ErrorData,
23+
Implementation,
24+
InitializeResult,
25+
JSONRPCError,
26+
JSONRPCResponse,
27+
ResourcesCapability,
28+
ServerCapabilities,
29+
ToolsCapability,
30+
)
31+
32+
from dash.version import __version__
33+
from dash.mcp._sse import (
34+
close_sse_stream,
35+
create_sse_stream,
36+
shutdown_all_streams,
37+
)
38+
from dash.mcp.primitives import (
39+
call_tool,
40+
list_resource_templates,
41+
list_resources,
42+
list_tools,
43+
read_resource,
44+
)
45+
from dash.mcp.primitives.tools.callback_adapter_collection import (
46+
CallbackAdapterCollection,
47+
)
48+
49+
logger = logging.getLogger(__name__)
50+
51+
52+
def enable_mcp_server(app: Dash, mcp_path: str) -> None:
53+
"""
54+
Add MCP routes to a Dash/Flask app.
55+
56+
Registers a single Streamable HTTP endpoint for the MCP protocol.
57+
Uses ``app._add_url()`` so that ``routes_pathname_prefix`` is applied
58+
automatically.
59+
60+
Args:
61+
app: The Dash application instance.
62+
mcp_path: Route prefix for MCP endpoints.
63+
"""
64+
# Session storage: session_id -> metadata
65+
sessions: dict[str, dict[str, Any]] = {}
66+
67+
def _create_session() -> str:
68+
sid = str(uuid.uuid4())
69+
sessions[sid] = {}
70+
return sid
71+
72+
# -- Streamable HTTP endpoint --------------------------------------------
73+
74+
def mcp_handler() -> Response:
75+
if request.method == "POST":
76+
return _handle_post()
77+
if request.method == "GET":
78+
return _handle_get()
79+
if request.method == "DELETE":
80+
return _handle_delete()
81+
return Response(
82+
json.dumps({"error": "Method not allowed"}),
83+
content_type="application/json",
84+
status=405,
85+
)
86+
87+
def _handle_get() -> Response:
88+
session_id = request.headers.get("mcp-session-id")
89+
if not session_id or session_id not in sessions:
90+
return Response(
91+
json.dumps({"error": "Session not found"}),
92+
content_type="application/json",
93+
status=404,
94+
)
95+
return create_sse_stream(sessions, session_id)
96+
97+
def _handle_post() -> Response:
98+
content_type = request.content_type or ""
99+
if "application/json" not in content_type:
100+
return Response(
101+
json.dumps({"error": "Content-Type must be application/json"}),
102+
content_type="application/json",
103+
status=415,
104+
)
105+
106+
try:
107+
data = request.get_json()
108+
except Exception:
109+
return Response(
110+
json.dumps({"error": "Invalid JSON"}),
111+
content_type="application/json",
112+
status=400,
113+
)
114+
115+
method = data.get("method", "")
116+
request_id = data.get("id")
117+
session_id = request.headers.get("mcp-session-id")
118+
119+
stale_session = False
120+
if method == "initialize":
121+
session_id = _create_session()
122+
elif session_id and session_id not in sessions:
123+
stale_session = True
124+
sessions[session_id] = {}
125+
elif not session_id:
126+
session_id = _create_session()
127+
128+
response_data = _process_mcp_message(data)
129+
130+
if response_data is None:
131+
return Response("", status=202)
132+
133+
if stale_session:
134+
_inject_warning(response_data, _STALE_SESSION_WARNING)
135+
136+
return Response(
137+
json.dumps(response_data),
138+
content_type="application/json",
139+
status=200,
140+
headers={"mcp-session-id": session_id},
141+
)
142+
143+
def _handle_delete() -> Response:
144+
session_id = request.headers.get("mcp-session-id")
145+
if not session_id or session_id not in sessions:
146+
return Response(
147+
json.dumps({"error": "Session not found"}),
148+
content_type="application/json",
149+
status=404,
150+
)
151+
close_sse_stream(sessions[session_id])
152+
del sessions[session_id]
153+
logger.info("MCP session terminated: %s", session_id)
154+
return Response("", status=204)
155+
156+
# -- Register routes -----------------------------------------------------
157+
158+
from dash._get_app import with_app_context_factory
159+
160+
app._add_url(
161+
mcp_path, with_app_context_factory(mcp_handler, app), ["GET", "POST", "DELETE"]
162+
)
163+
164+
# Close all SSE streams on server shutdown so MCP clients see a
165+
# clean stream end and can reconnect promptly.
166+
atexit.register(shutdown_all_streams, sessions)
167+
168+
logger.info(
169+
"MCP routes registered at %s%s",
170+
app.config.routes_pathname_prefix,
171+
mcp_path,
172+
)
173+
174+
175+
_STALE_SESSION_WARNING = (
176+
"[Warning: your session was not recognised"
177+
" — the app may have restarted."
178+
" Please call tools/list to refresh your tool list."
179+
" Please ask the user to reconnect to the MCP server.]"
180+
)
181+
182+
183+
def _inject_warning(response_data: dict[str, Any], warning: str) -> None:
184+
"""Append a warning to a JSON-RPC response dict.
185+
186+
For successful ``tools/call`` responses the warning is added as an
187+
extra text content block so the agent sees it alongside the result.
188+
For error responses the warning is appended to the error message.
189+
Other responses (tools/list, resources/*) are left unchanged — the
190+
JSON-RPC spec forbids extra top-level keys.
191+
"""
192+
# tools/call success: result has a "content" list
193+
result = response_data.get("result")
194+
if isinstance(result, dict) and isinstance(result.get("content"), list):
195+
result["content"].append({"type": "text", "text": warning})
196+
return
197+
198+
# Error response
199+
error = response_data.get("error")
200+
if isinstance(error, dict) and "message" in error:
201+
error["message"] += " " + warning
202+
203+
204+
def _handle_initialize() -> InitializeResult:
205+
return InitializeResult(
206+
protocolVersion=LATEST_PROTOCOL_VERSION,
207+
capabilities=ServerCapabilities(
208+
tools=ToolsCapability(listChanged=True),
209+
resources=ResourcesCapability(),
210+
),
211+
serverInfo=Implementation(name="Plotly Dash", version=__version__),
212+
instructions=(
213+
"This is a Dash web application. "
214+
"Dash apps are stateless: calling a tool executes "
215+
"a callback and returns its result to you, but does "
216+
"NOT update the user's browser. "
217+
"Use tool results to answer questions about what "
218+
"the app would produce for given inputs."
219+
),
220+
)
221+
222+
223+
def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None:
224+
"""
225+
Process an MCP JSON-RPC message and return the response dict.
226+
227+
Returns ``None`` for notifications (no ``id`` field).
228+
"""
229+
method = data.get("method", "")
230+
params = data.get("params", {}) or {}
231+
request_id = data.get("id")
232+
233+
app = get_app()
234+
if not hasattr(app, "mcp_callback_map"):
235+
app.mcp_callback_map = CallbackAdapterCollection(app)
236+
237+
mcp_methods = {
238+
"initialize": _handle_initialize,
239+
"tools/list": lambda: list_tools(),
240+
"tools/call": lambda: call_tool(
241+
params.get("name", ""), params.get("arguments", {})
242+
),
243+
"resources/list": lambda: list_resources(),
244+
"resources/templates/list": lambda: list_resource_templates(),
245+
"resources/read": lambda: read_resource(params.get("uri", "")),
246+
}
247+
248+
try:
249+
handler = mcp_methods.get(method)
250+
if handler is None:
251+
if method.startswith("notifications/"):
252+
return None
253+
raise ValueError(f"Unknown method: {method}")
254+
255+
result = handler()
256+
257+
response = JSONRPCResponse(
258+
jsonrpc="2.0",
259+
id=request_id,
260+
result=result.model_dump(exclude_none=True, mode="json"),
261+
)
262+
return response.model_dump(exclude_none=True, mode="json")
263+
264+
except MCPError as e:
265+
logger.error("MCP error: %s", e)
266+
return JSONRPCError(
267+
jsonrpc="2.0",
268+
id=request_id,
269+
error=ErrorData(code=e.code, message=str(e)),
270+
).model_dump(exclude_none=True)
271+
except Exception as e:
272+
logger.error("MCP error: %s", e, exc_info=True)
273+
return JSONRPCError(
274+
jsonrpc="2.0",
275+
id=request_id,
276+
error=ErrorData(code=-32603, message=f"{type(e).__name__}: {e}"),
277+
).model_dump(exclude_none=True)

0 commit comments

Comments
 (0)