Skip to content

Commit f84d24d

Browse files
committed
fix: address review feedback on plugin system
- Fix ruff lint errors (import order, unused import, type annotation) - Build tool→plugin map at registration time for O(1) dispatch - Make plugin dispatch authoritative (first declaring plugin owns it) - Track successfully started plugins; only shutdown those on exit
1 parent 828f1e7 commit f84d24d

4 files changed

Lines changed: 36 additions & 21 deletions

File tree

src/ros2_medkit_mcp/mcp_app.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from ros2_medkit_mcp.client import SovdClient, SovdClientError
1616
from ros2_medkit_mcp.config import Settings
17-
from ros2_medkit_mcp.plugin import McpPlugin
1817
from ros2_medkit_mcp.models import (
1918
AppIdArgs,
2019
AreaComponentsArgs,
@@ -59,6 +58,7 @@
5958
UpdateExecutionArgs,
6059
filter_entities,
6160
)
61+
from ros2_medkit_mcp.plugin import McpPlugin
6262

6363
logger = logging.getLogger(__name__)
6464

@@ -631,14 +631,18 @@ async def download_rosbags_for_fault(
631631
}
632632

633633

634-
def register_tools(server: Server, client: SovdClient, plugins: list[McpPlugin] | None = None) -> None:
634+
def register_tools(
635+
server: Server, client: SovdClient, plugins: list[McpPlugin] | None = None
636+
) -> None:
635637
"""Register all MCP tools on the server.
636638
637639
Args:
638640
server: The MCP server to register tools on.
639641
client: The SOVD client for making API calls.
640642
plugins: Optional list of plugins providing additional tools.
641643
"""
644+
# Tool name → plugin mapping, built during list_tools and used for dispatch
645+
plugin_tool_map: dict[str, McpPlugin] = {}
642646

643647
@server.list_tools()
644648
async def list_tools() -> list[Tool]:
@@ -1497,7 +1501,10 @@ async def list_tools() -> list[Tool]:
14971501
if plugins:
14981502
for plugin in plugins:
14991503
try:
1500-
tools.extend(plugin.list_tools())
1504+
plugin_tools = plugin.list_tools()
1505+
tools.extend(plugin_tools)
1506+
for t in plugin_tools:
1507+
plugin_tool_map[t.name] = plugin
15011508
except Exception:
15021509
logger.exception("Failed to list tools from plugin: %s", plugin.name)
15031510
return tools
@@ -1804,15 +1811,10 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
18041811
)
18051812

18061813
else:
1807-
# Try plugins before reporting unknown tool
1808-
if plugins:
1809-
for plugin in plugins:
1810-
try:
1811-
plugin_tool_names = {t.name for t in plugin.list_tools()}
1812-
if normalized_name in plugin_tool_names:
1813-
return await plugin.call_tool(normalized_name, arguments)
1814-
except Exception:
1815-
logger.exception("Plugin %s failed to handle tool %s", plugin.name, normalized_name)
1814+
# Check plugin tool map before reporting unknown tool
1815+
plugin = plugin_tool_map.get(normalized_name)
1816+
if plugin is not None:
1817+
return await plugin.call_tool(normalized_name, arguments)
18161818
return format_error(f"Unknown tool: {name}")
18171819

18181820
except SovdClientError as e:
@@ -1868,7 +1870,9 @@ async def read_resource(uri: str) -> list[TextContent]:
18681870
raise ValueError(f"Unknown resource URI: {uri}")
18691871

18701872

1871-
def setup_mcp_app(server: Server, settings: Settings, client: SovdClient, plugins: list[McpPlugin] | None = None) -> None:
1873+
def setup_mcp_app(
1874+
server: Server, settings: Settings, client: SovdClient, plugins: list[McpPlugin] | None = None
1875+
) -> None:
18721876
"""Set up the complete MCP application.
18731877
18741878
Args:

src/ros2_medkit_mcp/server_http.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ async def health_check(_request: Request) -> JSONResponse:
7979
}
8080
)
8181

82+
started_plugins: list = []
83+
8284
async def on_startup() -> None:
8385
"""Application startup handler."""
8486
logger.info("ros2_medkit MCP server starting (HTTP transport)")
@@ -87,14 +89,15 @@ async def on_startup() -> None:
8789
for plugin in plugins:
8890
try:
8991
await plugin.startup()
92+
started_plugins.append(plugin)
9093
logger.info("Plugin started: %s", plugin.name)
9194
except Exception:
9295
logger.exception("Failed to start plugin: %s", plugin.name)
9396

9497
async def on_shutdown() -> None:
9598
"""Application shutdown handler."""
96-
# Shutdown plugins
97-
for plugin in plugins:
99+
# Only shutdown plugins that started successfully
100+
for plugin in started_plugins:
98101
try:
99102
await plugin.shutdown()
100103
except Exception:

src/ros2_medkit_mcp/server_stdio.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ async def run_server() -> None:
3434
client = SovdClient(settings)
3535
plugins = discover_plugins()
3636

37+
started_plugins = []
3738
try:
3839
# Start plugins
3940
for plugin in plugins:
4041
try:
4142
await plugin.startup()
43+
started_plugins.append(plugin)
4244
logger.info("Plugin started: %s", plugin.name)
4345
except Exception:
4446
logger.exception("Failed to start plugin: %s", plugin.name)
@@ -52,8 +54,8 @@ async def run_server() -> None:
5254
server.create_initialization_options(),
5355
)
5456
finally:
55-
# Shutdown plugins
56-
for plugin in plugins:
57+
# Only shutdown plugins that started successfully
58+
for plugin in started_plugins:
5759
try:
5860
await plugin.shutdown()
5961
except Exception:

tests/test_plugin_discovery.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Tests for MCP plugin discovery and integration."""
22

3+
from typing import Any
34
from unittest.mock import MagicMock, patch
45

5-
import pytest
66
from mcp.types import TextContent, Tool
77

88
from ros2_medkit_mcp.plugin import discover_plugins
@@ -14,9 +14,15 @@ def name(self) -> str:
1414
return "fake"
1515

1616
def list_tools(self) -> list[Tool]:
17-
return [Tool(name="fake_tool", description="A fake tool", inputSchema={"type": "object", "properties": {}})]
18-
19-
async def call_tool(self, name: str, arguments: dict) -> list[TextContent]:
17+
return [
18+
Tool(
19+
name="fake_tool",
20+
description="A fake tool",
21+
inputSchema={"type": "object", "properties": {}},
22+
)
23+
]
24+
25+
async def call_tool(self, name: str, _arguments: dict[str, Any]) -> list[TextContent]:
2026
if name == "fake_tool":
2127
return [TextContent(type="text", text="fake result")]
2228
raise ValueError(f"Unknown tool: {name}")

0 commit comments

Comments
 (0)