Skip to content

Commit 8814bc5

Browse files
feat(mcp): add MCP server for AI assistant integration
- Implement MCP server to enable AI assistants to search plot specifications and fetch implementation code - Update import paths and configuration in main.py and pyproject.toml - Add unit tests for MCP server tools
1 parent 2298302 commit 8814bc5

File tree

6 files changed

+31
-31
lines changed

6 files changed

+31
-31
lines changed

api/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
http_exception_handler,
2323
pyplots_exception_handler,
2424
)
25+
from api.mcp.server import mcp_server # noqa: E402
2526
from api.routers import ( # noqa: E402
2627
debug_router,
2728
download_router,
@@ -35,7 +36,6 @@
3536
stats_router,
3637
)
3738
from core.database import close_db, init_db, is_db_configured # noqa: E402
38-
from pyplots_mcp.server import mcp_server # noqa: E402
3939

4040

4141
# Configure logging
File renamed without changes.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ skip-magic-trailing-comma = true
144144
[tool.ruff.lint.isort]
145145
force-single-line = false
146146
lines-after-imports = 2
147-
known-first-party = ["api", "core", "automation", "pyplots_mcp"]
147+
known-first-party = ["api", "core", "automation"]
148148
split-on-trailing-comma = false
149149

150150
# ===== Pytest Configuration =====
@@ -175,5 +175,5 @@ build-backend = "setuptools.build_meta"
175175
# ===== Setuptools Configuration =====
176176

177177
[tool.setuptools.packages.find]
178-
include = ["api*", "core*", "automation*", "pyplots_mcp*"]
178+
include = ["api*", "core*", "automation*"]
179179
exclude = ["tests*", "docs*", "app*", "specs*", "rules*", "plots*"]
Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
# Import the tool functions from the module
1212
# Note: These are FunctionTool objects, we need to access .fn to get the actual callable
13-
from pyplots_mcp.server import get_implementation as get_implementation_tool
14-
from pyplots_mcp.server import get_spec_detail as get_spec_detail_tool
15-
from pyplots_mcp.server import get_tag_values as get_tag_values_tool
16-
from pyplots_mcp.server import list_libraries as list_libraries_tool
17-
from pyplots_mcp.server import list_specs as list_specs_tool
18-
from pyplots_mcp.server import search_specs_by_tags as search_specs_by_tags_tool
13+
from api.mcp.server import get_implementation as get_implementation_tool
14+
from api.mcp.server import get_spec_detail as get_spec_detail_tool
15+
from api.mcp.server import get_tag_values as get_tag_values_tool
16+
from api.mcp.server import list_libraries as list_libraries_tool
17+
from api.mcp.server import list_specs as list_specs_tool
18+
from api.mcp.server import search_specs_by_tags as search_specs_by_tags_tool
1919

2020

2121
# Extract the actual functions from the FunctionTool wrappers
@@ -40,8 +40,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
4040
pass
4141

4242
with (
43-
patch("pyplots_mcp.server.get_db_context", return_value=MockContextManager()),
44-
patch("pyplots_mcp.server.is_db_configured", return_value=True),
43+
patch("api.mcp.server.get_db_context", return_value=MockContextManager()),
44+
patch("api.mcp.server.is_db_configured", return_value=True),
4545
):
4646
yield mock_session
4747

@@ -92,7 +92,7 @@ async def test_list_specs(mock_db_context, mock_spec):
9292
mock_repo = MagicMock()
9393
mock_repo.get_all = AsyncMock(return_value=[mock_spec])
9494

95-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
95+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
9696
result = await list_specs(limit=10, offset=0)
9797

9898
assert len(result) == 1
@@ -109,7 +109,7 @@ async def test_list_specs_pagination(mock_db_context):
109109
mock_repo = MagicMock()
110110
mock_repo.get_all = AsyncMock(return_value=specs)
111111

112-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
112+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
113113
result = await list_specs(limit=2, offset=1)
114114

115115
assert len(result) == 2
@@ -123,7 +123,7 @@ async def test_search_specs_by_tags_spec_level(mock_db_context, mock_spec):
123123
mock_repo = MagicMock()
124124
mock_repo.search_by_tags = AsyncMock(return_value=[mock_spec])
125125

126-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
126+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
127127
result = await search_specs_by_tags(plot_type=["scatter"], domain=["statistics"])
128128

129129
# Verify repository called with flattened list (order may vary)
@@ -139,7 +139,7 @@ async def test_search_specs_by_tags_impl_level(mock_db_context, mock_spec):
139139
mock_repo = MagicMock()
140140
mock_repo.get_all = AsyncMock(return_value=[mock_spec])
141141

142-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
142+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
143143
result = await search_specs_by_tags(library=["matplotlib"], patterns=["data-generation"])
144144

145145
assert len(result) == 1
@@ -152,7 +152,7 @@ async def test_search_specs_by_tags_no_matches(mock_db_context, mock_spec):
152152
mock_repo = MagicMock()
153153
mock_repo.get_all = AsyncMock(return_value=[mock_spec])
154154

155-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
155+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
156156
result = await search_specs_by_tags(library=["seaborn"]) # matplotlib impl, not seaborn
157157

158158
assert len(result) == 0
@@ -164,7 +164,7 @@ async def test_search_specs_by_tags_dataprep_styling(mock_db_context, mock_spec)
164164
mock_repo = MagicMock()
165165
mock_repo.get_all = AsyncMock(return_value=[mock_spec])
166166

