Skip to content

Commit 67fba03

Browse files
ayhammoudaclaude
andcommitted
fix: harden cache constructor and tighten package-docs types
- persistent_cache: move _fingerprint_index() inside the try-except so a missing index.db disables the cache cleanly instead of raising FileNotFoundError out of the constructor and crashing server startup (CodeRabbit Major) - models: tighten PackageDocsSource.kind and PackageDocsResult .trust_boundary to Literal types so invalid values fail validation at construction (CodeRabbit nitpick); add PackageKind alias and propagate it through package_docs._source signature, with a cast at the dynamic _ALLOWED-derived call site (runtime-safe — the allowlist enumeration matches the Literal exactly) Tests: cache disables gracefully on missing index, PackageDocsSource rejects unknown kind, PackageDocsResult rejects unknown trust_boundary. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a525e7b commit 67fba03

5 files changed

Lines changed: 84 additions & 7 deletions

File tree

src/mcp_server_python_docs/models.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,31 @@ class DetectPythonVersionResult(BaseModel):
167167

168168
# --- lookup_package_docs models ---
169169

170+
# Bounded kind vocabulary: hardcoded values from `_source` calls plus the
171+
# normalized members of `_ALLOWED` (with spaces → underscores) in
172+
# services/package_docs.py. Keep these in sync.
173+
PackageKind = Literal[
174+
"pypi",
175+
"docs",
176+
"documentation",
177+
"homepage",
178+
"home_page",
179+
"source",
180+
"source_code",
181+
"repository",
182+
"repo",
183+
]
184+
170185

171186
class PackageDocsSource(BaseModel):
172187
"""A package-declared documentation or project source URL."""
173188

174189
label: str = Field(description="Label from PyPI metadata or a normalized core metadata field")
175190
url: str = Field(description="HTTP(S) URL declared by the package on PyPI")
176-
kind: str = Field(description="Source category, such as docs, homepage, source, or pypi")
191+
kind: PackageKind = Field(
192+
description="Source category: pypi, docs, documentation, homepage, home_page, "
193+
"source, source_code, repository, or repo"
194+
)
177195
declared_by: str = Field(description="Where this source declaration came from")
178196

179197

