|
8 | 8 | import json |
9 | 9 | import re |
10 | 10 | from collections.abc import Callable |
11 | | -from typing import Protocol |
| 11 | +from typing import Protocol, cast |
12 | 12 | from urllib.error import HTTPError, URLError |
13 | 13 | from urllib.parse import quote, urlparse |
14 | 14 | from urllib.request import Request, urlopen |
15 | 15 |
|
16 | | -from mcp_server_python_docs.models import PackageDocsResult, PackageDocsSource |
| 16 | +from mcp_server_python_docs.models import ( |
| 17 | + PackageDocsResult, |
| 18 | + PackageDocsSource, |
| 19 | + PackageKind, |
| 20 | +) |
17 | 21 | from mcp_server_python_docs.services.observability import log_tool_call |
18 | 22 |
|
19 | 23 | _ALLOWED = { |
@@ -57,7 +61,7 @@ def _http_url(url: object) -> str | None: |
57 | 61 | return url.strip() if parsed.scheme in {"http", "https"} and parsed.netloc else None |
58 | 62 |
|
59 | 63 |
|
60 | | -def _source(label: str, url: object, kind: str) -> PackageDocsSource | None: |
| 64 | +def _source(label: str, url: object, kind: PackageKind) -> PackageDocsSource | None: |
61 | 65 | valid = _http_url(url) |
62 | 66 | if valid is None: |
63 | 67 | return None |
@@ -133,7 +137,10 @@ def lookup(self, package: str) -> PackageDocsResult: |
133 | 137 | for label, url in project_urls.items(): |
134 | 138 | lowered = str(label).strip().lower() |
135 | 139 | if lowered in _ALLOWED: |
136 | | - found = _source(str(label), url, lowered.replace(" ", "_")) |
| 140 | + # Runtime-safe: members of `_ALLOWED` (with spaces → underscores) |
| 141 | + # are exactly the non-pypi entries in the PackageKind Literal. |
| 142 | + derived_kind = cast(PackageKind, lowered.replace(" ", "_")) |
| 143 | + found = _source(str(label), url, derived_kind) |
137 | 144 | if found is not None and found not in sources: |
138 | 145 | sources.append(found) |
139 | 146 | else: |
|
0 commit comments