167-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
167+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
168168
# Test dataprep filter - should not match (mock_spec has no dataprep tags)
169169
result = await search_specs_by_tags(dataprep=["normalization"])
170170
assert len(result) == 0
@@ -181,7 +181,7 @@ async def test_get_spec_detail(mock_db_context, mock_spec):
181181
mock_repo = MagicMock()
182182
mock_repo.get_by_id = AsyncMock(return_value=mock_spec)
183183

184-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
184+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
185185
result = await get_spec_detail("scatter-basic")
186186

187187
assert result["id"] == "scatter-basic"
@@ -197,7 +197,7 @@ async def test_get_spec_detail_not_found(mock_db_context):
197197
mock_repo = MagicMock()
198198
mock_repo.get_by_id = AsyncMock(return_value=None)
199199

200-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
200+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
201201
with pytest.raises(ValueError, match="Specification 'invalid' not found"):
202202
await get_spec_detail("invalid")
203203

@@ -221,9 +221,9 @@ async def test_get_implementation(mock_db_context, mock_spec):
221221
mock_impl_repo.get_by_spec_and_library = AsyncMock(return_value=mock_impl)
222222

223223
with (
224-
patch("pyplots_mcp.server.SpecRepository", return_value=mock_spec_repo),
225-
patch("pyplots_mcp.server.LibraryRepository", return_value=mock_lib_repo),
226-
patch("pyplots_mcp.server.ImplRepository", return_value=mock_impl_repo),
224+
patch("api.mcp.server.SpecRepository", return_value=mock_spec_repo),
225+
patch("api.mcp.server.LibraryRepository", return_value=mock_lib_repo),
226+
patch("api.mcp.server.ImplRepository", return_value=mock_impl_repo),
227227
):
228228
result = await get_implementation("scatter-basic", "matplotlib")
229229

@@ -238,7 +238,7 @@ async def test_get_implementation_spec_not_found(mock_db_context):
238238
mock_spec_repo = MagicMock()
239239
mock_spec_repo.get_by_id = AsyncMock(return_value=None)
240240

241-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_spec_repo):
241+
with patch("api.mcp.server.SpecRepository", return_value=mock_spec_repo):
242242
with pytest.raises(ValueError, match="Specification 'invalid' not found"):
243243
await get_implementation("invalid", "matplotlib")
244244

@@ -254,8 +254,8 @@ async def test_get_implementation_library_not_found(mock_db_context, mock_spec):
254254
mock_lib_repo.get_all = AsyncMock(return_value=[MagicMock(id="matplotlib")])
255255

256256
with (
257-
patch("pyplots_mcp.server.SpecRepository", return_value=mock_spec_repo),
258-
patch("pyplots_mcp.server.LibraryRepository", return_value=mock_lib_repo),
257+
patch("api.mcp.server.SpecRepository", return_value=mock_spec_repo),
258+
patch("api.mcp.server.LibraryRepository", return_value=mock_lib_repo),
259259
):
260260
with pytest.raises(ValueError, match="Library 'invalid' not found"):
261261
await get_implementation("scatter-basic", "invalid")
@@ -277,9 +277,9 @@ async def test_get_implementation_not_found(mock_db_context, mock_spec):
277277
mock_impl_repo.get_by_spec_and_library = AsyncMock(return_value=None)
278278

279279
with (
280-
patch("pyplots_mcp.server.SpecRepository", return_value=mock_spec_repo),
281-
patch("pyplots_mcp.server.LibraryRepository", return_value=mock_lib_repo),
282-
patch("pyplots_mcp.server.ImplRepository", return_value=mock_impl_repo),
280+
patch("api.mcp.server.SpecRepository", return_value=mock_spec_repo),
281+
patch("api.mcp.server.LibraryRepository", return_value=mock_lib_repo),
282+
patch("api.mcp.server.ImplRepository", return_value=mock_impl_repo),
283283
):
284284
with pytest.raises(ValueError, match="Implementation for 'scatter-basic' in library 'seaborn' not found"):
285285
await get_implementation("scatter-basic", "seaborn")
@@ -303,7 +303,7 @@ async def test_list_libraries(mock_db_context):
303303
mock_repo = MagicMock()
304304
mock_repo.get_all = AsyncMock(return_value=mock_libs)
305305

306-
with patch("pyplots_mcp.server.LibraryRepository", return_value=mock_repo):
306+
with patch("api.mcp.server.LibraryRepository", return_value=mock_repo):
307307
result = await list_libraries()
308308

309309
assert len(result) == 2
@@ -322,7 +322,7 @@ async def test_get_tag_values_spec_level(mock_db_context, mock_spec):
322322
mock_repo = MagicMock()
323323
mock_repo.get_all = AsyncMock(return_value=[mock_spec, mock_spec2])
324324

325-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
325+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
326326
result = await get_tag_values("plot_type")
327327

328328
assert sorted(result) == ["bar", "histogram", "scatter"]
@@ -334,7 +334,7 @@ async def test_get_tag_values_impl_level(mock_db_context, mock_spec):
334334
mock_repo = MagicMock()
335335
mock_repo.get_all = AsyncMock(return_value=[mock_spec])
336336

337-
with patch("pyplots_mcp.server.SpecRepository", return_value=mock_repo):
337+
with patch("api.mcp.server.SpecRepository", return_value=mock_repo):
338338
result = await get_tag_values("patterns")
339339

340340
assert sorted(result) == ["data-generation"]

0 commit comments

Comments
 (0)