@@ -184,7 +202,7 @@ class PackageDocsResult(BaseModel):
184202
version: str = Field(description="Latest version reported by PyPI metadata")
185203
summary: str = Field(default="", description="Package summary from PyPI metadata")
186204
metadata_source: str = Field(description="Official PyPI JSON API URL used for lookup")
187-
trust_boundary: str = Field(
205+
trust_boundary: Literal["pypi-declared-metadata"] = Field(
188206
default="pypi-declared-metadata",
189207
description="Indicates results are limited to PyPI/project-declared metadata",
190208
)

src/mcp_server_python_docs/services/package_docs.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
import json
99
import re
1010
from collections.abc import Callable
11-
from typing import Protocol
11+
from typing import Protocol, cast
1212
from urllib.error import HTTPError, URLError
1313
from urllib.parse import quote, urlparse
1414
from urllib.request import Request, urlopen
1515

16-
from mcp_server_python_docs.models import PackageDocsResult, PackageDocsSource
16+
from mcp_server_python_docs.models import (
17+
PackageDocsResult,
18+
PackageDocsSource,
19+
PackageKind,
20+
)
1721
from mcp_server_python_docs.services.observability import log_tool_call
1822

1923
_ALLOWED = {
@@ -57,7 +61,7 @@ def _http_url(url: object) -> str | None:
5761
return url.strip() if parsed.scheme in {"http", "https"} and parsed.netloc else None
5862

5963

60-
def _source(label: str, url: object, kind: str) -> PackageDocsSource | None:
64+
def _source(label: str, url: object, kind: PackageKind) -> PackageDocsSource | None:
6165
valid = _http_url(url)
6266
if valid is None:
6367
return None
@@ -133,7 +137,10 @@ def lookup(self, package: str) -> PackageDocsResult:
133137
for label, url in project_urls.items():
134138
lowered = str(label).strip().lower()
135139
if lowered in _ALLOWED:
136-
found = _source(str(label), url, lowered.replace(" ", "_"))
140+
# Runtime-safe: members of `_ALLOWED` (with spaces → underscores)
141+
# are exactly the non-pypi entries in the PackageKind Literal.
142+
derived_kind = cast(PackageKind, lowered.replace(" ", "_"))
143+
found = _source(str(label), url, derived_kind)
137144
if found is not None and found not in sources:
138145
sources.append(found)
139146
else:

src/mcp_server_python_docs/services/persistent_cache.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ class PersistentDocsCache:
2727

2828
def __init__(self, cache_path: Path, index_path: Path) -> None:
2929
self._cache_path = Path(cache_path)
30-
self._fingerprint = self._fingerprint_index(Path(index_path))
30+
# Set after fingerprint stat succeeds; stays "" if init fails so the
31+
# cache disables cleanly without leaking partial state.
32+
self._fingerprint = ""
3133
self._hits = self._misses = self._writes = 0
3234
# ``check_same_thread=False`` lets multiple threads share the connection,
3335
# but per the Python sqlite3 docs writes must still be serialized by the
@@ -36,6 +38,7 @@ def __init__(self, cache_path: Path, index_path: Path) -> None:
3638
self._lock = threading.Lock()
3739
self._conn: sqlite3.Connection | None = None
3840
try:
41+
self._fingerprint = self._fingerprint_index(Path(index_path))
3942
self._cache_path.parent.mkdir(parents=True, exist_ok=True)
4043
self._conn = sqlite3.connect(str(self._cache_path), check_same_thread=False)
4144
self._conn.execute("PRAGMA journal_mode = WAL")

tests/test_package_docs.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import json
55
from urllib.error import HTTPError, URLError
66

7+
import pytest
8+
from pydantic import ValidationError
9+
10+
from mcp_server_python_docs.models import PackageDocsResult, PackageDocsSource
711
from mcp_server_python_docs.services.package_docs import (
812
_PYPI_METADATA_MAX_BYTES,
913
PackageDocsService,
@@ -135,3 +139,27 @@ def invalid_utf8(url: str, timeout: float):
135139
utf8_result = PackageDocsService(fetcher=invalid_utf8).lookup("demo")
136140
assert utf8_result.sources == []
137141
assert utf8_result.note == "Unable to retrieve PyPI metadata: UnicodeDecodeError."
142+
143+
144+
def test_package_docs_source_rejects_unknown_kind():
145+
"""``kind`` is a controlled vocabulary — unknown values must fail validation.
146+
147+
Regression for CodeRabbit nitpick: tightening ``kind`` from ``str`` to a
148+
``Literal`` formalizes the implicit contract enforced by the ``_ALLOWED``
149+
set in package_docs.py and the hardcoded values used by ``_source``.
150+
"""
151+
with pytest.raises(ValidationError):
152+
PackageDocsSource(
153+
label="bogus", url="https://x/", kind="bogus_kind", declared_by="x"
154+
)
155+
156+
157+
def test_package_docs_result_rejects_unknown_trust_boundary():
158+
"""``trust_boundary`` is fixed at construction — divergence must fail."""
159+
with pytest.raises(ValidationError):
160+
PackageDocsResult(
161+
package="x",
162+
version="1",
163+
metadata_source="https://pypi.org/pypi/x/json",
164+
trust_boundary="something-else",
165+
)

tests/test_persistent_docs_cache.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,27 @@ def test_invalid_cached_json_is_best_effort_miss(tmp_path: Path, caplog):
184184
assert cache.stats().misses == 1
185185

186186

187+
def test_cache_disables_gracefully_when_index_missing(tmp_path: Path, caplog):
188+
"""Constructor must not raise when index.db is missing.
189+
190+
Regression for CodeRabbit Major: ``_fingerprint_index()`` calls
191+
``Path.stat()`` which raises ``FileNotFoundError``; without guarding,
192+
this turns an optional cache into a startup failure for the whole server.
193+
"""
194+
cache_path = tmp_path / "cache.sqlite3"
195+
missing_index = tmp_path / "does-not-exist.db"
196+
assert not missing_index.exists()
197+
198+
with caplog.at_level(logging.WARNING):
199+
cache = PersistentDocsCache(cache_path, missing_index)
200+
201+
assert "Persistent docs cache disabled" in caplog.text
202+
# Cache should be disabled — get returns None, put is a no-op
203+
assert cache.get(
204+
version="3.13", slug="x", anchor=None, max_chars=100, start_index=0
205+
) is None
206+
207+
187208
def test_concurrent_puts_serialize_safely_without_lost_writes(tmp_path: Path):
188209
"""Concurrent put() must not race on the shared connection or stats counter.
189210

0 commit comments

Comments
 (0)