Skip to content

Commit 4d38420

Browse files
authored
cache: add zstd codec layer for retrieved docs
Closes #46
1 parent cf8fc97 commit 4d38420

7 files changed

Lines changed: 275 additions & 10 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Cache support utilities."""
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Versioned codecs for cache-at-rest payloads."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Callable
6+
from dataclasses import dataclass
7+
8+
import zstandard as zstd
9+
10+
_SUPPORTED_CODECS = ["none", "zstd", "zstd-dict-v1"]
11+
12+
13+
@dataclass(frozen=True)
14+
class _Codec:
15+
encode: Callable[[str, object | None], bytes]
16+
decode: Callable[[bytes, object | None], str]
17+
18+
19+
def list_supported() -> list[str]:
20+
"""Return codec ids in stable preference order."""
21+
22+
return list(_SUPPORTED_CODECS)
23+
24+
25+
def encode(text: str, codec: str, *, dictionary: object | None = None) -> bytes:
26+
"""Encode text using a supported cache codec."""
27+
28+
try:
29+
handler = _REGISTRY[codec]
30+
except KeyError as e:
31+
raise ValueError(f"Unsupported cache codec: {codec}") from e
32+
return handler.encode(text, dictionary)
33+
34+
35+
def decode(blob: bytes, codec: str, *, dictionary: object | None = None) -> str:
36+
"""Decode text using the codec stored with the cache row."""
37+
38+
try:
39+
handler = _REGISTRY[codec]
40+
except KeyError as e:
41+
raise ValueError(f"Unsupported cache codec: {codec}") from e
42+
return handler.decode(blob, dictionary)
43+
44+
45+
def _encode_none(text: str, dictionary: object | None) -> bytes:
46+
_reject_dictionary("none", dictionary)
47+
return text.encode("utf-8")
48+
49+
50+
def _decode_none(blob: bytes, dictionary: object | None) -> str:
51+
_reject_dictionary("none", dictionary)
52+
return blob.decode("utf-8")
53+
54+
55+
def _encode_zstd(text: str, dictionary: object | None) -> bytes:
56+
_reject_dictionary("zstd", dictionary)
57+
try:
58+
return zstd.ZstdCompressor().compress(text.encode("utf-8"))
59+
except zstd.ZstdError as e:
60+
raise ValueError(f"zstd encode failed: {e}") from e
61+
62+
63+
def _decode_zstd(blob: bytes, dictionary: object | None) -> str:
64+
_reject_dictionary("zstd", dictionary)
65+
try:
66+
return zstd.ZstdDecompressor().decompress(blob).decode("utf-8")
67+
except zstd.ZstdError as e:
68+
raise ValueError(f"zstd decode failed: {e}") from e
69+
70+
71+
def _encode_zstd_dict(text: str, dictionary: object | None) -> bytes:
72+
try:
73+
return zstd.ZstdCompressor(dict_data=_coerce_dictionary(dictionary)).compress(
74+
text.encode("utf-8")
75+
)
76+
except zstd.ZstdError as e:
77+
raise ValueError(f"zstd dictionary encode failed: {e}") from e
78+
79+
80+
def _decode_zstd_dict(blob: bytes, dictionary: object | None) -> str:
81+
try:
82+
return (
83+
zstd.ZstdDecompressor(dict_data=_coerce_dictionary(dictionary))
84+
.decompress(blob)
85+
.decode("utf-8")
86+
)
87+
except zstd.ZstdError as e:
88+
raise ValueError(f"zstd dictionary decode failed: {e}") from e
89+
90+
91+
def _reject_dictionary(codec: str, dictionary: object | None) -> None:
92+
if dictionary is not None:
93+
raise ValueError(f"Codec {codec!r} does not use a dictionary")
94+
95+
96+
def _coerce_dictionary(dictionary: object | None) -> zstd.ZstdCompressionDict:
97+
if dictionary is None:
98+
raise ValueError("Codec 'zstd-dict-v1' requires an explicit dictionary")
99+
if isinstance(dictionary, zstd.ZstdCompressionDict):
100+
return dictionary
101+
if isinstance(dictionary, bytes):
102+
return zstd.ZstdCompressionDict(dictionary)
103+
if isinstance(dictionary, bytearray | memoryview):
104+
return zstd.ZstdCompressionDict(bytes(dictionary))
105+
raise TypeError(f"Unsupported zstd dictionary object: {type(dictionary).__name__}")
106+
107+
108+
_REGISTRY: dict[str, _Codec] = {
109+
"none": _Codec(_encode_none, _decode_none),
110+
"zstd": _Codec(_encode_zstd, _decode_zstd),
111+
"zstd-dict-v1": _Codec(_encode_zstd_dict, _decode_zstd_dict),
112+
}
113+
114+
__all__ = ["decode", "encode", "list_supported"]

src/mcp_server_python_docs/services/persistent_cache.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010

1111
from pydantic import ValidationError
1212

13+
from mcp_server_python_docs.cache.codec import decode as decode_cache_payload
14+
from mcp_server_python_docs.cache.codec import encode as encode_cache_payload
1315
from mcp_server_python_docs.models import GetDocsResult
1416

1517
logger = logging.getLogger(__name__)
1618
_NO_ANCHOR_KEY = "\x00mcp-python-docs:no-anchor\x00"
19+
DEFAULT_RETRIEVED_DOCS_CACHE_CODEC = "zstd"
1720

1821

1922
class CacheStats(NamedTuple):
@@ -25,8 +28,15 @@ class CacheStats(NamedTuple):
2528
class PersistentDocsCache:
2629
"""Persist get_docs results by index fingerprint, version, and request identity."""
2730

28-
def __init__(self, cache_path: Path, index_path: Path) -> None:
31+
def __init__(
32+
self,
33+
cache_path: Path,
34+
index_path: Path,
35+
*,
36+
default_codec: str = DEFAULT_RETRIEVED_DOCS_CACHE_CODEC,
37+
) -> None:
2938
self._cache_path = Path(cache_path)
39+
self._default_codec = default_codec
3040
# Set after fingerprint stat succeeds; stays "" if init fails so the
3141
# cache disables cleanly without leaking partial state.
3242
self._fingerprint = ""
@@ -47,9 +57,11 @@ def __init__(self, cache_path: Path, index_path: Path) -> None:
4757
"CREATE TABLE IF NOT EXISTS retrieved_docs_cache ("
4858
"index_fingerprint TEXT NOT NULL, version TEXT NOT NULL, slug TEXT NOT NULL, "
4959
"anchor TEXT NOT NULL, max_chars INTEGER NOT NULL, start_index INTEGER NOT NULL, "
50-
"result_json TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, "
60+
"result_json TEXT NOT NULL, compression TEXT NOT NULL DEFAULT 'none', "
61+
"created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, "
5162
"PRIMARY KEY (index_fingerprint, version, slug, anchor, max_chars, start_index))"
5263
)
64+
self._ensure_compression_column()
5365
self._conn.execute(
5466
"DELETE FROM retrieved_docs_cache WHERE index_fingerprint != ?",
5567
(self._fingerprint,),
@@ -74,6 +86,18 @@ def _fingerprint_index(index_path: Path) -> str:
7486
def _anchor_key(anchor: str | None) -> str:
7587
return _NO_ANCHOR_KEY if anchor is None else anchor
7688

89+
def _ensure_compression_column(self) -> None:
90+
if self._conn is None:
91+
return
92+
columns = {
93+
row[1] for row in self._conn.execute("PRAGMA table_info(retrieved_docs_cache)")
94+
}
95+
if "compression" not in columns:
96+
self._conn.execute(
97+
"ALTER TABLE retrieved_docs_cache "
98+
"ADD COLUMN compression TEXT NOT NULL DEFAULT 'none'"
99+
)
100+
77101
def stats(self) -> CacheStats:
78102
return CacheStats(self._hits, self._misses, self._writes)
79103

@@ -87,7 +111,8 @@ def get(
87111
with self._lock:
88112
try:
89113
row = self._conn.execute(
90-
"SELECT result_json FROM retrieved_docs_cache WHERE index_fingerprint = ? "
114+
"SELECT result_json, compression FROM retrieved_docs_cache "
115+
"WHERE index_fingerprint = ? "
91116
"AND version = ? AND slug = ? AND anchor = ? AND max_chars = ? "
92117
"AND start_index = ?",
93118
(
@@ -107,8 +132,10 @@ def get(
107132
self._misses += 1
108133
return None
109134
try:
110-
result = GetDocsResult.model_validate_json(row[0])
111-
except (ValidationError, ValueError) as e:
135+
payload = row[0].encode("utf-8") if isinstance(row[0], str) else bytes(row[0])
136+
result_json = decode_cache_payload(payload, row[1])
137+
result = GetDocsResult.model_validate_json(result_json)
138+
except (ValidationError, ValueError, TypeError) as e:
112139
self._misses += 1
113140
logger.warning("Persistent docs cache entry ignored: %s", e)
114141
return None
@@ -123,15 +150,16 @@ def put(self, *, result: GetDocsResult, max_chars: int, start_index: int) -> Non
123150
self._conn.execute(
124151
"INSERT OR REPLACE INTO retrieved_docs_cache "
125152
"(index_fingerprint, version, slug, anchor, max_chars, start_index, "
126-
"result_json) VALUES (?, ?, ?, ?, ?, ?, ?)",
153+
"result_json, compression) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
127154
(
128155
self._fingerprint,
129156
result.version,
130157
result.slug,
131158
self._anchor_key(result.anchor),
132159
max_chars,
133160
start_index,
134-
result.model_dump_json(),
161+
encode_cache_payload(result.model_dump_json(), self._default_codec),
162+
self._default_codec,
135163
),
136164
)
137165
self._conn.commit()

tests/cache/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Cache tests."""

