Skip to content

Commit 31c0525

Browse files
committed
fix(review): harden tool validation and lifecycle cleanup
1 parent faf6a78 commit 31c0525

9 files changed

Lines changed: 263 additions & 97 deletions

File tree

src/mcp_server_python_docs/__main__.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,40 @@
99
# === STDIO HYGIENE (HYGN-01, B3 blocker) ===
1010
# These MUST be the first imports and operations, before anything
1111
# that might write to stdout.
12+
import atexit
1213
import os
1314
import signal
1415
import sys
1516

1617
# Save the real stdout fd for the MCP framer, then redirect fd 1 to stderr.
1718
# After this, any print() or write to fd 1 goes to stderr, not the MCP pipe.
18-
_real_stdout_fd = os.dup(1)
19+
_saved_stdout_fd: int | None = os.dup(1)
20+
21+
22+
def _close_saved_stdout_fd() -> None:
23+
"""Close the saved stdout fd when the CLI exits without serving."""
24+
global _saved_stdout_fd
25+
if _saved_stdout_fd is None:
26+
return
27+
try:
28+
os.close(_saved_stdout_fd)
29+
except OSError:
30+
pass
31+
finally:
32+
_saved_stdout_fd = None
33+
34+
35+
def _consume_saved_stdout_fd() -> int:
36+
"""Hand off the saved stdout fd to the stdio MCP transport."""
37+
global _saved_stdout_fd
38+
if _saved_stdout_fd is None:
39+
raise RuntimeError("Saved stdout fd is not available")
40+
fd = _saved_stdout_fd
41+
_saved_stdout_fd = None
42+
return fd
43+
44+
45+
atexit.register(_close_saved_stdout_fd)
1946
os.dup2(2, 1)
2047
sys.stdout = sys.stderr
2148

@@ -59,12 +86,13 @@ def serve() -> None:
5986
from mcp_server_python_docs.server import create_server
6087

6188
mcp_server = create_server()
89+
saved_stdout_fd = _consume_saved_stdout_fd()
6290

6391
# Restore the real stdout fd for MCP protocol framing.
6492
# By this point all imports are done — no third-party code will
6593
# print to stdout during MCP communication.
66-
os.dup2(_real_stdout_fd, 1)
67-
os.close(_real_stdout_fd)
94+
os.dup2(saved_stdout_fd, 1)
95+
os.close(saved_stdout_fd)
6896

6997
try:
7098
mcp_server.run(transport="stdio")

src/mcp_server_python_docs/ingestion/publish.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@
2626
def generate_build_path() -> Path:
2727
"""Generate a timestamped build artifact path (PUBL-01).
2828
29-
Returns a path like ``~/.cache/mcp-python-docs/build-20260416-143022.db``.
29+
Returns a path like ``~/.cache/mcp-python-docs/build-20260416-143022-123456.db``.
3030
Creates the cache directory if it does not exist.
3131
3232
Returns:
3333
Path to the new build artifact.
3434
"""
3535
cache_dir = get_cache_dir()
3636
cache_dir.mkdir(parents=True, exist_ok=True)
37-
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
37+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
3838
return cache_dir / f"build-{timestamp}.db"
3939

4040

