|
| 1 | +"""Model Context Protocol (MCP) server bridge. |
| 2 | +
|
| 3 | +Exposes every :class:`~automation_file.core.action_registry.ActionRegistry` |
| 4 | +entry as an MCP tool over JSON-RPC 2.0. The default transport is stdio — |
| 5 | +one JSON message per line — because that's what MCP host implementations |
| 6 | +(Claude Desktop, MCP CLIs) consume today. |
| 7 | +
|
| 8 | +Scope |
| 9 | +----- |
| 10 | +* ``initialize`` — handshake, returns ``serverInfo`` + capabilities |
| 11 | +* ``notifications/initialized`` — acknowledged as a no-op |
| 12 | +* ``tools/list`` — lists registered actions as MCP tools |
| 13 | +* ``tools/call`` — dispatches through the action registry |
| 14 | +
|
| 15 | +Errors surface as JSON-RPC error objects with a ``MCPServerException`` chain |
| 16 | +in the data field, so hosts can render them without having to parse the |
| 17 | +exception string. |
| 18 | +""" |
| 19 | + |
| 20 | +from __future__ import annotations |
| 21 | + |
| 22 | +import inspect |
| 23 | +import json |
| 24 | +import sys |
| 25 | +from collections.abc import Callable, Iterable |
| 26 | +from typing import Any, TextIO |
| 27 | + |
| 28 | +from automation_file.core.action_executor import executor |
| 29 | +from automation_file.core.action_registry import ActionRegistry |
| 30 | +from automation_file.exceptions import MCPServerException |
| 31 | +from automation_file.logging_config import file_automation_logger |
| 32 | + |
| 33 | +_JSONRPC_VERSION = "2.0" |
| 34 | +_PROTOCOL_VERSION = "2024-11-05" |
| 35 | + |
| 36 | +_PARSE_ERROR = -32700 |
| 37 | +_INVALID_REQUEST = -32600 |
| 38 | +_METHOD_NOT_FOUND = -32601 |
| 39 | +_INVALID_PARAMS = -32602 |
| 40 | +_INTERNAL_ERROR = -32603 |
| 41 | + |
| 42 | + |
| 43 | +class MCPServer: |
| 44 | + """Bridge between an MCP host and an :class:`ActionRegistry`.""" |
| 45 | + |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + registry: ActionRegistry | None = None, |
| 49 | + *, |
| 50 | + name: str = "automation_file", |
| 51 | + version: str = "1.0.0", |
| 52 | + ) -> None: |
| 53 | + self._registry = registry if registry is not None else executor.registry |
| 54 | + self._name = name |
| 55 | + self._version = version |
| 56 | + self._initialized = False |
| 57 | + |
| 58 | + def handle_message(self, message: dict[str, Any]) -> dict[str, Any] | None: |
| 59 | + """Dispatch a single decoded JSON-RPC message. |
| 60 | +
|
| 61 | + Returns the response dict for request messages, or ``None`` for |
| 62 | + notifications (which get no reply). Protocol-level errors return a |
| 63 | + JSON-RPC error object rather than raising. |
| 64 | + """ |
| 65 | + if not isinstance(message, dict) or message.get("jsonrpc") != _JSONRPC_VERSION: |
| 66 | + return _error_response(None, _INVALID_REQUEST, "invalid JSON-RPC envelope") |
| 67 | + |
| 68 | + method = message.get("method") |
| 69 | + msg_id = message.get("id") |
| 70 | + params = message.get("params") or {} |
| 71 | + |
| 72 | + if not isinstance(method, str): |
| 73 | + return _error_response(msg_id, _INVALID_REQUEST, "missing method") |
| 74 | + |
| 75 | + is_notification = msg_id is None |
| 76 | + try: |
| 77 | + if method == "initialize": |
| 78 | + result = self._handle_initialize(params) |
| 79 | + elif method == "notifications/initialized": |
| 80 | + self._initialized = True |
| 81 | + return None |
| 82 | + elif method == "tools/list": |
| 83 | + result = self._handle_tools_list() |
| 84 | + elif method == "tools/call": |
| 85 | + result = self._handle_tools_call(params) |
| 86 | + else: |
| 87 | + return _error_response(msg_id, _METHOD_NOT_FOUND, f"unknown method: {method}") |
| 88 | + except MCPServerException as error: |
| 89 | + return _error_response(msg_id, _INVALID_PARAMS, str(error)) |
| 90 | + except Exception as error: |
| 91 | + file_automation_logger.warning("mcp_server: internal error: %r", error) |
| 92 | + return _error_response(msg_id, _INTERNAL_ERROR, f"{type(error).__name__}: {error}") |
| 93 | + |
| 94 | + if is_notification: |
| 95 | + return None |
| 96 | + return {"jsonrpc": _JSONRPC_VERSION, "id": msg_id, "result": result} |
| 97 | + |
| 98 | + def serve_stdio( |
| 99 | + self, |
| 100 | + stdin: TextIO | None = None, |
| 101 | + stdout: TextIO | None = None, |
| 102 | + ) -> None: |
| 103 | + """Run the server over newline-delimited JSON on ``stdin`` / ``stdout``.""" |
| 104 | + reader = stdin if stdin is not None else sys.stdin |
| 105 | + writer = stdout if stdout is not None else sys.stdout |
| 106 | + for line in reader: |
| 107 | + stripped = line.strip() |
| 108 | + if not stripped: |
| 109 | + continue |
| 110 | + try: |
| 111 | + message = json.loads(stripped) |
| 112 | + except json.JSONDecodeError as error: |
| 113 | + self._write(writer, _error_response(None, _PARSE_ERROR, f"bad json: {error}")) |
| 114 | + continue |
| 115 | + response = self.handle_message(message) |
| 116 | + if response is not None: |
| 117 | + self._write(writer, response) |
| 118 | + |
| 119 | + def _handle_initialize(self, _params: dict[str, Any]) -> dict[str, Any]: |
| 120 | + return { |
| 121 | + "protocolVersion": _PROTOCOL_VERSION, |
| 122 | + "capabilities": {"tools": {"listChanged": False}}, |
| 123 | + "serverInfo": {"name": self._name, "version": self._version}, |
| 124 | + } |
| 125 | + |
| 126 | + def _handle_tools_list(self) -> dict[str, Any]: |
| 127 | + tools = [] |
| 128 | + for name, command in sorted(self._registry.event_dict.items()): |
| 129 | + tools.append( |
| 130 | + { |
| 131 | + "name": name, |
| 132 | + "description": _describe(command), |
| 133 | + "inputSchema": _schema_for(command), |
| 134 | + } |
| 135 | + ) |
| 136 | + return {"tools": tools} |
| 137 | + |
| 138 | + def _handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]: |
| 139 | + name = params.get("name") |
| 140 | + arguments = params.get("arguments") or {} |
| 141 | + if not isinstance(name, str) or not name: |
| 142 | + raise MCPServerException("tools/call requires a string 'name'") |
| 143 | + if not isinstance(arguments, dict): |
| 144 | + raise MCPServerException("'arguments' must be an object") |
| 145 | + command = self._registry.resolve(name) |
| 146 | + if command is None: |
| 147 | + raise MCPServerException(f"unknown tool: {name}") |
| 148 | + try: |
| 149 | + value = command(**arguments) |
| 150 | + except TypeError as error: |
| 151 | + raise MCPServerException(f"bad arguments for {name}: {error}") from error |
| 152 | + return { |
| 153 | + "content": [{"type": "text", "text": _serialise(value)}], |
| 154 | + "isError": False, |
| 155 | + } |
| 156 | + |
| 157 | + @staticmethod |
| 158 | + def _write(writer: TextIO, response: dict[str, Any]) -> None: |
| 159 | + writer.write(json.dumps(response, default=repr) + "\n") |
| 160 | + writer.flush() |
| 161 | + |
| 162 | + |
| 163 | +def _error_response(msg_id: object, code: int, message: str) -> dict[str, Any]: |
| 164 | + return { |
| 165 | + "jsonrpc": _JSONRPC_VERSION, |
| 166 | + "id": msg_id, |
| 167 | + "error": {"code": code, "message": message}, |
| 168 | + } |
| 169 | + |
| 170 | + |
| 171 | +def _describe(command: Callable[..., Any]) -> str: |
| 172 | + doc = inspect.getdoc(command) or "" |
| 173 | + return doc.splitlines()[0] if doc else "Registered automation_file action." |
| 174 | + |
| 175 | + |
| 176 | +def _schema_for(command: Callable[..., Any]) -> dict[str, Any]: |
| 177 | + try: |
| 178 | + signature = inspect.signature(command) |
| 179 | + except (TypeError, ValueError): |
| 180 | + return {"type": "object", "properties": {}, "additionalProperties": True} |
| 181 | + properties: dict[str, Any] = {} |
| 182 | + required: list[str] = [] |
| 183 | + for parameter in signature.parameters.values(): |
| 184 | + if parameter.kind in ( |
| 185 | + inspect.Parameter.VAR_POSITIONAL, |
| 186 | + inspect.Parameter.VAR_KEYWORD, |
| 187 | + ): |
| 188 | + continue |
| 189 | + if parameter.name in {"self", "cls"}: |
| 190 | + continue |
| 191 | + properties[parameter.name] = _json_schema_for(parameter.annotation) |
| 192 | + if parameter.default is inspect.Parameter.empty: |
| 193 | + required.append(parameter.name) |
| 194 | + schema: dict[str, Any] = { |
| 195 | + "type": "object", |
| 196 | + "properties": properties, |
| 197 | + "additionalProperties": True, |
| 198 | + } |
| 199 | + if required: |
| 200 | + schema["required"] = required |
| 201 | + return schema |
| 202 | + |
| 203 | + |
| 204 | +def _json_schema_for(annotation: Any) -> dict[str, Any]: |
| 205 | + if annotation is inspect.Parameter.empty: |
| 206 | + return {} |
| 207 | + mapping: dict[type, str] = { |
| 208 | + str: "string", |
| 209 | + int: "integer", |
| 210 | + float: "number", |
| 211 | + bool: "boolean", |
| 212 | + list: "array", |
| 213 | + dict: "object", |
| 214 | + } |
| 215 | + if isinstance(annotation, type) and annotation in mapping: |
| 216 | + return {"type": mapping[annotation]} |
| 217 | + return {} |
| 218 | + |
| 219 | + |
| 220 | +def _serialise(value: Any) -> str: |
| 221 | + try: |
| 222 | + return json.dumps(value, default=repr) |
| 223 | + except (TypeError, ValueError): |
| 224 | + return repr(value) |
| 225 | + |
| 226 | + |
| 227 | +def tools_from_registry(registry: ActionRegistry) -> Iterable[dict[str, Any]]: |
| 228 | + """Yield MCP-shaped tool descriptors for every entry in ``registry``. |
| 229 | +
|
| 230 | + Exposed separately so GUIs and tests can render the same catalogue |
| 231 | + without instantiating :class:`MCPServer`. |
| 232 | + """ |
| 233 | + server = MCPServer(registry) |
| 234 | + yield from server._handle_tools_list()["tools"] |
0 commit comments