tests/cache/test_codec.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Cache codec coverage."""
2+
3+
from __future__ import annotations
4+
5+
import zstandard as zstd
6+
7+
from mcp_server_python_docs.cache.codec import decode, encode, list_supported
8+
9+
10+
def _test_dictionary() -> zstd.ZstdCompressionDict:
11+
samples = [
12+
(
13+
f"Python documentation section {i}: json dumps loads encoder decoder "
14+
"arguments return values exceptions examples. "
15+
).encode("utf-8")
16+
* 8
17+
for i in range(64)
18+
]
19+
return zstd.train_dictionary(512, samples)
20+
21+
22+
def test_list_supported_is_stable() -> None:
23+
assert list_supported() == ["none", "zstd", "zstd-dict-v1"]
24+
25+
26+
def test_none_round_trips_text() -> None:
27+
text = '{"content":"plain json payload","version":"3.13"}'
28+
encoded = encode(text, "none")
29+
assert encoded == text.encode("utf-8")
30+
assert decode(encoded, "none") == text
31+
32+
33+
def test_zstd_round_trips_text() -> None:
34+
text = '{"content":"compressed json payload","version":"3.13"}'
35+
encoded = encode(text, "zstd")
36+
assert encoded != text.encode("utf-8")
37+
assert decode(encoded, "zstd") == text
38+
39+
40+
def test_zstd_dict_v1_round_trips_with_explicit_dictionary() -> None:
41+
dictionary = _test_dictionary()
42+
text = "Python documentation section 7: json dumps loads encoder decoder arguments."
43+
encoded = encode(text, "zstd-dict-v1", dictionary=dictionary)
44+
assert decode(encoded, "zstd-dict-v1", dictionary=dictionary) == text
45+
46+
47+
def test_none_decodes_payload_from_prior_server_version() -> None:
48+
prior_payload = b'{"content":"legacy uncompressed json","version":"3.12"}'
49+
assert decode(prior_payload, "none") == prior_payload.decode("utf-8")

tests/test_mcp_get_docs_cache_smoke.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,12 @@ def test_get_docs_cache_restart_and_corrupt_cache_fallback(tmp_path: Path):
140140

141141
with sqlite3.connect(cache_path) as conn:
142142
rows = conn.execute(
143-
"SELECT version, slug, anchor, max_chars, start_index, length(result_json) "
143+
"SELECT version, slug, anchor, max_chars, start_index, "
144+
"length(result_json), compression "
144145
"FROM retrieved_docs_cache"
145146
).fetchall()
146147
assert len(rows) == 1
147-
version, slug, anchor, max_chars, start_index, result_json_length = rows[0]
148+
version, slug, anchor, max_chars, start_index, result_json_length, compression = rows[0]
148149
assert (version, slug, anchor, max_chars, start_index) == (
149150
"3.13",
150151
"library/json.html",
@@ -153,6 +154,7 @@ def test_get_docs_cache_restart_and_corrupt_cache_fallback(tmp_path: Path):
153154
0,
154155
)
155156
assert result_json_length > 0
157+
assert compression == "zstd"
156158

157159
restarted_page = _tool_structured_content(
158160
_run_server(

tests/test_persistent_docs_cache.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from mcp_server_python_docs.models import GetDocsResult
1212
from mcp_server_python_docs.services.content import ContentService
13-
from mcp_server_python_docs.services.persistent_cache import PersistentDocsCache
13+
from mcp_server_python_docs.services.persistent_cache import _NO_ANCHOR_KEY, PersistentDocsCache
1414

1515

1616
def _doc(db, version: str, content: str, default: int = 0) -> None:
@@ -68,6 +68,76 @@ def test_cache_survives_restart_and_miss_falls_back(populated_db, tmp_path: Path
6868
assert restarted.stats().hits == 1
6969

7070

71+
def test_current_default_codec_reads_identically_after_restart(tmp_path: Path):
72+
index_path, cache = _cache(tmp_path)
73+
expected = _result("compressed docs payload")
74+
cache.put(result=expected, max_chars=500, start_index=0)
75+
76+
with sqlite3.connect(cache.cache_path) as conn:
77+
compression = conn.execute("SELECT compression FROM retrieved_docs_cache").fetchone()[0]
78+
assert compression == "zstd"
79+
80+
restarted = PersistentDocsCache(tmp_path / "retrieved.sqlite3", index_path)
81+
assert (
82+
restarted.get(
83+
version="3.12",
84+
slug="library/json.html",
85+
anchor=None,
86+
max_chars=500,
87+
start_index=0,
88+
)
89+
== expected
90+
)
91+
assert restarted.stats().hits == 1
92+
93+
94+
def test_legacy_uncompressed_cache_row_migrates_and_reads(tmp_path: Path):
95+
index_path = tmp_path / "index.db"
96+
index_path.write_bytes(b"index-1")
97+
fingerprint = PersistentDocsCache._fingerprint_index(index_path)
98+
cache_path = tmp_path / "retrieved.sqlite3"
99+
expected = _result("legacy docs payload")
100+
with sqlite3.connect(cache_path) as conn:
101+
conn.execute(
102+
"CREATE TABLE retrieved_docs_cache ("
103+
"index_fingerprint TEXT NOT NULL, version TEXT NOT NULL, slug TEXT NOT NULL, "
104+
"anchor TEXT NOT NULL, max_chars INTEGER NOT NULL, start_index INTEGER NOT NULL, "
105+
"result_json TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, "
106+
"PRIMARY KEY (index_fingerprint, version, slug, anchor, max_chars, start_index))"
107+
)
108+
conn.execute(
109+
"INSERT INTO retrieved_docs_cache "
110+
"(index_fingerprint, version, slug, anchor, max_chars, start_index, result_json) "
111+
"VALUES (?, ?, ?, ?, ?, ?, ?)",
112+
(
113+
fingerprint,
114+
expected.version,
115+
expected.slug,
116+
_NO_ANCHOR_KEY,
117+
500,
118+
0,
119+
expected.model_dump_json(),
120+
),
121+
)
122+
123+
migrated = PersistentDocsCache(cache_path, index_path)
124+
assert (
125+
migrated.get(
126+
version="3.12",
127+
slug="library/json.html",
128+
anchor=None,
129+
max_chars=500,
130+
start_index=0,
131+
)
132+
== expected
133+
)
134+
with sqlite3.connect(cache_path) as conn:
135+
columns = {row[1] for row in conn.execute("PRAGMA table_info(retrieved_docs_cache)")}
136+
compression = conn.execute("SELECT compression FROM retrieved_docs_cache").fetchone()[0]
137+
assert "compression" in columns
138+
assert compression == "none"
139+
140+
71141
def test_cache_key_includes_python_version(populated_db, tmp_path: Path):
72142
_doc(populated_db, "3.12", "docs for 3.12")
73143
_doc(populated_db, "3.13", "docs for 3.13", 1)

0 commit comments

Comments
 (0)