src/mcp_server_python_docs/retrieval/budget.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def apply_budget(
3131
next_start_index is None when not truncated (all remaining text fits).
3232
"""
3333
if not text or max_chars <= 0:
34-
return ("", bool(text), 0 if text else None)
34+
return ("", False, None)
3535

3636
if start_index >= len(text):
3737
return ("", False, None)

src/mcp_server_python_docs/retrieval/query.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,10 @@ def build_match_expression(
174174

175175
# If expansion added new terms, OR-join all terms
176176
if expanded != original_tokens:
177+
original_query = fts5_escape(query)
177178
escaped_terms = [fts5_escape(term) for term in sorted(expanded)]
178-
return " OR ".join(escaped_terms)
179+
extra_terms = [term for term in escaped_terms if term != original_query]
180+
return " OR ".join([original_query, *extra_terms])
179181

180182
# No expansion -- use plain escaped query (implicit AND)
181183
return fts5_escape(query)

src/mcp_server_python_docs/server.py

Lines changed: 94 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
from collections.abc import AsyncIterator
1414
from contextlib import asynccontextmanager
1515
from pathlib import Path
16-
from typing import Literal
16+
from typing import Annotated, Literal
1717

1818
import platformdirs
1919
import yaml
2020
from mcp.server.fastmcp import Context, FastMCP
2121
from mcp.server.fastmcp.exceptions import ToolError
2222
from mcp.types import ToolAnnotations
23+
from pydantic import Field
2324

2425
from mcp_server_python_docs.app_context import AppContext
2526
from mcp_server_python_docs.detection import detect_python_version, match_to_indexed
@@ -81,50 +82,51 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
8182
# Open read-only connection (STOR-06, STOR-07)
8283
db = get_readonly_connection(index_path)
8384

84-
# Check FTS5 (STOR-08)
85-
_assert_fts5(db)
86-
87-
# Construct service instances (Phase 5 — service layer wiring)
88-
search_svc = SearchService(db, synonyms)
89-
content_svc = ContentService(db)
90-
version_svc = VersionService(db)
91-
92-
# Detect user's Python version and match to indexed versions
93-
detected_ver, detected_src = detect_python_version()
94-
indexed_versions = [
95-
r[0] for r in db.execute("SELECT version FROM doc_sets ORDER BY version").fetchall()
96-
]
97-
matched = match_to_indexed(detected_ver, indexed_versions)
98-
if matched:
99-
logger.info("User Python %s matches indexed version — using as default", matched)
100-
else:
101-
logger.info(
102-
"User Python %s not in index %s — using normal default",
103-
detected_ver,
104-
indexed_versions,
105-
)
106-
10785
try:
108-
yield AppContext(
109-
db=db,
110-
index_path=index_path,
111-
synonyms=synonyms,
112-
search_service=search_svc,
113-
content_service=content_svc,
114-
version_service=version_svc,
115-
detected_python_version=matched,
116-
detected_python_source=detected_src,
117-
)
118-
except Exception:
119-
# HYGN-05: log lifespan errors, write last-error.log, re-raise original
120-
error_msg = traceback.format_exc()
121-
logger.error("Lifespan error: %s", error_msg)
86+
# Check FTS5 (STOR-08)
87+
_assert_fts5(db)
88+
89+
# Construct service instances (Phase 5 — service layer wiring)
90+
search_svc = SearchService(db, synonyms)
91+
content_svc = ContentService(db)
92+
version_svc = VersionService(db)
93+
94+
# Detect user's Python version and match to indexed versions
95+
detected_ver, detected_src = detect_python_version()
96+
indexed_versions = [
97+
r[0] for r in db.execute("SELECT version FROM doc_sets ORDER BY version").fetchall()
98+
]
99+
matched = match_to_indexed(detected_ver, indexed_versions)
100+
if matched:
101+
logger.info("User Python %s matches indexed version — using as default", matched)
102+
else:
103+
logger.info(
104+
"User Python %s not in index %s — using normal default",
105+
detected_ver,
106+
indexed_versions,
107+
)
108+
122109
try:
123-
error_log = cache_dir / "last-error.log"
124-
error_log.write_text(error_msg)
110+
yield AppContext(
111+
db=db,
112+
index_path=index_path,
113+
synonyms=synonyms,
114+
search_service=search_svc,
115+
content_service=content_svc,
116+
version_service=version_svc,
117+
detected_python_version=matched,
118+
detected_python_source=detected_src,
119+
)
125120
except Exception:
126-
pass
127-
raise
121+
# HYGN-05: log lifespan errors, write last-error.log, re-raise original
122+
error_msg = traceback.format_exc()
123+
logger.error("Lifespan error: %s", error_msg)
124+
try:
125+
error_log = cache_dir / "last-error.log"
126+
error_log.write_text(error_msg)
127+
except Exception:
128+
pass
129+
raise
128130
finally:
129131
db.close()
130132

@@ -137,6 +139,47 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
137139
openWorldHint=False,
138140
)
139141

142+
SearchQueryParam = Annotated[
143+
str,
144+
Field(
145+
max_length=500,
146+
description="Search query - Python symbol (asyncio.TaskGroup) or concept (parse json)",
147+
),
148+
]
149+
VersionParam = Annotated[
150+
str | None,
151+
Field(description="Python version (e.g. '3.13'). Defaults to latest."),
152+
]
153+
SearchKindParam = Annotated[
154+
Literal["auto", "page", "symbol", "section", "example"],
155+
Field(
156+
description=(
157+
"Search type. Use 'symbol' for API lookups, "
158+
"'example' for code samples, 'auto' otherwise."
159+
)
160+
),
161+
]
162+
MaxResultsParam = Annotated[
163+
int,
164+
Field(ge=1, le=20, description="Maximum number of results to return."),
165+
]
166+
SlugParam = Annotated[
167+
str,
168+
Field(max_length=500, description="Page slug (e.g. 'library/asyncio-task.html')"),
169+
]
170+
AnchorParam = Annotated[
171+
str | None,
172+
Field(description="Section anchor for section-only retrieval"),
173+
]
174+
MaxCharsParam = Annotated[
175+
int,
176+
Field(ge=100, le=50000, description="Maximum characters to return"),
177+
]
178+
StartIndexParam = Annotated[
179+
int,
180+
Field(ge=0, description="Start position for pagination"),
181+
]
182+
140183

141184
def create_server() -> FastMCP:
142185
"""Create and configure the FastMCP server."""
@@ -147,10 +190,10 @@ def create_server() -> FastMCP:
147190

148191
@mcp.tool(annotations=_TOOL_ANNOTATIONS)
149192
def search_docs(
150-
query: str,
151-
version: str | None = None,
152-
kind: Literal["auto", "page", "symbol", "section", "example"] = "auto",
153-
max_results: int = 5,
193+
query: SearchQueryParam,
194+
version: VersionParam = None,
195+
kind: SearchKindParam = "auto",
196+
max_results: MaxResultsParam = 5,
154197
ctx: Context = None, # type: ignore[assignment]
155198
) -> SearchDocsResult:
156199
"""Search Python documentation. Use kind='symbol' for API lookups
@@ -168,11 +211,11 @@ def search_docs(
168211

169212
@mcp.tool(annotations=_TOOL_ANNOTATIONS)
170213
def get_docs(
171-
slug: str,
172-
version: str | None = None,
173-
anchor: str | None = None,
174-
max_chars: int = 8000,
175-
start_index: int = 0,
214+
slug: SlugParam,
215+
version: VersionParam = None,
216+
anchor: AnchorParam = None,
217+
max_chars: MaxCharsParam = 8000,
218+
start_index: StartIndexParam = 0,
176219
ctx: Context = None, # type: ignore[assignment]
177220
) -> GetDocsResult:
178221
"""Retrieve a documentation page or specific section. Provide anchor for
@@ -214,7 +257,6 @@ def detect_python_version(
214257
matches an indexed documentation set."""
215258
app_ctx: AppContext = ctx.request_context.lifespan_context
216259
detected_ver = app_ctx.detected_python_version
217-
detected_src = app_ctx.detected_python_source or "unknown"
218260

219261
# Re-run detection to get the raw version even if it didn't match
220262
from mcp_server_python_docs.detection import detect_python_version as _detect

src/mcp_server_python_docs/services/observability.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
unstable, so we instrument at the service layer.
66
77
Log format (logfmt — D-10 from Phase 1):
8-
tool=search_docs version=3.13 latency_ms=12 result_count=5 truncated=false resolution=fts synonym_expansion=yes
8+
tool=search_docs version=3.13 latency_ms=12 result_count=5
9+
truncated=false resolution=fts synonym_expansion=yes
910
"""
1011
from __future__ import annotations
1112

@@ -14,6 +15,7 @@
1415
import sys
1516
import time
1617
from collections.abc import Callable
18+
from types import TracebackType
1719
from typing import Any
1820

1921

@@ -56,8 +58,15 @@ def decorator(fn: Callable) -> Callable:
5658
@functools.wraps(fn)
5759
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
5860
start = time.monotonic()
61+
result: Any = None
62+
error: Exception | None = None
63+
error_tb: TracebackType | None = None
5964

60-
result = fn(self, *args, **kwargs)
65+
try:
66+
result = fn(self, *args, **kwargs)
67+
except Exception as exc:
68+
error = exc
69+
error_tb = exc.__traceback__
6170

6271
elapsed_ms = (time.monotonic() - start) * 1000
6372

@@ -77,37 +86,43 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
7786
version_val = kwargs.get("version")
7887
fields["version"] = version_val or "default"
7988

80-
# Extract result-specific fields
81-
if hasattr(result, "hits"):
82-
# SearchDocsResult
83-
fields["result_count"] = len(result.hits)
84-
# Resolution path from service state
85-
if hasattr(self, "_last_resolution"):
86-
fields["resolution"] = self._last_resolution
87-
else:
88-
fields["resolution"] = "fts"
89-
fields["truncated"] = False
90-
elif hasattr(result, "truncated"):
91-
# GetDocsResult
92-
fields["result_count"] = 1 if result.content else 0
93-
fields["truncated"] = result.truncated
94-
fields["resolution"] = "exact"
95-
elif hasattr(result, "versions"):
96-
# ListVersionsResult
97-
fields["result_count"] = len(result.versions)
98-
fields["truncated"] = False
99-
fields["resolution"] = "exact"
100-
101-
# Synonym expansion detection from service state
102-
if hasattr(self, "_last_synonym_expanded"):
103-
fields["synonym_expansion"] = (
104-
"yes" if self._last_synonym_expanded else "no"
105-
)
89+
if error is None:
90+
# Extract result-specific fields
91+
if hasattr(result, "hits"):
92+
# SearchDocsResult
93+
fields["result_count"] = len(result.hits)
94+
# Resolution path from service state
95+
if hasattr(self, "_last_resolution"):
96+
fields["resolution"] = self._last_resolution
97+
else:
98+
fields["resolution"] = "fts"
99+
fields["truncated"] = False
100+
elif hasattr(result, "truncated"):
101+
# GetDocsResult
102+
fields["result_count"] = 1 if result.content else 0
103+
fields["truncated"] = result.truncated
104+
fields["resolution"] = "exact"
105+
elif hasattr(result, "versions"):
106+
# ListVersionsResult
107+
fields["result_count"] = len(result.versions)
108+
fields["truncated"] = False
109+
fields["resolution"] = "exact"
110+
111+
# Synonym expansion detection from service state
112+
if hasattr(self, "_last_synonym_expanded"):
113+
fields["synonym_expansion"] = (
114+
"yes" if self._last_synonym_expanded else "no"
115+
)
116+
else:
117+
fields["error"] = type(error).__name__
106118

107119
# Write logfmt line to stderr (HYGN-01 safe — stderr only)
108120
log_line = _format_logfmt(**fields)
109121
print(log_line, file=sys.stderr)
110122

123+
if error is not None:
124+
raise error.with_traceback(error_tb)
125+
111126
return result
112127

113128
return wrapper

0 commit comments

Comments
 (0)